Skip to content

Commit 66a05cb

Browse files
committed
chore: use dedicated types for compressed modswitched conformance
1 parent 0136736 commit 66a05cb

File tree

24 files changed

+519
-114
lines changed

24 files changed

+519
-114
lines changed

tfhe/src/c_api/high_level_api/booleans.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl_destroy_on_type!(CompressedFheBool);
4646
impl_clone_on_type!(CompressedFheBool);
4747
impl_serialize_deserialize_on_type!(CompressedFheBool);
4848
impl_safe_serialize_on_type!(CompressedFheBool);
49-
impl_safe_deserialize_conformant_on_type!(CompressedFheBool, FheBoolConformanceParams);
49+
impl_safe_deserialize_conformant_on_type!(CompressedFheBool, CompressedFheBoolConformanceParams);
5050
impl_try_encrypt_with_client_key_on_type!(CompressedFheBool{crate::high_level_api::CompressedFheBool}, bool);
5151

5252
#[no_mangle]

tfhe/src/c_api/high_level_api/integers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ macro_rules! create_integer_wrapper_type {
359359

360360
impl_safe_serialize_on_type!([<Compressed $name>]);
361361

362-
impl_safe_deserialize_conformant_on_type!([<Compressed $name>], [<$name ConformanceParams>]);
362+
impl_safe_deserialize_conformant_on_type!([<Compressed $name>], [<Compressed $name ConformanceParams>]);
363363

364364
#[no_mangle]
365365
pub unsafe extern "C" fn [<compressed_ $name:snake _decompress>](

tfhe/src/core_crypto/entities/compressed_modulus_switched_lwe_ciphertext.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,30 +135,55 @@ impl<PackingScalar: UnsignedInteger> CompressedModulusSwitchedLweCiphertext<Pack
135135
}
136136
}
137137

138+
#[derive(Copy, Clone)]
139+
pub enum MsDecompressionType {
140+
ClassicPbs,
141+
MultiBitPbs(LweBskGroupingFactor),
142+
}
143+
144+
#[derive(Copy, Clone)]
145+
pub struct CompressedModulusSwitchedLweCiphertextConformanceParams<Scalar>
146+
where
147+
Scalar: UnsignedInteger,
148+
{
149+
pub ct_params: LweCiphertextConformanceParams<Scalar>,
150+
pub ms_decompression_type: MsDecompressionType,
151+
}
152+
138153
impl<Scalar: UnsignedInteger> ParameterSetConformant
139154
for CompressedModulusSwitchedLweCiphertext<Scalar>
140155
{
141-
type ParameterSet = LweCiphertextConformanceParams<Scalar>;
156+
type ParameterSet = CompressedModulusSwitchedLweCiphertextConformanceParams<Scalar>;
142157

143-
fn is_conformant(&self, lwe_ct_parameters: &LweCiphertextConformanceParams<Scalar>) -> bool {
158+
fn is_conformant(
159+
&self,
160+
compressed_ct_parameters: &CompressedModulusSwitchedLweCiphertextConformanceParams<Scalar>,
161+
) -> bool {
144162
let Self {
145163
packed_integers,
146164
lwe_dimension,
147165
} = self;
148166

167+
let CompressedModulusSwitchedLweCiphertextConformanceParams {
168+
ct_params,
169+
ms_decompression_type,
170+
} = compressed_ct_parameters;
171+
172+
let LweCiphertextConformanceParams {
173+
lwe_dim: params_lwe_dim,
174+
ct_modulus,
175+
} = ct_params;
176+
149177
let lwe_size = lwe_dimension.to_lwe_size().0;
150178

151179
let number_bits_to_pack = lwe_size * packed_integers.log_modulus().0;
152180

153181
let len = number_bits_to_pack.div_ceil(Scalar::BITS);
154182

155183
packed_integers.packed_coeffs().len() == len
156-
&& *lwe_dimension == lwe_ct_parameters.lwe_dim
157-
&& lwe_ct_parameters.ct_modulus.is_power_of_two()
158-
&& matches!(
159-
lwe_ct_parameters.ms_decompression_method,
160-
MsDecompressionType::ClassicPbs
161-
)
184+
&& lwe_dimension == params_lwe_dim
185+
&& ct_modulus.is_power_of_two()
186+
&& matches!(ms_decompression_type, MsDecompressionType::ClassicPbs)
162187
}
163188
}
164189

tfhe/src/core_crypto/entities/compressed_modulus_switched_multi_bit_lwe_ciphertext.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,12 @@ impl MultiBitModulusSwitchedLweCiphertext for FromCompressionMultiBitModulusSwit
511511
impl<Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>> ParameterSetConformant
512512
for CompressedModulusSwitchedMultiBitLweCiphertext<Scalar>
513513
{
514-
type ParameterSet = LweCiphertextConformanceParams<Scalar>;
514+
type ParameterSet = CompressedModulusSwitchedLweCiphertextConformanceParams<Scalar>;
515515

516-
fn is_conformant(&self, lwe_ct_parameters: &LweCiphertextConformanceParams<Scalar>) -> bool {
516+
fn is_conformant(
517+
&self,
518+
compressed_ct_parameters: &CompressedModulusSwitchedLweCiphertextConformanceParams<Scalar>,
519+
) -> bool {
517520
let Self {
518521
body,
519522
packed_mask,
@@ -523,22 +526,30 @@ impl<Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>> ParameterSetCo
523526
grouping_factor,
524527
} = self;
525528

526-
let lwe_dim = lwe_dimension.0;
529+
let CompressedModulusSwitchedLweCiphertextConformanceParams {
530+
ct_params,
531+
ms_decompression_type,
532+
} = compressed_ct_parameters;
533+
534+
let LweCiphertextConformanceParams {
535+
lwe_dim: params_lwe_dim,
536+
ct_modulus,
537+
} = ct_params;
527538

528539
*body >> packed_mask.log_modulus().0 == Scalar::ZERO
529-
&& packed_mask.is_conformant(&lwe_dim)
540+
&& packed_mask.is_conformant(&lwe_dimension.0)
530541
&& packed_diffs
531542
.as_ref()
532-
.is_none_or(|packed_diffs| packed_diffs.is_conformant(&lwe_dim))
533-
&& *lwe_dimension == lwe_ct_parameters.lwe_dim
534-
&& lwe_ct_parameters.ct_modulus.is_power_of_two()
535-
&& match lwe_ct_parameters.ms_decompression_method {
543+
.is_none_or(|packed_diffs| packed_diffs.is_conformant(&lwe_dimension.0))
544+
&& lwe_dimension == params_lwe_dim
545+
&& ct_modulus.is_power_of_two()
546+
&& match ms_decompression_type {
536547
MsDecompressionType::ClassicPbs => false,
537548
MsDecompressionType::MultiBitPbs(expected_grouping_factor) => {
538549
expected_grouping_factor.0 == grouping_factor.0
539550
}
540551
}
541-
&& *uncompressed_ciphertext_modulus == lwe_ct_parameters.ct_modulus
552+
&& uncompressed_ciphertext_modulus == ct_modulus
542553
}
543554
}
544555

tfhe/src/core_crypto/entities/lwe_ciphertext.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,6 @@ pub type LweCiphertextMutView<'data, Scalar> = LweCiphertext<&'data mut [Scalar]
753753
pub struct LweCiphertextConformanceParams<T: UnsignedInteger> {
754754
pub lwe_dim: LweDimension,
755755
pub ct_modulus: CiphertextModulus<T>,
756-
pub ms_decompression_method: MsDecompressionType,
757-
}
758-
759-
#[derive(Copy, Clone)]
760-
pub enum MsDecompressionType {
761-
ClassicPbs,
762-
MultiBitPbs(LweBskGroupingFactor),
763756
}
764757

765758
impl<C: Container> ParameterSetConformant for LweCiphertext<C>
@@ -777,9 +770,14 @@ where
777770
ciphertext_modulus,
778771
} = self;
779772

780-
check_encrypted_content_respects_mod(data, lwe_ct_parameters.ct_modulus)
781-
&& self.lwe_size() == lwe_ct_parameters.lwe_dim.to_lwe_size()
782-
&& *ciphertext_modulus == lwe_ct_parameters.ct_modulus
773+
let LweCiphertextConformanceParams {
774+
lwe_dim,
775+
ct_modulus,
776+
} = lwe_ct_parameters;
777+
778+
check_encrypted_content_respects_mod(data, *ct_modulus)
779+
&& self.lwe_size() == lwe_dim.to_lwe_size()
780+
&& ciphertext_modulus == ct_modulus
783781
}
784782
}
785783

tfhe/src/high_level_api/booleans/compressed.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ use crate::high_level_api::traits::Tagged;
99
use crate::integer::BooleanBlock;
1010
use crate::named::Named;
1111
use crate::prelude::FheTryEncrypt;
12-
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree};
13-
use crate::shortint::CompressedCiphertext;
14-
use crate::{ClientKey, FheBool, FheBoolConformanceParams, Tag};
12+
use crate::shortint::ciphertext::{
13+
CompressedModulusSwitchedCiphertext, CompressedModulusSwitchedCiphertextConformanceParams,
14+
Degree,
15+
};
16+
use crate::shortint::{AtomicPatternParameters, CompressedCiphertext};
17+
use crate::{ClientKey, FheBool, ServerKey, Tag};
1518
use serde::{Deserialize, Serialize};
1619
use tfhe_versionable::Versionize;
1720

@@ -111,12 +114,35 @@ impl FheTryEncrypt<bool, ClientKey> for CompressedFheBool {
111114
}
112115
}
113116

117+
#[derive(Copy, Clone)]
118+
pub struct CompressedFheBoolConformanceParams(
119+
pub(crate) CompressedModulusSwitchedCiphertextConformanceParams,
120+
);
121+
122+
impl<P: Into<AtomicPatternParameters>> From<P> for CompressedFheBoolConformanceParams {
123+
fn from(params: P) -> Self {
124+
let params = params.into();
125+
Self(params.to_compressed_modswitched_conformance_param())
126+
}
127+
}
128+
129+
impl From<&ServerKey> for CompressedFheBoolConformanceParams {
130+
fn from(sk: &ServerKey) -> Self {
131+
Self(
132+
sk.key
133+
.pbs_key()
134+
.key
135+
.compressed_modswitched_conformance_params(),
136+
)
137+
}
138+
}
139+
114140
impl ParameterSetConformant for CompressedFheBool {
115-
type ParameterSet = FheBoolConformanceParams;
141+
type ParameterSet = CompressedFheBoolConformanceParams;
116142

117-
fn is_conformant(&self, params: &FheBoolConformanceParams) -> bool {
143+
fn is_conformant(&self, params: &CompressedFheBoolConformanceParams) -> bool {
118144
match &self.inner {
119-
InnerCompressedFheBool::Seeded(seeded) => seeded.is_conformant(&params.0),
145+
InnerCompressedFheBool::Seeded(seeded) => seeded.is_conformant(&params.0.into()),
120146
InnerCompressedFheBool::ModulusSwitched(ct) => ct.is_conformant(&params.0),
121147
}
122148
}

tfhe/src/high_level_api/booleans/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pub use base::{FheBool, FheBoolConformanceParams};
2-
pub use compressed::CompressedFheBool;
2+
pub use compressed::{CompressedFheBool, CompressedFheBoolConformanceParams};
33
pub use squashed_noise::SquashedNoiseFheBool;
44

55
pub(in crate::high_level_api) use compressed::InnerCompressedFheBool;

tfhe/src/high_level_api/booleans/tests.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ fn compressed_bool_test_case(setup_fn: impl FnOnce() -> (ClientKey, Device)) {
318318

319319
mod cpu {
320320
use super::*;
321+
use crate::high_level_api::booleans::compressed::CompressedFheBoolConformanceParams;
321322
use crate::safe_serialization::{DeserializationConfig, SerializationConfig};
322323
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
323324
use crate::FheBoolConformanceParams;
@@ -707,12 +708,14 @@ mod cpu {
707708
.serialize_into(&a, &mut serialized)
708709
.unwrap();
709710

710-
let params = FheBoolConformanceParams::from(&server_key);
711+
let params = CompressedFheBoolConformanceParams::from(&server_key);
711712
let deserialized_a = DeserializationConfig::new(1 << 20)
712713
.deserialize_from::<CompressedFheBool>(serialized.as_slice(), &params)
713714
.unwrap();
714715

715-
assert!(deserialized_a.is_conformant(&FheBoolConformanceParams::from(block_params)));
716+
assert!(
717+
deserialized_a.is_conformant(&CompressedFheBoolConformanceParams::from(block_params))
718+
);
716719

717720
let decrypted: bool = deserialized_a.decompress().decrypt(&keys);
718721
assert_eq!(decrypted, clear_a);

tfhe/src/high_level_api/integers/signed/compressed.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::marker::PhantomData;
2+
13
use tfhe_versionable::Versionize;
24

35
use crate::backward_compatibility::integers::{
@@ -6,20 +8,20 @@ use crate::backward_compatibility::integers::{
68
use crate::conformance::ParameterSetConformant;
79
use crate::core_crypto::prelude::SignedNumeric;
810
use crate::high_level_api::global_state;
9-
use crate::high_level_api::integers::signed::base::FheIntConformanceParams;
1011
use crate::high_level_api::integers::{FheInt, FheIntId};
1112
use crate::high_level_api::keys::InternalServerKey;
1213
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
1314
use crate::high_level_api::traits::Tagged;
1415
use crate::integer::block_decomposition::DecomposableInto;
1516
use crate::integer::ciphertext::{
17+
CompressedModulusSwitchedRadixCiphertextConformanceParams,
1618
CompressedModulusSwitchedSignedRadixCiphertext,
1719
CompressedSignedRadixCiphertext as IntegerCompressedSignedRadixCiphertext,
1820
};
19-
use crate::integer::parameters::RadixCiphertextConformanceParams;
2021
use crate::named::Named;
2122
use crate::prelude::FheTryEncrypt;
22-
use crate::{ClientKey, Tag};
23+
use crate::shortint::AtomicPatternParameters;
24+
use crate::{ClientKey, ServerKey, Tag};
2325

2426
/// Compressed [FheInt]
2527
///
@@ -153,10 +155,51 @@ where
153155
}
154156
}
155157

158+
#[derive(Copy, Clone)]
159+
pub struct CompressedFheIntConformanceParams<Id: FheIntId> {
160+
pub(crate) params: CompressedSignedRadixCiphertextConformanceParams,
161+
pub(crate) id: PhantomData<Id>,
162+
}
163+
164+
impl<Id: FheIntId, P: Into<AtomicPatternParameters>> From<P>
165+
for CompressedFheIntConformanceParams<Id>
166+
{
167+
fn from(params: P) -> Self {
168+
let params = params.into();
169+
Self {
170+
params: CompressedSignedRadixCiphertextConformanceParams(
171+
CompressedModulusSwitchedRadixCiphertextConformanceParams {
172+
shortint_params: params.to_compressed_modswitched_conformance_param(),
173+
num_blocks_per_integer: Id::num_blocks(params.message_modulus()),
174+
},
175+
),
176+
id: PhantomData,
177+
}
178+
}
179+
}
180+
181+
impl<Id: FheIntId> From<&ServerKey> for CompressedFheIntConformanceParams<Id> {
182+
fn from(sk: &ServerKey) -> Self {
183+
Self {
184+
params: CompressedSignedRadixCiphertextConformanceParams(
185+
CompressedModulusSwitchedRadixCiphertextConformanceParams {
186+
shortint_params: sk
187+
.key
188+
.pbs_key()
189+
.key
190+
.compressed_modswitched_conformance_params(),
191+
num_blocks_per_integer: Id::num_blocks(sk.key.pbs_key().message_modulus()),
192+
},
193+
),
194+
id: PhantomData,
195+
}
196+
}
197+
}
198+
156199
impl<Id: FheIntId> ParameterSetConformant for CompressedFheInt<Id> {
157-
type ParameterSet = FheIntConformanceParams<Id>;
200+
type ParameterSet = CompressedFheIntConformanceParams<Id>;
158201

159-
fn is_conformant(&self, params: &FheIntConformanceParams<Id>) -> bool {
202+
fn is_conformant(&self, params: &CompressedFheIntConformanceParams<Id>) -> bool {
160203
let Self {
161204
ciphertext,
162205
id: _,
@@ -178,12 +221,17 @@ pub enum CompressedSignedRadixCiphertext {
178221
ModulusSwitched(CompressedModulusSwitchedSignedRadixCiphertext),
179222
}
180223

224+
#[derive(Copy, Clone)]
225+
pub struct CompressedSignedRadixCiphertextConformanceParams(
226+
pub(crate) CompressedModulusSwitchedRadixCiphertextConformanceParams,
227+
);
228+
181229
impl ParameterSetConformant for CompressedSignedRadixCiphertext {
182-
type ParameterSet = RadixCiphertextConformanceParams;
183-
fn is_conformant(&self, params: &RadixCiphertextConformanceParams) -> bool {
230+
type ParameterSet = CompressedSignedRadixCiphertextConformanceParams;
231+
fn is_conformant(&self, params: &CompressedSignedRadixCiphertextConformanceParams) -> bool {
184232
match self {
185-
Self::Seeded(ct) => ct.is_conformant(params),
186-
Self::ModulusSwitched(ct) => ct.is_conformant(params),
233+
Self::Seeded(ct) => ct.is_conformant(&params.0.into()),
234+
Self::ModulusSwitched(ct) => ct.is_conformant(&params.0),
187235
}
188236
}
189237
}

tfhe/src/high_level_api/integers/signed/static_.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use crate::high_level_api::integers::signed::base::{FheInt, FheIntConformanceParams, FheIntId};
2-
use crate::high_level_api::integers::signed::compressed::CompressedFheInt;
2+
use crate::high_level_api::integers::signed::compressed::{
3+
CompressedFheInt, CompressedFheIntConformanceParams,
4+
};
35
use crate::high_level_api::{FheId, IntegerId};
46
use serde::{Deserialize, Serialize};
57
use tfhe_versionable::NotVersioned;
@@ -52,6 +54,7 @@ macro_rules! static_int_type {
5254

5355
// Conformance Params
5456
pub type [<FheInt $num_bits ConformanceParams>] = FheIntConformanceParams<[<FheInt $num_bits Id>]>;
57+
pub type [<CompressedFheInt $num_bits ConformanceParams>] = CompressedFheIntConformanceParams<[<FheInt $num_bits Id>]>;
5558
}
5659
};
5760
}

0 commit comments

Comments
 (0)