|
1 | 1 | use super::{RadixCiphertext, ServerKey, SignedRadixCiphertext}; |
2 | 2 | use crate::core_crypto::commons::generators::DeterministicSeeder; |
3 | 3 | use crate::core_crypto::prelude::DefaultRandomGenerator; |
| 4 | +use crate::shortint::MessageModulus; |
4 | 5 | use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; |
5 | 6 |
|
6 | 7 | pub use tfhe_csprng::seeders::{Seed, Seeder}; |
@@ -229,6 +230,116 @@ impl ServerKey { |
229 | 230 |
|
230 | 231 | result |
231 | 232 | } |
| 233 | + |
| 234 | + /// Generates an encrypted `num_blocks_output` blocks unsigned integer |
| 235 | + /// taken almost uniformly in `[0, excluded_upper_bound[` using the given seed. |
| 236 | + /// The encrypted value is oblivious to the server. |
| 237 | + /// It can be useful to make server random generation deterministic. |
| 238 | + /// |
| 239 | + /// The norm-1 distance (defined as ∆(, ) := 1/2 Sum[ω∈Ω] |P(ω) − Q(ω)| between the actual distribution and the target uniform distribution is below the `max_distance` argument. |
| 240 | + /// |
| 241 | + /// A safe value for `max_distance` is `2^-128`. It is the default value if None is provided. |
| 242 | + /// |
| 243 | + /// Higher values allow better performance but must be considered carefully in the context of their target application |
| 244 | + /// as it may have serious unintended consequences. |
| 245 | + /// |
| 246 | + /// ```rust |
| 247 | + /// use tfhe::integer::gen_keys_radix; |
| 248 | + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128; |
| 249 | + /// use tfhe::Seed; |
| 250 | + /// |
| 251 | + /// let size = 4; |
| 252 | + /// |
| 253 | + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size); |
| 254 | + /// |
| 255 | + /// let excluded_upper_bound = 3; |
| 256 | + /// let num_blocks_output = 8; |
| 257 | + /// |
| 258 | + /// let ct_res = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range2(Seed(0), excluded_upper_bound,num_blocks_output, None); |
| 259 | + /// |
| 260 | + /// let dec_result: u64 = cks.decrypt(&ct_res); |
| 261 | + /// assert!(dec_result < excluded_upper_bound); |
| 262 | + /// ``` |
| 263 | + pub fn par_generate_oblivious_pseudo_random_unsigned_custom_range2( |
| 264 | + &self, |
| 265 | + seed: Seed, |
| 266 | + excluded_upper_bound: u64, |
| 267 | + num_blocks_output: u64, |
| 268 | + max_distance: Option<f64>, |
| 269 | + ) -> RadixCiphertext { |
| 270 | + let max_distance = max_distance.unwrap_or(2_f64.powi(-128)); |
| 271 | + |
| 272 | + let message_modulus = self.message_modulus(); |
| 273 | + |
| 274 | + let num_input_random_bits = num_input_random_bits_for_max_distance( |
| 275 | + excluded_upper_bound, |
| 276 | + max_distance, |
| 277 | + message_modulus, |
| 278 | + ); |
| 279 | + |
| 280 | + self.par_generate_oblivious_pseudo_random_unsigned_custom_range( |
| 281 | + seed, |
| 282 | + num_input_random_bits, |
| 283 | + excluded_upper_bound, |
| 284 | + num_blocks_output, |
| 285 | + ) |
| 286 | + } |
| 287 | +} |
| 288 | + |
| 289 | +fn num_input_random_bits_for_max_distance( |
| 290 | + excluded_upper_bound: u64, |
| 291 | + max_distance: f64, |
| 292 | + message_modulus: MessageModulus, |
| 293 | +) -> u64 { |
| 294 | + let log_message_modulus = message_modulus.0.ilog2() as u64; |
| 295 | + |
| 296 | + let mut random_block_count = 1; |
| 297 | + |
| 298 | + let random_block_count = loop { |
| 299 | + let random_bit_count = random_block_count * log_message_modulus; |
| 300 | + |
| 301 | + let remainder = mod_pow_2(random_bit_count, excluded_upper_bound) as f64; |
| 302 | + |
| 303 | + let distance = remainder * (excluded_upper_bound as f64 - remainder) |
| 304 | + / (2_f64.powi(random_bit_count as i32) * excluded_upper_bound as f64); |
| 305 | + |
| 306 | + if distance < max_distance { |
| 307 | + break random_block_count; |
| 308 | + } |
| 309 | + |
| 310 | + random_block_count += 1; |
| 311 | + }; |
| 312 | + |
| 313 | + random_block_count * log_message_modulus |
| 314 | +} |
| 315 | + |
| 316 | +// Computes 2^exponent % modulus |
| 317 | +fn mod_pow_2(exponent: u64, modulus: u64) -> u64 { |
| 318 | + if modulus == 1 { |
| 319 | + return 0; |
| 320 | + } |
| 321 | + |
| 322 | + let mut result: u128 = 1; |
| 323 | + let mut base: u128 = 2; // We are calculating 2^i |
| 324 | + |
| 325 | + // We cast exponent to u128 to match the loop, though u64 is fine |
| 326 | + let mut exp = exponent; |
| 327 | + let mod_val = modulus as u128; |
| 328 | + |
| 329 | + while exp > 0 { |
| 330 | + // If exponent is odd, multiply result with base |
| 331 | + if exp % 2 == 1 { |
| 332 | + result = (result * base) % mod_val; |
| 333 | + } |
| 334 | + |
| 335 | + // Square the base |
| 336 | + base = (base * base) % mod_val; |
| 337 | + |
| 338 | + // Divide exponent by 2 |
| 339 | + exp /= 2; |
| 340 | + } |
| 341 | + |
| 342 | + result as u64 |
232 | 343 | } |
233 | 344 |
|
234 | 345 | impl ServerKey { |
@@ -362,3 +473,101 @@ impl ServerKey { |
362 | 473 | SignedRadixCiphertext::from(blocks) |
363 | 474 | } |
364 | 475 | } |
| 476 | + |
| 477 | +#[cfg(test)] |
| 478 | +mod test { |
| 479 | + |
| 480 | + use super::*; |
| 481 | + use crate::integer::gen_keys_radix; |
| 482 | + use crate::shortint::oprf::test::test_uniformity; |
| 483 | + use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128; |
| 484 | + use crate::Seed; |
| 485 | + use num_bigint::BigUint; |
| 486 | + use rand::{thread_rng, Rng}; |
| 487 | + |
| 488 | + // Helper: The "Oracle" implementation using BigInt |
| 489 | + // This is slow but mathematically guaranteed to be correct. |
| 490 | + fn oracle_mod_pow_2(exponent: u64, modulus: u64) -> u64 { |
| 491 | + if modulus == 0 { |
| 492 | + panic!("div by 0"); |
| 493 | + } |
| 494 | + if modulus == 1 { |
| 495 | + return 0; |
| 496 | + } |
| 497 | + |
| 498 | + let base = BigUint::from(2u32); |
| 499 | + let exp = BigUint::from(exponent); |
| 500 | + let modu = BigUint::from(modulus); |
| 501 | + |
| 502 | + let res = base.modpow(&exp, &modu); |
| 503 | + res.iter_u64_digits().next().unwrap_or(0) |
| 504 | + } |
| 505 | + |
| 506 | + #[test] |
| 507 | + fn test_edge_cases() { |
| 508 | + // 2^0 % 10 = 1 |
| 509 | + assert_eq!(mod_pow_2(0, 10), 1, "Failed exponent 0"); |
| 510 | + |
| 511 | + // 2^10 % 1 = 0 |
| 512 | + assert_eq!(mod_pow_2(10, 1), 0, "Failed modulus 1"); |
| 513 | + |
| 514 | + // 2^1 % 10 = 2 |
| 515 | + assert_eq!(mod_pow_2(1, 10), 2, "Failed exponent 1"); |
| 516 | + |
| 517 | + // 2^3 % 5 = 8 % 5 = 3 |
| 518 | + assert_eq!(mod_pow_2(3, 5), 3, "Failed small calc"); |
| 519 | + } |
| 520 | + |
| 521 | + #[test] |
| 522 | + fn test_boundaries_and_overflow() { |
| 523 | + assert_eq!(mod_pow_2(2, u64::MAX), 4); |
| 524 | + |
| 525 | + assert_eq!(mod_pow_2(u64::MAX, 3), 2); |
| 526 | + |
| 527 | + assert_eq!(mod_pow_2(5, 32), 0); |
| 528 | + } |
| 529 | + |
| 530 | + #[test] |
| 531 | + fn test_fuzzing_against_oracle() { |
| 532 | + let mut rng = thread_rng(); |
| 533 | + for _ in 0..1_000_000 { |
| 534 | + let exp: u64 = rng.gen(); |
| 535 | + let mod_val: u64 = rng.gen(); |
| 536 | + |
| 537 | + let mod_val = if mod_val == 0 { 1 } else { mod_val }; |
| 538 | + |
| 539 | + let expected = oracle_mod_pow_2(exp, mod_val); |
| 540 | + let actual = mod_pow_2(exp, mod_val); |
| 541 | + |
| 542 | + assert_eq!( |
| 543 | + actual, expected, |
| 544 | + "Mismatch! 2^{} % {} => Ours: {}, Oracle: {}", |
| 545 | + exp, mod_val, actual, expected |
| 546 | + ); |
| 547 | + } |
| 548 | + } |
| 549 | + |
| 550 | + #[test] |
| 551 | + fn test_uniformity_par_generate_oblivious_pseudo_random_unsigned_custom_range2() { |
| 552 | + let num_blocks = 8; |
| 553 | + |
| 554 | + let sample_count: usize = 1_000; |
| 555 | + |
| 556 | + let p_value_limit: f64 = 0.000_1; |
| 557 | + |
| 558 | + let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks); |
| 559 | + |
| 560 | + let excluded_upper_bound = 3; |
| 561 | + |
| 562 | + test_uniformity(sample_count, p_value_limit, excluded_upper_bound, &|seed| { |
| 563 | + let img = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range2( |
| 564 | + Seed(seed as u128), |
| 565 | + excluded_upper_bound, |
| 566 | + num_blocks as u64, |
| 567 | + Some(2_f64.powi(-32)), |
| 568 | + ); |
| 569 | + |
| 570 | + cks.decrypt(&img) |
| 571 | + }); |
| 572 | + } |
| 573 | +} |
0 commit comments