Skip to content

Commit cea702e

Browse files
committed
refactor(gpu): unchecked_match_value_or to backend
1 parent 1a7efbc commit cea702e

File tree

20 files changed

+1363
-88
lines changed

20 files changed

+1363
-88
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/cast.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,55 @@ template <typename Torus> struct int_extend_radix_with_sign_msb_buffer {
7575
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
7676
}
7777
};
78+
79+
template <typename Torus> struct int_cast_to_unsigned_buffer {
80+
int_radix_params params;
81+
bool allocate_gpu_memory;
82+
83+
bool requires_full_propagate;
84+
bool requires_sign_extension;
85+
86+
int_fullprop_buffer<Torus> *prop_buffer;
87+
int_extend_radix_with_sign_msb_buffer<Torus> *extend_buffer;
88+
89+
int_cast_to_unsigned_buffer(CudaStreams streams, int_radix_params params,
90+
uint32_t num_input_blocks,
91+
uint32_t target_num_blocks, bool input_is_signed,
92+
bool requires_full_propagate,
93+
bool allocate_gpu_memory,
94+
uint64_t &size_tracker) {
95+
this->params = params;
96+
this->allocate_gpu_memory = allocate_gpu_memory;
97+
this->requires_full_propagate = requires_full_propagate;
98+
99+
this->prop_buffer = nullptr;
100+
this->extend_buffer = nullptr;
101+
102+
if (requires_full_propagate) {
103+
this->prop_buffer = new int_fullprop_buffer<Torus>(
104+
streams, params, allocate_gpu_memory, size_tracker);
105+
}
106+
107+
this->requires_sign_extension =
108+
(target_num_blocks > num_input_blocks) && input_is_signed;
109+
110+
if (this->requires_sign_extension) {
111+
uint32_t num_blocks_to_add = target_num_blocks - num_input_blocks;
112+
this->extend_buffer = new int_extend_radix_with_sign_msb_buffer<Torus>(
113+
streams, params, num_input_blocks, num_blocks_to_add,
114+
allocate_gpu_memory, size_tracker);
115+
}
116+
}
117+
118+
void release(CudaStreams streams) {
119+
if (this->prop_buffer) {
120+
this->prop_buffer->release(streams);
121+
delete this->prop_buffer;
122+
}
123+
if (this->extend_buffer) {
124+
this->extend_buffer->release(streams);
125+
delete this->extend_buffer;
126+
}
127+
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
128+
}
129+
};

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,10 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
569569
CudaRadixCiphertextFFI const *input,
570570
CudaStreamsFFI streams);
571571

572+
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
573+
CudaRadixCiphertextFFI const *input,
574+
CudaStreamsFFI streams);
575+
572576
uint64_t scratch_cuda_apply_noise_squashing(
573577
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t lwe_dimension,
574578
uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -850,6 +854,46 @@ void cuda_unchecked_match_value_64(
850854

851855
void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
852856
int8_t **mem_ptr_void);
857+
858+
uint64_t scratch_cuda_cast_to_unsigned_64(
859+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
860+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
861+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
862+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
863+
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
864+
bool requires_full_propagate, uint32_t message_modulus,
865+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
866+
PBS_MS_REDUCTION_T noise_reduction_type);
867+
868+
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
869+
CudaRadixCiphertextFFI *output,
870+
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
871+
uint32_t target_num_blocks, bool input_is_signed,
872+
void *const *bsks, void *const *ksks);
873+
874+
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
875+
int8_t **mem_ptr_void);
876+
877+
uint64_t scratch_cuda_unchecked_match_value_or_64(
878+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
879+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
880+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
881+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
882+
uint32_t num_matches, uint32_t num_input_blocks,
883+
uint32_t num_match_packed_blocks, uint32_t num_final_blocks,
884+
uint32_t max_output_is_zero, uint32_t message_modulus,
885+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
886+
PBS_MS_REDUCTION_T noise_reduction_type);
887+
888+
void cuda_unchecked_match_value_or_64(
889+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
890+
CudaRadixCiphertextFFI const *lwe_array_in_ct,
891+
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
892+
const uint64_t *h_or_value, int8_t *mem, void *const *bsks,
893+
void *const *ksks);
894+
895+
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
896+
int8_t **mem_ptr_void);
853897
} // extern C
854898

855899
#endif // CUDA_INTEGER_H

backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include "cast.h"
23
#include "integer/comparison.h"
34
#include "integer/radix_ciphertext.cuh"
45
#include "integer_utilities.h"
@@ -593,3 +594,98 @@ template <typename Torus> struct int_unchecked_match_buffer {
593594
delete this->packed_selectors_ct;
594595
}
595596
};
597+
598+
template <typename Torus> struct int_unchecked_match_value_or_buffer {
599+
int_radix_params params;
600+
bool allocate_gpu_memory;
601+
602+
uint32_t num_matches;
603+
uint32_t num_input_blocks;
604+
uint32_t num_match_packed_blocks;
605+
uint32_t num_final_blocks;
606+
bool max_output_is_zero;
607+
608+
int_unchecked_match_buffer<Torus> *match_buffer;
609+
int_cmux_buffer<Torus> *cmux_buffer;
610+
611+
CudaRadixCiphertextFFI *tmp_match_result;
612+
CudaRadixCiphertextFFI *tmp_match_bool;
613+
CudaRadixCiphertextFFI *tmp_or_value;
614+
615+
Torus *d_or_value;
616+
617+
int_unchecked_match_value_or_buffer(
618+
CudaStreams streams, int_radix_params params, uint32_t num_matches,
619+
uint32_t num_input_blocks, uint32_t num_match_packed_blocks,
620+
uint32_t num_final_blocks, bool max_output_is_zero,
621+
bool allocate_gpu_memory, uint64_t &size_tracker) {
622+
this->params = params;
623+
this->allocate_gpu_memory = allocate_gpu_memory;
624+
this->num_matches = num_matches;
625+
this->num_input_blocks = num_input_blocks;
626+
this->num_match_packed_blocks = num_match_packed_blocks;
627+
this->num_final_blocks = num_final_blocks;
628+
this->max_output_is_zero = max_output_is_zero;
629+
630+
this->match_buffer = new int_unchecked_match_buffer<Torus>(
631+
streams, params, num_matches, num_input_blocks, num_match_packed_blocks,
632+
max_output_is_zero, allocate_gpu_memory, size_tracker);
633+
634+
this->cmux_buffer = new int_cmux_buffer<Torus>(
635+
streams, [](Torus x) -> Torus { return x == 1; }, params,
636+
num_final_blocks, allocate_gpu_memory, size_tracker);
637+
638+
this->tmp_match_result = new CudaRadixCiphertextFFI;
639+
this->tmp_match_bool = new CudaRadixCiphertextFFI;
640+
this->tmp_or_value = new CudaRadixCiphertextFFI;
641+
642+
this->d_or_value = (Torus *)cuda_malloc_with_size_tracking_async(
643+
num_final_blocks * sizeof(Torus), streams.stream(0),
644+
streams.gpu_index(0), size_tracker, allocate_gpu_memory);
645+
646+
if (!max_output_is_zero) {
647+
create_zero_radix_ciphertext_async<Torus>(
648+
streams.stream(0), streams.gpu_index(0), this->tmp_match_result,
649+
num_final_blocks, params.big_lwe_dimension, size_tracker,
650+
allocate_gpu_memory);
651+
}
652+
653+
create_zero_radix_ciphertext_async<Torus>(
654+
streams.stream(0), streams.gpu_index(0), this->tmp_match_bool, 1,
655+
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
656+
657+
create_zero_radix_ciphertext_async<Torus>(
658+
streams.stream(0), streams.gpu_index(0), this->tmp_or_value,
659+
num_final_blocks, params.big_lwe_dimension, size_tracker,
660+
allocate_gpu_memory);
661+
}
662+
663+
void release(CudaStreams streams) {
664+
this->match_buffer->release(streams);
665+
delete this->match_buffer;
666+
667+
this->cmux_buffer->release(streams);
668+
delete this->cmux_buffer;
669+
670+
if (!max_output_is_zero) {
671+
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
672+
this->tmp_match_result,
673+
this->allocate_gpu_memory);
674+
}
675+
delete this->tmp_match_result;
676+
677+
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
678+
this->tmp_match_bool,
679+
this->allocate_gpu_memory);
680+
delete this->tmp_match_bool;
681+
682+
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
683+
this->tmp_or_value,
684+
this->allocate_gpu_memory);
685+
delete this->tmp_or_value;
686+
687+
cuda_drop_async(this->d_or_value, streams.stream(0), streams.gpu_index(0));
688+
689+
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
690+
}
691+
};

backends/tfhe-cuda-backend/cuda/src/integer/cast.cu

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output,
1818
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
1919
}
2020

21+
void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output,
22+
CudaRadixCiphertextFFI const *input,
23+
CudaStreamsFFI streams) {
24+
25+
auto cuda_streams = CudaStreams(streams);
26+
host_trim_radix_blocks_msb<uint64_t>(output, input, cuda_streams);
27+
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
28+
}
29+
2130
uint64_t scratch_cuda_extend_radix_with_sign_msb_64(
2231
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
2332
uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level,
@@ -64,3 +73,46 @@ void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams,
6473
delete mem_ptr;
6574
*mem_ptr_void = nullptr;
6675
}
76+
77+
uint64_t scratch_cuda_cast_to_unsigned_64(
78+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
79+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
80+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
81+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
82+
uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed,
83+
bool requires_full_propagate, uint32_t message_modulus,
84+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
85+
PBS_MS_REDUCTION_T noise_reduction_type) {
86+
87+
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
88+
big_lwe_dimension, small_lwe_dimension, ks_level,
89+
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
90+
message_modulus, carry_modulus, noise_reduction_type);
91+
92+
return scratch_cuda_cast_to_unsigned<uint64_t>(
93+
CudaStreams(streams), (int_cast_to_unsigned_buffer<uint64_t> **)mem_ptr,
94+
params, num_input_blocks, target_num_blocks, input_is_signed,
95+
requires_full_propagate, allocate_gpu_memory);
96+
}
97+
98+
void cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
99+
CudaRadixCiphertextFFI *output,
100+
CudaRadixCiphertextFFI *input, int8_t *mem_ptr,
101+
uint32_t target_num_blocks, bool input_is_signed,
102+
void *const *bsks, void *const *ksks) {
103+
104+
host_cast_to_unsigned<uint64_t>(
105+
CudaStreams(streams), output, input,
106+
(int_cast_to_unsigned_buffer<uint64_t> *)mem_ptr, target_num_blocks,
107+
input_is_signed, bsks, (uint64_t **)ksks);
108+
}
109+
110+
void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams,
111+
int8_t **mem_ptr_void) {
112+
int_cast_to_unsigned_buffer<uint64_t> *mem_ptr =
113+
(int_cast_to_unsigned_buffer<uint64_t> *)(*mem_ptr_void);
114+
115+
mem_ptr->release(CudaStreams(streams));
116+
delete mem_ptr;
117+
*mem_ptr_void = nullptr;
118+
}

backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@ __host__ void host_trim_radix_blocks_lsb(CudaRadixCiphertextFFI *output,
3636
input->num_radix_blocks);
3737
}
3838

39+
template <typename Torus>
40+
__host__ void
41+
host_trim_radix_blocks_msb(CudaRadixCiphertextFFI *output_radix,
42+
const CudaRadixCiphertextFFI *input_radix,
43+
CudaStreams streams) {
44+
45+
PANIC_IF_FALSE(input_radix->num_radix_blocks >=
46+
output_radix->num_radix_blocks,
47+
"Cuda error: input radix ciphertext has fewer blocks than "
48+
"required to keep");
49+
50+
copy_radix_ciphertext_slice_async<Torus>(
51+
streams.stream(0), streams.gpu_index(0), output_radix, 0,
52+
output_radix->num_radix_blocks, input_radix, 0,
53+
output_radix->num_radix_blocks);
54+
}
55+
3956
template <typename Torus>
4057
__host__ uint64_t scratch_extend_radix_with_sign_msb(
4158
CudaStreams streams, int_extend_radix_with_sign_msb_buffer<Torus> **mem_ptr,
@@ -91,4 +108,56 @@ __host__ void host_extend_radix_with_sign_msb(
91108
POP_RANGE()
92109
}
93110

111+
template <typename Torus>
112+
uint64_t scratch_cuda_cast_to_unsigned(
113+
CudaStreams streams, int_cast_to_unsigned_buffer<Torus> **mem_ptr,
114+
int_radix_params params, uint32_t num_input_blocks,
115+
uint32_t target_num_blocks, bool input_is_signed,
116+
bool requires_full_propagate, bool allocate_gpu_memory) {
117+
118+
uint64_t size_tracker = 0;
119+
*mem_ptr = new int_cast_to_unsigned_buffer<Torus>(
120+
streams, params, num_input_blocks, target_num_blocks, input_is_signed,
121+
requires_full_propagate, allocate_gpu_memory, size_tracker);
122+
123+
return size_tracker;
124+
}
125+
126+
template <typename Torus>
127+
__host__ void
128+
host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output,
129+
CudaRadixCiphertextFFI *input,
130+
int_cast_to_unsigned_buffer<Torus> *mem_ptr,
131+
uint32_t target_num_blocks, bool input_is_signed,
132+
void *const *bsks, Torus *const *ksks) {
133+
134+
uint32_t current_num_blocks = input->num_radix_blocks;
135+
136+
if (mem_ptr->requires_full_propagate) {
137+
host_full_propagate_inplace<Torus>(streams, input, mem_ptr->prop_buffer,
138+
ksks, bsks, current_num_blocks);
139+
}
140+
141+
if (target_num_blocks > current_num_blocks) {
142+
uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks;
143+
144+
if (input_is_signed) {
145+
host_extend_radix_with_sign_msb<Torus>(
146+
streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add,
147+
bsks, (Torus **)ksks);
148+
} else {
149+
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(output, input,
150+
streams);
151+
}
152+
153+
} else if (target_num_blocks < current_num_blocks) {
154+
host_trim_radix_blocks_msb<Torus>(output, input, streams);
155+
156+
} else {
157+
copy_radix_ciphertext_slice_async<Torus>(
158+
streams.stream(0), streams.gpu_index(0), output, 0, current_num_blocks,
159+
input, 0, current_num_blocks);
160+
}
161+
}
162+
94163
#endif

0 commit comments

Comments
 (0)