Skip to content

Commit 71eaf01

Browse files
committed
fix(gpu): fix doc tests
1 parent 174adbb commit 71eaf01

File tree

3 files changed

+130
-19
lines changed

3 files changed

+130
-19
lines changed

tfhe/src/high_level_api/booleans/base.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2467,10 +2467,13 @@ impl BitNotSizeOnGpu for FheBool {
24672467
global_state::with_internal_keys(|key| {
24682468
if let InternalServerKey::Cuda(cuda_key) = key {
24692469
let streams = &cuda_key.streams;
2470+
let inner_block = self.ciphertext.on_gpu(streams).as_ref().duplicate(streams);
2471+
let boolean_block = CudaBooleanBlock::from_cuda_radix_ciphertext(inner_block);
2472+
24702473
cuda_key
24712474
.key
24722475
.key
2473-
.get_bitnot_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
2476+
.get_boolean_bitnot_size_on_gpu(&boolean_block, streams)
24742477
} else {
24752478
0
24762479
}

tfhe/src/integer/gpu/mod.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,56 @@ pub(crate) fn cuda_backend_get_boolean_bitop_size_on_gpu(
16271627
size_tracker
16281628
}
16291629

1630+
#[allow(clippy::too_many_arguments)]
1631+
pub(crate) fn cuda_backend_get_boolean_bitnot_size_on_gpu(
1632+
streams: &CudaStreams,
1633+
message_modulus: MessageModulus,
1634+
carry_modulus: CarryModulus,
1635+
glwe_dimension: GlweDimension,
1636+
polynomial_size: PolynomialSize,
1637+
big_lwe_dimension: LweDimension,
1638+
small_lwe_dimension: LweDimension,
1639+
ks_level: DecompositionLevelCount,
1640+
ks_base_log: DecompositionBaseLog,
1641+
pbs_level: DecompositionLevelCount,
1642+
pbs_base_log: DecompositionBaseLog,
1643+
is_unchecked: bool,
1644+
num_blocks: u32,
1645+
pbs_type: PBSType,
1646+
grouping_factor: LweBskGroupingFactor,
1647+
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
1648+
) -> u64 {
1649+
let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration);
1650+
1651+
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
1652+
let size_tracker = unsafe {
1653+
scratch_cuda_boolean_bitnot_64(
1654+
streams.ffi(),
1655+
std::ptr::addr_of_mut!(mem_ptr),
1656+
glwe_dimension.0 as u32,
1657+
polynomial_size.0 as u32,
1658+
big_lwe_dimension.0 as u32,
1659+
small_lwe_dimension.0 as u32,
1660+
ks_level.0 as u32,
1661+
ks_base_log.0 as u32,
1662+
pbs_level.0 as u32,
1663+
pbs_base_log.0 as u32,
1664+
grouping_factor.0 as u32,
1665+
message_modulus.0 as u32,
1666+
carry_modulus.0 as u32,
1667+
pbs_type as u32,
1668+
num_blocks,
1669+
is_unchecked,
1670+
false,
1671+
noise_reduction_type as u32,
1672+
)
1673+
};
1674+
unsafe {
1675+
cleanup_cuda_boolean_bitnot(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
1676+
}
1677+
size_tracker
1678+
}
1679+
16301680
#[allow(clippy::too_many_arguments)]
16311681
pub(crate) fn cuda_backend_get_bitop_size_on_gpu(
16321682
streams: &CudaStreams,

tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphe
55
use crate::integer::gpu::server_key::CudaBootstrappingKey;
66
use crate::integer::gpu::{
77
cuda_backend_boolean_bitnot_assign, cuda_backend_boolean_bitop_assign,
8-
cuda_backend_get_bitop_size_on_gpu, cuda_backend_get_boolean_bitop_size_on_gpu,
9-
cuda_backend_get_full_propagate_assign_size_on_gpu, cuda_backend_unchecked_bitnot_assign,
10-
cuda_backend_unchecked_bitop_assign, BitOpType, CudaServerKey, PBSType,
8+
cuda_backend_get_bitop_size_on_gpu, cuda_backend_get_boolean_bitnot_size_on_gpu,
9+
cuda_backend_get_boolean_bitop_size_on_gpu, cuda_backend_get_full_propagate_assign_size_on_gpu,
10+
cuda_backend_unchecked_bitnot_assign, cuda_backend_unchecked_bitop_assign, BitOpType,
11+
CudaServerKey, PBSType,
1112
};
1213

1314
impl CudaServerKey {
@@ -87,7 +88,7 @@ impl CudaServerKey {
8788
/// ```rust
8889
/// use tfhe::core_crypto::gpu::CudaStreams;
8990
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
90-
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
91+
/// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
9192
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
9293
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
9394
///
@@ -102,13 +103,13 @@ impl CudaServerKey {
102103
/// let ct = cks.encrypt_bool(msg);
103104
///
104105
/// // Copy to GPU
105-
/// let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
106+
/// let d_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
106107
///
107108
/// // Compute homomorphically a bitwise and:
108109
/// let d_ct_res = sks.boolean_bitnot(&d_ct, &streams);
109110
///
110111
/// // Copy back to CPU
111-
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
112+
/// let ct_res = CudaBooleanBlock::to_boolean_block(&d_ct_res, &streams);
112113
///
113114
/// // Decrypt:
114115
/// let dec: bool = cks.decrypt_bool(&ct_res);
@@ -134,7 +135,7 @@ impl CudaServerKey {
134135
/// ```rust
135136
/// use tfhe::core_crypto::gpu::CudaStreams;
136137
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
137-
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
138+
/// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
138139
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
139140
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
140141
///
@@ -151,14 +152,14 @@ impl CudaServerKey {
151152
/// let ct2 = cks.encrypt_bool(msg2);
152153
///
153154
/// // Copy to GPU
154-
/// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
155-
/// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams);
155+
/// let d_ct1 = CudaBooleanBlock::from_boolean_block(&ct1, &streams);
156+
/// let d_ct2 = CudaBooleanBlock::from_boolean_block(&ct2, &streams);
156157
///
157158
/// // Compute homomorphically a bitwise and:
158159
/// let d_ct_res = sks.boolean_bitand(&d_ct1, &d_ct2, &streams);
159160
///
160161
/// // Copy back to CPU
161-
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
162+
/// let ct_res = CudaBooleanBlock::to_boolean_block(&d_ct_res, &streams);
162163
/// let expected = msg1 & msg2;
163164
///
164165
/// // Decrypt:
@@ -195,7 +196,7 @@ impl CudaServerKey {
195196
/// ```rust
196197
/// use tfhe::core_crypto::gpu::CudaStreams;
197198
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
198-
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
199+
/// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
199200
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
200201
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
201202
///
@@ -212,14 +213,14 @@ impl CudaServerKey {
212213
/// let ct2 = cks.encrypt_bool(msg2);
213214
///
214215
/// // Copy to GPU
215-
/// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
216-
/// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams);
216+
/// let d_ct1 = CudaBooleanBlock::from_boolean_block(&ct1, &streams);
217+
/// let d_ct2 = CudaBooleanBlock::from_boolean_block(&ct2, &streams);
217218
///
218219
/// // Compute homomorphically a bitwise or:
219220
/// let d_ct_res = sks.boolean_bitor(&d_ct1, &d_ct2, &streams);
220221
///
221222
/// // Copy back to CPU
222-
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
223+
/// let ct_res = CudaBooleanBlock::to_boolean_block(&d_ct_res, &streams);
223224
/// let expected = msg1 | msg2;
224225
///
225226
/// // Decrypt:
@@ -256,7 +257,7 @@ impl CudaServerKey {
256257
/// ```rust
257258
/// use tfhe::core_crypto::gpu::CudaStreams;
258259
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
259-
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
260+
/// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
260261
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
261262
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
262263
///
@@ -273,14 +274,14 @@ impl CudaServerKey {
273274
/// let ct2 = cks.encrypt_bool(msg2);
274275
///
275276
/// // Copy to GPU
276-
/// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
277-
/// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams);
277+
/// let d_ct1 = CudaBooleanBlock::from_boolean_block(&ct1, &streams);
278+
/// let d_ct2 = CudaBooleanBlock::from_boolean_block(&ct2, &streams);
278279
///
279280
/// // Compute homomorphically a bitwise xor:
280281
/// let d_ct_res = sks.boolean_bitxor(&d_ct1, &d_ct2, &streams);
281282
///
282283
/// // Copy back to CPU
283-
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
284+
/// let ct_res = CudaBooleanBlock::to_boolean_block(&d_ct_res, &streams);
284285
/// let expected = msg1 ^ msg2;
285286
///
286287
/// // Decrypt:
@@ -1282,6 +1283,63 @@ impl CudaServerKey {
12821283
self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::Xor, streams)
12831284
}
12841285

1286+
pub fn get_boolean_bitnot_size_on_gpu(
1287+
&self,
1288+
ct: &CudaBooleanBlock,
1289+
streams: &CudaStreams,
1290+
) -> u64 {
1291+
let boolean_bitnot_mem = match &self.bootstrapping_key {
1292+
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_boolean_bitnot_size_on_gpu(
1293+
streams,
1294+
self.message_modulus,
1295+
self.carry_modulus,
1296+
d_bsk.glwe_dimension,
1297+
d_bsk.polynomial_size,
1298+
self.key_switching_key
1299+
.input_key_lwe_size()
1300+
.to_lwe_dimension(),
1301+
self.key_switching_key
1302+
.output_key_lwe_size()
1303+
.to_lwe_dimension(),
1304+
self.key_switching_key.decomposition_level_count(),
1305+
self.key_switching_key.decomposition_base_log(),
1306+
d_bsk.decomp_level_count,
1307+
d_bsk.decomp_base_log,
1308+
false,
1309+
1u32,
1310+
PBSType::Classical,
1311+
LweBskGroupingFactor(0),
1312+
d_bsk.ms_noise_reduction_configuration.as_ref(),
1313+
),
1314+
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
1315+
cuda_backend_get_boolean_bitnot_size_on_gpu(
1316+
streams,
1317+
self.message_modulus,
1318+
self.carry_modulus,
1319+
d_multibit_bsk.glwe_dimension,
1320+
d_multibit_bsk.polynomial_size,
1321+
self.key_switching_key
1322+
.input_key_lwe_size()
1323+
.to_lwe_dimension(),
1324+
self.key_switching_key
1325+
.output_key_lwe_size()
1326+
.to_lwe_dimension(),
1327+
self.key_switching_key.decomposition_level_count(),
1328+
self.key_switching_key.decomposition_base_log(),
1329+
d_multibit_bsk.decomp_level_count,
1330+
d_multibit_bsk.decomp_base_log,
1331+
false,
1332+
1u32,
1333+
PBSType::MultiBit,
1334+
d_multibit_bsk.grouping_factor,
1335+
None,
1336+
)
1337+
}
1338+
};
1339+
1340+
boolean_bitnot_mem
1341+
}
1342+
12851343
pub fn get_bitnot_size_on_gpu<T: CudaIntegerRadixCiphertext>(
12861344
&self,
12871345
ct: &T,

0 commit comments

Comments
 (0)