Skip to content

Commit 1b27416

Browse files
committed
refactor(gpu): moving cast_to_signed to the backend
1 parent 1d98c2b commit 1b27416

File tree

7 files changed

+310
-177
lines changed

7 files changed

+310
-177
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/cast.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,37 @@ template <typename Torus> struct int_cast_to_unsigned_buffer {
127127
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
128128
}
129129
};
130+
131+
template <typename Torus> struct int_cast_to_signed_buffer {
132+
int_radix_params params;
133+
bool allocate_gpu_memory;
134+
uint32_t num_input_blocks;
135+
uint32_t target_num_blocks;
136+
137+
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
138+
139+
int_cast_to_signed_buffer(CudaStreams streams, int_radix_params params,
140+
uint32_t num_input_blocks,
141+
uint32_t target_num_blocks,
142+
bool allocate_gpu_memory, uint64_t &size_tracker) {
143+
this->params = params;
144+
this->allocate_gpu_memory = allocate_gpu_memory;
145+
this->num_input_blocks = num_input_blocks;
146+
this->target_num_blocks = target_num_blocks;
147+
this->extend_buffer = nullptr;
148+
149+
if (target_num_blocks > num_input_blocks) {
150+
uint32_t num_additional_blocks = target_num_blocks - num_input_blocks;
151+
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
152+
streams, params, num_input_blocks, num_additional_blocks,
153+
allocate_gpu_memory, size_tracker);
154+
}
155+
}
156+
157+
void release(CudaStreams streams) {
158+
if (this->extend_buffer) {
159+
this->extend_buffer->release(streams);
160+
delete this->extend_buffer;
161+
}
162+
}
163+
};

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,24 @@ void cuda_unchecked_index_of_clear_64(
10041004

10051005
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
10061006
int8_t **mem_ptr_void);
1007+
1008+
uint64_t scratch_cuda_cast_to_signed_64(
1009+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
1010+
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
1011+
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
1012+
uint32_t grouping_factor, uint32_t num_input_blocks,
1013+
uint32_t target_num_blocks, uint32_t message_modulus,
1014+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
1015+
PBS_MS_REDUCTION_T noise_reduction_type);
1016+
1017+
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
1018+
CudaRadixCiphertextFFI *output,
1019+
CudaRadixCiphertextFFI const *input, int8_t *mem,
1020+
bool input_is_signed, void *const *bsks,
1021+
void *const *ksks);
1022+
1023+
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
1024+
int8_t **mem_ptr_void);
10071025
} // extern C
10081026

10091027
#endif // CUDA_INTEGER_H

backends/tfhe-cuda-backend/cuda/src/integer/cast.cu

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,45 @@ void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
116116
delete mem_ptr;
117117
*mem_ptr_void = nullptr;
118118
}
119+
120+
uint64_t scratch_cuda_cast_to_signed_64(
121+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
122+
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
123+
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
124+
uint32_t grouping_factor, uint32_t num_input_blocks,
125+
uint32_t target_num_blocks, uint32_t message_modulus,
126+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
127+
PBS_MS_REDUCTION_T noise_reduction_type) {
128+
129+
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
130+
glwe_dimension * polynomial_size, lwe_dimension,
131+
ks_level, ks_base_log, pbs_level, pbs_base_log,
132+
grouping_factor, message_modulus, carry_modulus,
133+
noise_reduction_type);
134+
135+
return scratch_cuda_cast_to_signed<uint64_t>(
136+
CudaStreams(streams), (int_cast_to_signed_buffer<uint64_t> **)mem_ptr,
137+
params, num_input_blocks, target_num_blocks, allocate_gpu_memory);
138+
}
139+
140+
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
141+
CudaRadixCiphertextFFI *output,
142+
CudaRadixCiphertextFFI const *input, int8_t *mem,
143+
bool input_is_signed, void *const *bsks,
144+
void *const *ksks) {
145+
146+
host_cast_to_signed<uint64_t>(CudaStreams(streams), output, input,
147+
(int_cast_to_signed_buffer<uint64_t> *)mem,
148+
input_is_signed, bsks, (uint64_t **)ksks);
149+
}
150+
151+
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
152+
int8_t **mem_ptr_void) {
153+
int_cast_to_signed_buffer<uint64_t> *mem_ptr =
154+
(int_cast_to_signed_buffer<uint64_t> *)(*mem_ptr_void);
155+
156+
mem_ptr->release(CudaStreams(streams));
157+
158+
delete mem_ptr;
159+
*mem_ptr_void = nullptr;
160+
}

backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,49 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
160160
}
161161
}
162162

163+
template <typename Torus>
164+
uint64_t scratch_cuda_cast_to_signed(CudaStreams streams,
165+
int_cast_to_signed_buffer<Torus> **mem_ptr,
166+
int_radix_params params,
167+
uint32_t num_input_blocks,
168+
uint32_t target_num_blocks,
169+
bool allocate_gpu_memory) {
170+
171+
uint64_t size_tracker = 0;
172+
*mem_ptr = new int_cast_to_signed_buffer<Torus>(
173+
streams, params, num_input_blocks, target_num_blocks, allocate_gpu_memory,
174+
size_tracker);
175+
176+
return size_tracker;
177+
}
178+
179+
template <typename Torus>
180+
__host__ void
181+
host_cast_to_signed(CudaStreams streams, CudaRadixCiphertextFFI *output,
182+
CudaRadixCiphertextFFI const *input,
183+
int_cast_to_signed_buffer<Torus> *mem_ptr,
184+
bool input_is_signed, void *const *bsks, Torus **ksks) {
185+
186+
uint32_t current_num_blocks = input->num_radix_blocks;
187+
uint32_t target_num_blocks = mem_ptr->target_num_blocks;
188+
189+
if (input_is_signed) {
190+
if (target_num_blocks > current_num_blocks) {
191+
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
192+
host_extend_radix_with_sign_msb<Torus>(streams, output, input,
193+
mem_ptr->extend_buffer,
194+
num_blocks_to_add, bsks, ksks);
195+
} else {
196+
host_trim_radix_blocks_msb<Torus>(output, input, streams);
197+
}
198+
} else {
199+
if (target_num_blocks > current_num_blocks) {
200+
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
201+
streams);
202+
} else {
203+
host_trim_radix_blocks_msb<Torus>(output, input, streams);
204+
}
205+
}
206+
}
207+
163208
#endif

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,41 @@ unsafe extern "C" {
21862186
mem_ptr_void: *mut *mut i8,
21872187
);
21882188
}
2189+
unsafe extern "C" {
2190+
pub fn scratch_cuda_cast_to_signed_64(
2191+
streams: CudaStreamsFFI,
2192+
mem_ptr: *mut *mut i8,
2193+
glwe_dimension: u32,
2194+
polynomial_size: u32,
2195+
lwe_dimension: u32,
2196+
ks_level: u32,
2197+
ks_base_log: u32,
2198+
pbs_level: u32,
2199+
pbs_base_log: u32,
2200+
grouping_factor: u32,
2201+
num_input_blocks: u32,
2202+
target_num_blocks: u32,
2203+
message_modulus: u32,
2204+
carry_modulus: u32,
2205+
pbs_type: PBS_TYPE,
2206+
allocate_gpu_memory: bool,
2207+
noise_reduction_type: PBS_MS_REDUCTION_T,
2208+
) -> u64;
2209+
}
2210+
unsafe extern "C" {
2211+
pub fn cuda_cast_to_signed_64(
2212+
streams: CudaStreamsFFI,
2213+
output: *mut CudaRadixCiphertextFFI,
2214+
input: *const CudaRadixCiphertextFFI,
2215+
mem: *mut i8,
2216+
input_is_signed: bool,
2217+
bsks: *const *mut ffi::c_void,
2218+
ksks: *const *mut ffi::c_void,
2219+
);
2220+
}
2221+
unsafe extern "C" {
2222+
pub fn cleanup_cuda_cast_to_signed_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
2223+
}
21892224
unsafe extern "C" {
21902225
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
21912226
streams: CudaStreamsFFI,

tfhe/src/integer/gpu/mod.rs

Lines changed: 80 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5910,82 +5910,6 @@ pub(crate) unsafe fn cuda_backend_unchecked_partial_sum_ciphertexts_assign<
59105910
update_noise_degree(result, &cuda_ffi_result);
59115911
}
59125912

5913-
#[allow(clippy::too_many_arguments)]
5914-
/// # Safety
5915-
///
5916-
/// - The data must not be moved or dropped while being used by the CUDA kernel.
5917-
/// - This function assumes exclusive access to the passed data; violating this may lead to
5918-
/// undefined behavior.
5919-
pub(crate) unsafe fn cuda_backend_extend_radix_with_sign_msb<T: UnsignedInteger, B: Numeric>(
5920-
streams: &CudaStreams,
5921-
output: &mut CudaRadixCiphertext,
5922-
ct: &CudaRadixCiphertext,
5923-
bootstrapping_key: &CudaVec<B>,
5924-
keyswitch_key: &CudaVec<T>,
5925-
lwe_dimension: LweDimension,
5926-
glwe_dimension: GlweDimension,
5927-
polynomial_size: PolynomialSize,
5928-
ks_level: DecompositionLevelCount,
5929-
ks_base_log: DecompositionBaseLog,
5930-
pbs_level: DecompositionLevelCount,
5931-
pbs_base_log: DecompositionBaseLog,
5932-
num_additional_blocks: u32,
5933-
pbs_type: PBSType,
5934-
grouping_factor: LweBskGroupingFactor,
5935-
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
5936-
) {
5937-
let message_modulus = ct.info.blocks.first().unwrap().message_modulus;
5938-
let carry_modulus = ct.info.blocks.first().unwrap().carry_modulus;
5939-
5940-
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
5941-
5942-
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
5943-
5944-
let mut input_degrees = ct.info.blocks.iter().map(|b| b.degree.0).collect();
5945-
let mut input_noise_levels = ct.info.blocks.iter().map(|b| b.noise_level.0).collect();
5946-
let cuda_ffi_radix_input =
5947-
prepare_cuda_radix_ffi(ct, &mut input_degrees, &mut input_noise_levels);
5948-
5949-
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
5950-
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
5951-
let mut cuda_ffi_radix_output =
5952-
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
5953-
5954-
scratch_cuda_extend_radix_with_sign_msb_64(
5955-
streams.ffi(),
5956-
std::ptr::addr_of_mut!(mem_ptr),
5957-
glwe_dimension.0 as u32,
5958-
polynomial_size.0 as u32,
5959-
lwe_dimension.0 as u32,
5960-
ks_level.0 as u32,
5961-
ks_base_log.0 as u32,
5962-
pbs_level.0 as u32,
5963-
pbs_base_log.0 as u32,
5964-
grouping_factor.0 as u32,
5965-
1u32,
5966-
num_additional_blocks,
5967-
message_modulus.0 as u32,
5968-
carry_modulus.0 as u32,
5969-
pbs_type as u32,
5970-
true,
5971-
noise_reduction_type as u32,
5972-
);
5973-
5974-
cuda_extend_radix_with_sign_msb_64(
5975-
streams.ffi(),
5976-
&raw mut cuda_ffi_radix_output,
5977-
&raw const cuda_ffi_radix_input,
5978-
mem_ptr,
5979-
num_additional_blocks,
5980-
bootstrapping_key.ptr.as_ptr(),
5981-
keyswitch_key.ptr.as_ptr(),
5982-
);
5983-
5984-
cleanup_cuda_extend_radix_with_sign_msb_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
5985-
5986-
update_noise_degree(output, &cuda_ffi_radix_output);
5987-
}
5988-
59895913
#[allow(clippy::too_many_arguments)]
59905914
/// # Safety
59915915
///
@@ -10184,3 +10108,83 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear<
1018410108
update_noise_degree(index_ct, &ffi_index);
1018510109
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
1018610110
}
10111+
10112+
#[allow(clippy::too_many_arguments)]
10113+
/// # Safety
10114+
///
10115+
/// - The data must not be moved or dropped while being used by the CUDA kernel.
10116+
/// - This function assumes exclusive access to the passed data; violating this may lead to
10117+
/// undefined behavior.
10118+
pub(crate) unsafe fn cuda_backend_cast_to_signed<T: UnsignedInteger, B: Numeric>(
10119+
streams: &CudaStreams,
10120+
output: &mut CudaRadixCiphertext,
10121+
input: &CudaRadixCiphertext,
10122+
input_is_signed: bool,
10123+
bootstrapping_key: &CudaVec<B>,
10124+
keyswitch_key: &CudaVec<T>,
10125+
message_modulus: MessageModulus,
10126+
carry_modulus: CarryModulus,
10127+
glwe_dimension: GlweDimension,
10128+
polynomial_size: PolynomialSize,
10129+
big_lwe_dimension: LweDimension,
10130+
ks_level: DecompositionLevelCount,
10131+
ks_base_log: DecompositionBaseLog,
10132+
pbs_level: DecompositionLevelCount,
10133+
pbs_base_log: DecompositionBaseLog,
10134+
pbs_type: PBSType,
10135+
grouping_factor: LweBskGroupingFactor,
10136+
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
10137+
) {
10138+
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
10139+
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
10140+
10141+
let num_input_blocks = input.d_blocks.lwe_ciphertext_count().0 as u32;
10142+
let target_num_blocks = output.d_blocks.lwe_ciphertext_count().0 as u32;
10143+
10144+
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
10145+
10146+
let mut input_degrees = input.info.blocks.iter().map(|b| b.degree.0).collect();
10147+
let mut input_noise_levels = input.info.blocks.iter().map(|b| b.noise_level.0).collect();
10148+
let cuda_ffi_input = prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels);
10149+
10150+
let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect();
10151+
let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect();
10152+
let mut cuda_ffi_output =
10153+
prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels);
10154+
10155+
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
10156+
10157+
scratch_cuda_cast_to_signed_64(
10158+
streams.ffi(),
10159+
std::ptr::addr_of_mut!(mem_ptr),
10160+
glwe_dimension.0 as u32,
10161+
polynomial_size.0 as u32,
10162+
big_lwe_dimension.0 as u32,
10163+
ks_level.0 as u32,
10164+
ks_base_log.0 as u32,
10165+
pbs_level.0 as u32,
10166+
pbs_base_log.0 as u32,
10167+
grouping_factor.0 as u32,
10168+
num_input_blocks,
10169+
target_num_blocks,
10170+
message_modulus.0 as u32,
10171+
carry_modulus.0 as u32,
10172+
pbs_type as u32,
10173+
true,
10174+
noise_reduction_type as u32,
10175+
);
10176+
10177+
cuda_cast_to_signed_64(
10178+
streams.ffi(),
10179+
&raw mut cuda_ffi_output,
10180+
&raw const cuda_ffi_input,
10181+
mem_ptr,
10182+
input_is_signed,
10183+
bootstrapping_key.ptr.as_ptr(),
10184+
keyswitch_key.ptr.as_ptr(),
10185+
);
10186+
10187+
cleanup_cuda_cast_to_signed_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
10188+
10189+
update_noise_degree(output, &cuda_ffi_output);
10190+
}

0 commit comments

Comments
 (0)