Skip to content

Commit 51b7388

Browse files
committed
refactor(gpu): moving cast_to_signed to the backend
1 parent 1d98c2b commit 51b7388

File tree

7 files changed

+314
-281
lines changed

7 files changed

+314
-281
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/cast.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,38 @@ template <typename Torus> struct int_cast_to_unsigned_buffer {
127127
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
128128
}
129129
};
130+
131+
template <typename Torus> struct int_cast_to_signed_buffer {
132+
int_radix_params params;
133+
bool allocate_gpu_memory;
134+
uint32_t num_input_blocks;
135+
uint32_t target_num_blocks;
136+
137+
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
138+
139+
int_cast_to_signed_buffer(CudaStreams streams, int_radix_params params,
140+
uint32_t num_input_blocks,
141+
uint32_t target_num_blocks, bool input_is_signed,
142+
bool allocate_gpu_memory, uint64_t &size_tracker) {
143+
this->params = params;
144+
this->allocate_gpu_memory = allocate_gpu_memory;
145+
this->num_input_blocks = num_input_blocks;
146+
this->target_num_blocks = target_num_blocks;
147+
this->extend_buffer = nullptr;
148+
149+
if (input_is_signed && target_num_blocks > num_input_blocks) {
150+
uint32_t num_additional_blocks = target_num_blocks - num_input_blocks;
151+
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
152+
streams, params, num_input_blocks, num_additional_blocks,
153+
allocate_gpu_memory, size_tracker);
154+
}
155+
}
156+
157+
void release(CudaStreams streams) {
158+
if (this->extend_buffer) {
159+
this->extend_buffer->release(streams);
160+
delete this->extend_buffer;
161+
}
162+
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
163+
}
164+
};

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -610,25 +610,6 @@ void cuda_integer_unsigned_scalar_div_radix_64(
610610
void cleanup_cuda_integer_unsigned_scalar_div_radix_64(CudaStreamsFFI streams,
611611
int8_t **mem_ptr_void);
612612

613-
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
614-
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
615-
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
616-
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
617-
uint32_t grouping_factor, uint32_t num_blocks,
618-
uint32_t num_additional_blocks, uint32_t message_modulus,
619-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
620-
PBS_MS_REDUCTION_T noise_reduction_type);
621-
622-
void cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
623-
CudaRadixCiphertextFFI *output,
624-
CudaRadixCiphertextFFI const *input,
625-
int8_t *mem_ptr,
626-
uint32_t num_additional_blocks,
627-
void *const *bsks, void *const *ksks);
628-
629-
void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
630-
int8_t **mem_ptr_void);
631-
632613
uint64_t scratch_cuda_integer_signed_scalar_div_radix_64(
633614
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
634615
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
@@ -1004,6 +985,24 @@ void cuda_unchecked_index_of_clear_64(
1004985

1005986
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
1006987
int8_t **mem_ptr_void);
988+
989+
uint64_t scratch_cuda_cast_to_signed_64(
990+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
991+
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
992+
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
993+
uint32_t grouping_factor, uint32_t num_input_blocks,
994+
uint32_t target_num_blocks, uint32_t message_modulus,
995+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool input_is_signed,
996+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
997+
998+
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
999+
CudaRadixCiphertextFFI *output,
1000+
CudaRadixCiphertextFFI const *input, int8_t *mem,
1001+
bool input_is_signed, void *const *bsks,
1002+
void *const *ksks);
1003+
1004+
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
1005+
int8_t **mem_ptr_void);
10071006
} // extern C
10081007

10091008
#endif // CUDA_INTEGER_H

backends/tfhe-cuda-backend/cuda/src/integer/cast.cu

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,53 +27,6 @@ void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
2727
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
2828
}
2929

30-
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
31-
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
32-
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
33-
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
34-
uint32_t grouping_factor, uint32_t num_blocks,
35-
uint32_t num_additional_blocks, uint32_t message_modulus,
36-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
37-
PBS_MS_REDUCTION_T noise_reduction_type) {
38-
39-
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
40-
glwe_dimension * polynomial_size, lwe_dimension,
41-
ks_level, ks_base_log, pbs_level, pbs_base_log,
42-
grouping_factor, message_modulus, carry_modulus,
43-
noise_reduction_type);
44-
45-
return scratch_extend_radix_with_sign_msb<uint64_t>(
46-
CudaStreams(streams),
47-
(int_extend_radix_with_sign_msb_buffer<uint64_t> **)mem_ptr, params,
48-
num_blocks, num_additional_blocks, allocate_gpu_memory);
49-
}
50-
51-
void cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
52-
CudaRadixCiphertextFFI *output,
53-
CudaRadixCiphertextFFI const *input,
54-
int8_t *mem_ptr,
55-
uint32_t num_additional_blocks,
56-
void *const *bsks, void *const *ksks) {
57-
PUSH_RANGE("cast")
58-
host_extend_radix_with_sign_msb<uint64_t>(
59-
CudaStreams(streams), output, input,
60-
(int_extend_radix_with_sign_msb_buffer<uint64_t> *)mem_ptr,
61-
num_additional_blocks, bsks, (uint64_t **)ksks);
62-
POP_RANGE()
63-
}
64-
65-
void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
66-
int8_t **mem_ptr_void) {
67-
PUSH_RANGE("clean cast")
68-
int_extend_radix_with_sign_msb_buffer<uint64_t> *mem_ptr =
69-
(int_extend_radix_with_sign_msb_buffer<uint64_t> *)(*mem_ptr_void);
70-
71-
mem_ptr->release(CudaStreams(streams));
72-
POP_RANGE()
73-
delete mem_ptr;
74-
*mem_ptr_void = nullptr;
75-
}
76-
7730
uint64_t scratch_cuda_cast_to_unsigned_64(
7831
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
7932
uint32_t polynomial_size, uint32_t big_lwe_dimension,
@@ -116,3 +69,46 @@ void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
11669
delete mem_ptr;
11770
*mem_ptr_void = nullptr;
11871
}
72+
73+
uint64_t scratch_cuda_cast_to_signed_64(
74+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
75+
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
76+
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
77+
uint32_t grouping_factor, uint32_t num_input_blocks,
78+
uint32_t target_num_blocks, uint32_t message_modulus,
79+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool input_is_signed,
80+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
81+
82+
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
83+
glwe_dimension * polynomial_size, lwe_dimension,
84+
ks_level, ks_base_log, pbs_level, pbs_base_log,
85+
grouping_factor, message_modulus, carry_modulus,
86+
noise_reduction_type);
87+
88+
return scratch_cuda_cast_to_signed<uint64_t>(
89+
CudaStreams(streams), (int_cast_to_signed_buffer<uint64_t> **)mem_ptr,
90+
params, num_input_blocks, target_num_blocks, input_is_signed,
91+
allocate_gpu_memory);
92+
}
93+
94+
void cuda_cast_to_signed_64(CudaStreamsFFI streams,
95+
CudaRadixCiphertextFFI *output,
96+
CudaRadixCiphertextFFI const *input, int8_t *mem,
97+
bool input_is_signed, void *const *bsks,
98+
void *const *ksks) {
99+
100+
host_cast_to_signed<uint64_t>(CudaStreams(streams), output, input,
101+
(int_cast_to_signed_buffer<uint64_t> *)mem,
102+
input_is_signed, bsks, (uint64_t **)ksks);
103+
}
104+
105+
void cleanup_cuda_cast_to_signed_64(CudaStreamsFFI streams,
106+
int8_t **mem_ptr_void) {
107+
int_cast_to_signed_buffer<uint64_t> *mem_ptr =
108+
(int_cast_to_signed_buffer<uint64_t> *)(*mem_ptr_void);
109+
110+
mem_ptr->release(CudaStreams(streams));
111+
112+
delete mem_ptr;
113+
*mem_ptr_void = nullptr;
114+
}

backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,49 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
160160
}
161161
}
162162

163+
template <typename Torus>
164+
uint64_t
165+
scratch_cuda_cast_to_signed(CudaStreams streams,
166+
int_cast_to_signed_buffer<Torus> **mem_ptr,
167+
int_radix_params params, uint32_t num_input_blocks,
168+
uint32_t target_num_blocks, bool input_is_signed,
169+
bool allocate_gpu_memory) {
170+
171+
uint64_t size_tracker = 0;
172+
*mem_ptr = new int_cast_to_signed_buffer<Torus>(
173+
streams, params, num_input_blocks, target_num_blocks, input_is_signed,
174+
allocate_gpu_memory, size_tracker);
175+
176+
return size_tracker;
177+
}
178+
179+
template <typename Torus>
180+
__host__ void
181+
host_cast_to_signed(CudaStreams streams, CudaRadixCiphertextFFI *output,
182+
CudaRadixCiphertextFFI const *input,
183+
int_cast_to_signed_buffer<Torus> *mem_ptr,
184+
bool input_is_signed, void *const *bsks, Torus **ksks) {
185+
186+
uint32_t current_num_blocks = input->num_radix_blocks;
187+
uint32_t target_num_blocks = mem_ptr->target_num_blocks;
188+
189+
if (input_is_signed) {
190+
if (target_num_blocks > current_num_blocks) {
191+
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
192+
host_extend_radix_with_sign_msb<Torus>(streams, output, input,
193+
mem_ptr->extend_buffer,
194+
num_blocks_to_add, bsks, ksks);
195+
} else {
196+
host_trim_radix_blocks_msb<Torus>(output, input, streams);
197+
}
198+
} else {
199+
if (target_num_blocks > current_num_blocks) {
200+
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
201+
streams);
202+
} else {
203+
host_trim_radix_blocks_msb<Torus>(output, input, streams);
204+
}
205+
}
206+
}
207+
163208
#endif

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,44 +1353,6 @@ unsafe extern "C" {
13531353
mem_ptr_void: *mut *mut i8,
13541354
);
13551355
}
1356-
unsafe extern "C" {
1357-
pub fn scratch_cuda_extend_radix_with_sign_msb_64(
1358-
streams: CudaStreamsFFI,
1359-
mem_ptr: *mut *mut i8,
1360-
glwe_dimension: u32,
1361-
polynomial_size: u32,
1362-
lwe_dimension: u32,
1363-
ks_level: u32,
1364-
ks_base_log: u32,
1365-
pbs_level: u32,
1366-
pbs_base_log: u32,
1367-
grouping_factor: u32,
1368-
num_blocks: u32,
1369-
num_additional_blocks: u32,
1370-
message_modulus: u32,
1371-
carry_modulus: u32,
1372-
pbs_type: PBS_TYPE,
1373-
allocate_gpu_memory: bool,
1374-
noise_reduction_type: PBS_MS_REDUCTION_T,
1375-
) -> u64;
1376-
}
1377-
unsafe extern "C" {
1378-
pub fn cuda_extend_radix_with_sign_msb_64(
1379-
streams: CudaStreamsFFI,
1380-
output: *mut CudaRadixCiphertextFFI,
1381-
input: *const CudaRadixCiphertextFFI,
1382-
mem_ptr: *mut i8,
1383-
num_additional_blocks: u32,
1384-
bsks: *const *mut ffi::c_void,
1385-
ksks: *const *mut ffi::c_void,
1386-
);
1387-
}
1388-
unsafe extern "C" {
1389-
pub fn cleanup_cuda_extend_radix_with_sign_msb_64(
1390-
streams: CudaStreamsFFI,
1391-
mem_ptr_void: *mut *mut i8,
1392-
);
1393-
}
13941356
unsafe extern "C" {
13951357
pub fn scratch_cuda_integer_signed_scalar_div_radix_64(
13961358
streams: CudaStreamsFFI,
@@ -2186,6 +2148,42 @@ unsafe extern "C" {
21862148
mem_ptr_void: *mut *mut i8,
21872149
);
21882150
}
2151+
unsafe extern "C" {
2152+
pub fn scratch_cuda_cast_to_signed_64(
2153+
streams: CudaStreamsFFI,
2154+
mem_ptr: *mut *mut i8,
2155+
glwe_dimension: u32,
2156+
polynomial_size: u32,
2157+
lwe_dimension: u32,
2158+
ks_level: u32,
2159+
ks_base_log: u32,
2160+
pbs_level: u32,
2161+
pbs_base_log: u32,
2162+
grouping_factor: u32,
2163+
num_input_blocks: u32,
2164+
target_num_blocks: u32,
2165+
message_modulus: u32,
2166+
carry_modulus: u32,
2167+
pbs_type: PBS_TYPE,
2168+
input_is_signed: bool,
2169+
allocate_gpu_memory: bool,
2170+
noise_reduction_type: PBS_MS_REDUCTION_T,
2171+
) -> u64;
2172+
}
2173+
unsafe extern "C" {
2174+
pub fn cuda_cast_to_signed_64(
2175+
streams: CudaStreamsFFI,
2176+
output: *mut CudaRadixCiphertextFFI,
2177+
input: *const CudaRadixCiphertextFFI,
2178+
mem: *mut i8,
2179+
input_is_signed: bool,
2180+
bsks: *const *mut ffi::c_void,
2181+
ksks: *const *mut ffi::c_void,
2182+
);
2183+
}
2184+
unsafe extern "C" {
2185+
pub fn cleanup_cuda_cast_to_signed_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
2186+
}
21892187
unsafe extern "C" {
21902188
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
21912189
streams: CudaStreamsFFI,

0 commit comments

Comments
 (0)