Skip to content

Commit 19bbc82

Browse files
committed
refactor(gpu): vector_find's functions to backend
1 parent 23c49e2 commit 19bbc82

File tree

9 files changed

+4554
-1756
lines changed

9 files changed

+4554
-1756
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 179 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -491,23 +491,6 @@ void cuda_integer_div_rem_radix_ciphertext_64(
491491
void cleanup_cuda_integer_div_rem(CudaStreamsFFI streams,
492492
int8_t **mem_ptr_void);
493493

494-
uint64_t scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(
495-
CudaStreamsFFI streams, int8_t **mem_ptr, void const *input_lut,
496-
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
497-
uint32_t ks_level, uint32_t ks_base_log, uint32_t pbs_level,
498-
uint32_t pbs_base_log, uint32_t grouping_factor, uint32_t num_radix_blocks,
499-
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
500-
uint64_t lut_degree, bool allocate_gpu_memory,
501-
PBS_MS_REDUCTION_T noise_reduction_type);
502-
503-
void cuda_integer_compute_prefix_sum_hillis_steele_64(
504-
CudaStreamsFFI streams, CudaRadixCiphertextFFI *output_radix_lwe,
505-
CudaRadixCiphertextFFI *generates_or_propagates, int8_t *mem_ptr,
506-
void *const *ksks, void *const *bsks, uint32_t num_blocks);
507-
508-
void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64(
509-
CudaStreamsFFI streams, int8_t **mem_ptr_void);
510-
511494
void cuda_integer_reverse_blocks_64_inplace(CudaStreamsFFI streams,
512495
CudaRadixCiphertextFFI *lwe_array);
513496

@@ -781,60 +764,6 @@ void cuda_integer_ilog2_64(
781764
void cleanup_cuda_integer_ilog2_64(CudaStreamsFFI streams,
782765
int8_t **mem_ptr_void);
783766

784-
uint64_t scratch_cuda_compute_equality_selectors_64(
785-
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
786-
uint32_t polynomial_size, uint32_t big_lwe_dimension,
787-
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
788-
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
789-
uint32_t num_possible_values, uint32_t num_blocks, uint32_t message_modulus,
790-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
791-
PBS_MS_REDUCTION_T noise_reduction_type);
792-
793-
void cuda_compute_equality_selectors_64(
794-
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out_list,
795-
CudaRadixCiphertextFFI const *lwe_array_in, uint32_t num_blocks,
796-
const uint64_t *h_decomposed_cleartexts, int8_t *mem, void *const *bsks,
797-
void *const *ksks);
798-
799-
void cleanup_cuda_compute_equality_selectors_64(CudaStreamsFFI streams,
800-
int8_t **mem_ptr_void);
801-
802-
uint64_t scratch_cuda_create_possible_results_64(
803-
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
804-
uint32_t polynomial_size, uint32_t big_lwe_dimension,
805-
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
806-
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
807-
uint32_t num_possible_values, uint32_t num_blocks, uint32_t message_modulus,
808-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
809-
PBS_MS_REDUCTION_T noise_reduction_type);
810-
811-
void cuda_create_possible_results_64(
812-
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out_list,
813-
CudaRadixCiphertextFFI const *lwe_array_in_list,
814-
uint32_t num_possible_values, const uint64_t *h_decomposed_cleartexts,
815-
uint32_t num_blocks, int8_t *mem, void *const *bsks, void *const *ksks);
816-
817-
void cleanup_cuda_create_possible_results_64(CudaStreamsFFI streams,
818-
int8_t **mem_ptr_void);
819-
820-
uint64_t scratch_cuda_aggregate_one_hot_vector_64(
821-
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
822-
uint32_t polynomial_size, uint32_t big_lwe_dimension,
823-
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
824-
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
825-
uint32_t num_blocks, uint32_t num_matches, uint32_t message_modulus,
826-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
827-
PBS_MS_REDUCTION_T noise_reduction_type);
828-
829-
void cuda_aggregate_one_hot_vector_64(
830-
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
831-
CudaRadixCiphertextFFI const *lwe_array_in_list,
832-
uint32_t num_input_ciphertexts, uint32_t num_blocks, int8_t *mem,
833-
void *const *bsks, void *const *ksks);
834-
835-
void cleanup_cuda_aggregate_one_hot_vector_64(CudaStreamsFFI streams,
836-
int8_t **mem_ptr_void);
837-
838767
uint64_t scratch_cuda_unchecked_match_value_64(
839768
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
840769
uint32_t polynomial_size, uint32_t big_lwe_dimension,
@@ -894,6 +823,185 @@ void cuda_unchecked_match_value_or_64(
894823

895824
void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams,
896825
int8_t **mem_ptr_void);
826+
827+
uint64_t scratch_cuda_unchecked_contains_64(
828+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
829+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
830+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
831+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
832+
uint32_t num_inputs, uint32_t num_blocks, uint32_t message_modulus,
833+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
834+
PBS_MS_REDUCTION_T noise_reduction_type);
835+
836+
void cuda_unchecked_contains_64(CudaStreamsFFI streams,
837+
CudaRadixCiphertextFFI *output,
838+
CudaRadixCiphertextFFI const *inputs,
839+
CudaRadixCiphertextFFI const *value,
840+
uint32_t num_inputs, uint32_t num_blocks,
841+
int8_t *mem, void *const *bsks,
842+
void *const *ksks);
843+
844+
void cleanup_cuda_unchecked_contains_64(CudaStreamsFFI streams,
845+
int8_t **mem_ptr_void);
846+
847+
uint64_t scratch_cuda_unchecked_contains_clear_64(
848+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
849+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
850+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
851+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
852+
uint32_t num_inputs, uint32_t num_blocks, uint32_t message_modulus,
853+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
854+
PBS_MS_REDUCTION_T noise_reduction_type);
855+
856+
void cuda_unchecked_contains_clear_64(CudaStreamsFFI streams,
857+
CudaRadixCiphertextFFI *output,
858+
CudaRadixCiphertextFFI const *inputs,
859+
const uint64_t *h_clear_val,
860+
uint32_t num_inputs, uint32_t num_blocks,
861+
int8_t *mem, void *const *bsks,
862+
void *const *ksks);
863+
864+
void cleanup_cuda_unchecked_contains_clear_64(CudaStreamsFFI streams,
865+
int8_t **mem_ptr_void);
866+
867+
uint64_t scratch_cuda_unchecked_is_in_clears_64(
868+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
869+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
870+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
871+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
872+
uint32_t num_clears, uint32_t num_blocks, uint32_t message_modulus,
873+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
874+
PBS_MS_REDUCTION_T noise_reduction_type);
875+
876+
void cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
877+
CudaRadixCiphertextFFI *output,
878+
CudaRadixCiphertextFFI const *input,
879+
const uint64_t *h_cleartexts,
880+
uint32_t num_clears, uint32_t num_blocks,
881+
int8_t *mem, void *const *bsks,
882+
void *const *ksks);
883+
884+
void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
885+
int8_t **mem_ptr_void);
886+
887+
uint64_t scratch_cuda_compute_final_index_from_selectors_64(
888+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
889+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
890+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
891+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
892+
uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus,
893+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
894+
PBS_MS_REDUCTION_T noise_reduction_type);
895+
896+
void cuda_compute_final_index_from_selectors_64(
897+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
898+
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
899+
uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem,
900+
void *const *bsks, void *const *ksks);
901+
902+
void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams,
903+
int8_t **mem_ptr_void);
904+
905+
uint64_t scratch_cuda_unchecked_index_in_clears_64(
906+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
907+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
908+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
909+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
910+
uint32_t num_clears, uint32_t num_blocks, uint32_t num_blocks_index,
911+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
912+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
913+
914+
void cuda_unchecked_index_in_clears_64(CudaStreamsFFI streams,
915+
CudaRadixCiphertextFFI *index_ct,
916+
CudaRadixCiphertextFFI *match_ct,
917+
CudaRadixCiphertextFFI const *input,
918+
const uint64_t *h_cleartexts,
919+
uint32_t num_clears, uint32_t num_blocks,
920+
uint32_t num_blocks_index, int8_t *mem,
921+
void *const *bsks, void *const *ksks);
922+
923+
void cleanup_cuda_unchecked_index_in_clears_64(CudaStreamsFFI streams,
924+
int8_t **mem_ptr_void);
925+
926+
uint64_t scratch_cuda_unchecked_first_index_in_clears_64(
927+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
928+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
929+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
930+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
931+
uint32_t num_unique, uint32_t num_blocks, uint32_t num_blocks_index,
932+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
933+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
934+
935+
void cuda_unchecked_first_index_in_clears_64(
936+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
937+
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *input,
938+
const uint64_t *h_unique_values, const uint64_t *h_unique_indices,
939+
uint32_t num_unique, uint32_t num_blocks, uint32_t num_blocks_index,
940+
int8_t *mem, void *const *bsks, void *const *ksks);
941+
942+
void cleanup_cuda_unchecked_first_index_in_clears_64(CudaStreamsFFI streams,
943+
int8_t **mem_ptr_void);
944+
945+
uint64_t scratch_cuda_unchecked_first_index_of_clear_64(
946+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
947+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
948+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
949+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
950+
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
951+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
952+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
953+
954+
void cuda_unchecked_first_index_of_clear_64(
955+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
956+
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
957+
const uint64_t *h_clear_val, uint32_t num_inputs, uint32_t num_blocks,
958+
uint32_t num_blocks_index, int8_t *mem, void *const *bsks,
959+
void *const *ksks);
960+
961+
void cleanup_cuda_unchecked_first_index_of_clear_64(CudaStreamsFFI streams,
962+
int8_t **mem_ptr_void);
963+
964+
uint64_t scratch_cuda_unchecked_first_index_of_64(
965+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
966+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
967+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
968+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
969+
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
970+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
971+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
972+
973+
void cuda_unchecked_first_index_of_64(CudaStreamsFFI streams,
974+
CudaRadixCiphertextFFI *index_ct,
975+
CudaRadixCiphertextFFI *match_ct,
976+
CudaRadixCiphertextFFI const *inputs,
977+
CudaRadixCiphertextFFI const *value,
978+
uint32_t num_inputs, uint32_t num_blocks,
979+
uint32_t num_blocks_index, int8_t *mem,
980+
void *const *bsks, void *const *ksks);
981+
982+
void cleanup_cuda_unchecked_first_index_of_64(CudaStreamsFFI streams,
983+
int8_t **mem_ptr_void);
984+
985+
uint64_t scratch_cuda_unchecked_index_of_64(
986+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
987+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
988+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
989+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
990+
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
991+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
992+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
993+
994+
void cuda_unchecked_index_of_64(CudaStreamsFFI streams,
995+
CudaRadixCiphertextFFI *index_ct,
996+
CudaRadixCiphertextFFI *match_ct,
997+
CudaRadixCiphertextFFI const *inputs,
998+
CudaRadixCiphertextFFI const *value,
999+
uint32_t num_inputs, uint32_t num_blocks,
1000+
uint32_t num_blocks_index, int8_t *mem,
1001+
void *const *bsks, void *const *ksks);
1002+
1003+
void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams,
1004+
int8_t **mem_ptr_void);
8971005
} // extern C
8981006

8991007
#endif // CUDA_INTEGER_H

0 commit comments

Comments
 (0)