Skip to content

Commit ca73670

Browse files
committed
add oprf over any range
1 parent 0652e85 commit ca73670

File tree

3 files changed

+270
-0
lines changed

3 files changed

+270
-0
lines changed

tfhe/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ rand_distr = "0.4.3"
2727
criterion = "0.5.1"
2828
doc-comment = "0.3.3"
2929
serde_json = "1.0.94"
30+
num-bigint = "0.4.6"
3031
# clap has to be pinned as its minimum supported rust version
3132
# changes often between minor releases, which breaks our CI
3233
clap = { version = "=4.5.30", features = ["derive"] }

tfhe/src/high_level_api/integers/oprf.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,66 @@ impl<Id: FheUintId> FheUint<Id> {
150150
}
151151
})
152152
}
153+
154+
/// Generates an encrypted unsigned integer
155+
/// taken almost uniformly in `[0, excluded_upper_bound[` using the given seed.
156+
/// The encrypted value is oblivious to the server.
157+
/// It can be useful to make server random generation deterministic.
158+
///
159+
/// The norm-1 distance (defined as ∆(, ) := 1/2 Sum[ω∈Ω] |P(ω) − Q(ω)| between the actual distribution and the target uniform distribution is below the `max_distance` argument.
160+
///
161+
/// A safe value for `max_distance` is `2^-128`. It is the default value if None is provided.
162+
///
163+
/// Higher values allow better performance but must be considered carefully in the context of their target application
164+
/// as it may have serious unintended consequences.
165+
///
166+
/// ```rust
167+
/// use tfhe::prelude::FheDecrypt;
168+
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, Seed};
169+
///
170+
/// let config = ConfigBuilder::default().build();
171+
/// let (client_key, server_key) = generate_keys(config);
172+
///
173+
/// set_server_key(server_key);
174+
///
175+
/// let excluded_upper_bound = 3;
176+
///
177+
/// let ct_res = FheUint8::generate_oblivious_pseudo_random_custom_range(Seed(0), excluded_upper_bound, None);
178+
///
179+
/// let dec_result: u16 = ct_res.decrypt(&client_key);
180+
/// assert!(dec_result < excluded_upper_bound as u16);
181+
/// ```
182+
pub fn generate_oblivious_pseudo_random_custom_range(
183+
seed: Seed,
184+
excluded_upper_bound: u64,
185+
max_distance: Option<f64>,
186+
) -> Self {
187+
global_state::with_internal_keys(|key| match key {
188+
InternalServerKey::Cpu(key) => {
189+
let num_blocks_output = Id::num_blocks(key.message_modulus()) as u64;
190+
191+
let ct = key
192+
.pbs_key()
193+
.par_generate_oblivious_pseudo_random_unsigned_custom_range2(
194+
seed,
195+
excluded_upper_bound,
196+
num_blocks_output,
197+
max_distance,
198+
);
199+
200+
Self::new(ct, key.tag.clone(), ReRandomizationMetadata::default())
201+
}
202+
#[cfg(feature = "gpu")]
203+
InternalServerKey::Cuda(cuda_key) => {
204+
panic!("Gpu does not support this operation yet.")
205+
}
206+
#[cfg(feature = "hpu")]
207+
InternalServerKey::Hpu(_device) => {
208+
panic!("Hpu does not support this operation yet.")
209+
}
210+
})
211+
}
212+
153213
#[cfg(feature = "gpu")]
154214
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random_bounded
155215
///

tfhe/src/integer/oprf.rs

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::{RadixCiphertext, ServerKey, SignedRadixCiphertext};
22
use crate::core_crypto::commons::generators::DeterministicSeeder;
33
use crate::core_crypto::prelude::DefaultRandomGenerator;
4+
use crate::shortint::MessageModulus;
45
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
56

67
pub use tfhe_csprng::seeders::{Seed, Seeder};
@@ -229,6 +230,116 @@ impl ServerKey {
229230

230231
result
231232
}
233+
234+
/// Generates an encrypted `num_blocks_output` blocks unsigned integer
235+
/// taken almost uniformly in `[0, excluded_upper_bound[` using the given seed.
236+
/// The encrypted value is oblivious to the server.
237+
/// It can be useful to make server random generation deterministic.
238+
///
239+
/// The norm-1 distance (defined as ∆(, ) := 1/2 Sum[ω∈Ω] |P(ω) − Q(ω)| between the actual distribution and the target uniform distribution is below the `max_distance` argument.
240+
///
241+
/// A safe value for `max_distance` is `2^-128`. It is the default value if None is provided.
242+
///
243+
/// Higher values allow better performance but must be considered carefully in the context of their target application
244+
/// as it may have serious unintended consequences.
245+
///
246+
/// ```rust
247+
/// use tfhe::integer::gen_keys_radix;
248+
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
249+
/// use tfhe::Seed;
250+
///
251+
/// let size = 4;
252+
///
253+
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
254+
///
255+
/// let excluded_upper_bound = 3;
256+
/// let num_blocks_output = 8;
257+
///
258+
/// let ct_res = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range2(Seed(0), excluded_upper_bound,num_blocks_output, None);
259+
///
260+
/// let dec_result: u64 = cks.decrypt(&ct_res);
261+
/// assert!(dec_result < excluded_upper_bound);
262+
/// ```
263+
pub fn par_generate_oblivious_pseudo_random_unsigned_custom_range2(
264+
&self,
265+
seed: Seed,
266+
excluded_upper_bound: u64,
267+
num_blocks_output: u64,
268+
max_distance: Option<f64>,
269+
) -> RadixCiphertext {
270+
let max_distance = max_distance.unwrap_or(2_f64.powi(-128));
271+
272+
let message_modulus = self.message_modulus();
273+
274+
let num_input_random_bits = num_input_random_bits_for_max_distance(
275+
excluded_upper_bound,
276+
max_distance,
277+
message_modulus,
278+
);
279+
280+
self.par_generate_oblivious_pseudo_random_unsigned_custom_range(
281+
seed,
282+
num_input_random_bits,
283+
excluded_upper_bound,
284+
num_blocks_output,
285+
)
286+
}
287+
}
288+
289+
fn num_input_random_bits_for_max_distance(
290+
excluded_upper_bound: u64,
291+
max_distance: f64,
292+
message_modulus: MessageModulus,
293+
) -> u64 {
294+
let log_message_modulus = message_modulus.0.ilog2() as u64;
295+
296+
let mut random_block_count = 1;
297+
298+
let random_block_count = loop {
299+
let random_bit_count = random_block_count * log_message_modulus;
300+
301+
let remainder = mod_pow_2(random_bit_count, excluded_upper_bound) as f64;
302+
303+
let distance = remainder * (excluded_upper_bound as f64 - remainder)
304+
/ (2_f64.powi(random_bit_count as i32) * excluded_upper_bound as f64);
305+
306+
if distance < max_distance {
307+
break random_block_count;
308+
}
309+
310+
random_block_count += 1;
311+
};
312+
313+
random_block_count * log_message_modulus
314+
}
315+
316+
// Computes 2^exponent % modulus
317+
fn mod_pow_2(exponent: u64, modulus: u64) -> u64 {
318+
if modulus == 1 {
319+
return 0;
320+
}
321+
322+
let mut result: u128 = 1;
323+
let mut base: u128 = 2; // We are calculating 2^i
324+
325+
// We cast exponent to u128 to match the loop, though u64 is fine
326+
let mut exp = exponent;
327+
let mod_val = modulus as u128;
328+
329+
while exp > 0 {
330+
// If exponent is odd, multiply result with base
331+
if exp % 2 == 1 {
332+
result = (result * base) % mod_val;
333+
}
334+
335+
// Square the base
336+
base = (base * base) % mod_val;
337+
338+
// Divide exponent by 2
339+
exp /= 2;
340+
}
341+
342+
result as u64
232343
}
233344

234345
impl ServerKey {
@@ -362,3 +473,101 @@ impl ServerKey {
362473
SignedRadixCiphertext::from(blocks)
363474
}
364475
}
476+
477+
#[cfg(test)]
478+
mod test {
479+
480+
use super::*;
481+
use crate::integer::gen_keys_radix;
482+
use crate::shortint::oprf::test::test_uniformity;
483+
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
484+
use crate::Seed;
485+
use num_bigint::BigUint;
486+
use rand::{thread_rng, Rng};
487+
488+
// Helper: The "Oracle" implementation using BigInt
489+
// This is slow but mathematically guaranteed to be correct.
490+
fn oracle_mod_pow_2(exponent: u64, modulus: u64) -> u64 {
491+
if modulus == 0 {
492+
panic!("div by 0");
493+
}
494+
if modulus == 1 {
495+
return 0;
496+
}
497+
498+
let base = BigUint::from(2u32);
499+
let exp = BigUint::from(exponent);
500+
let modu = BigUint::from(modulus);
501+
502+
let res = base.modpow(&exp, &modu);
503+
res.iter_u64_digits().next().unwrap_or(0)
504+
}
505+
506+
#[test]
507+
fn test_edge_cases() {
508+
// 2^0 % 10 = 1
509+
assert_eq!(mod_pow_2(0, 10), 1, "Failed exponent 0");
510+
511+
// 2^10 % 1 = 0
512+
assert_eq!(mod_pow_2(10, 1), 0, "Failed modulus 1");
513+
514+
// 2^1 % 10 = 2
515+
assert_eq!(mod_pow_2(1, 10), 2, "Failed exponent 1");
516+
517+
// 2^3 % 5 = 8 % 5 = 3
518+
assert_eq!(mod_pow_2(3, 5), 3, "Failed small calc");
519+
}
520+
521+
#[test]
522+
fn test_boundaries_and_overflow() {
523+
assert_eq!(mod_pow_2(2, u64::MAX), 4);
524+
525+
assert_eq!(mod_pow_2(u64::MAX, 3), 2);
526+
527+
assert_eq!(mod_pow_2(5, 32), 0);
528+
}
529+
530+
#[test]
531+
fn test_fuzzing_against_oracle() {
532+
let mut rng = thread_rng();
533+
for _ in 0..1_000_000 {
534+
let exp: u64 = rng.gen();
535+
let mod_val: u64 = rng.gen();
536+
537+
let mod_val = if mod_val == 0 { 1 } else { mod_val };
538+
539+
let expected = oracle_mod_pow_2(exp, mod_val);
540+
let actual = mod_pow_2(exp, mod_val);
541+
542+
assert_eq!(
543+
actual, expected,
544+
"Mismatch! 2^{} % {} => Ours: {}, Oracle: {}",
545+
exp, mod_val, actual, expected
546+
);
547+
}
548+
}
549+
550+
#[test]
551+
fn test_uniformity_par_generate_oblivious_pseudo_random_unsigned_custom_range2() {
552+
let num_blocks = 8;
553+
554+
let sample_count: usize = 1_000;
555+
556+
let p_value_limit: f64 = 0.000_1;
557+
558+
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks);
559+
560+
let excluded_upper_bound = 3;
561+
562+
test_uniformity(sample_count, p_value_limit, excluded_upper_bound, &|seed| {
563+
let img = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range2(
564+
Seed(seed as u128),
565+
excluded_upper_bound,
566+
num_blocks as u64,
567+
Some(2_f64.powi(-32)),
568+
);
569+
570+
cks.decrypt(&img)
571+
});
572+
}
573+
}

0 commit comments

Comments
 (0)