1010debug_split = False
1111max_leaves = None
1212
13+ def GetSortPerm (keys , * to_sort , n_bits = None , time = False ):
14+ """
15+ Compute and return secret shared permutation that stably sorts :param keys.
16+ """
17+ for k in keys :
18+ assert len (k ) == len (keys [0 ])
19+ n_bits = n_bits or [None ] * len (keys )
20+ bs = Matrix .create_from (sum ([k .get_vector ().bit_decompose (nb )
21+ for k , nb in reversed (list (zip (keys , n_bits )))], []))
22+ get_vec = lambda x : x [:] if isinstance (x , Array ) else x
23+ res = Matrix .create_from (get_vec (x ).v if isinstance (get_vec (x ), sfix ) else x
24+ for x in to_sort )
25+ res = res .transpose ()
26+ return radix_sort_from_matrix (bs , res )
27+
28+ def ApplyPermutation (perm , x ):
29+ res = Array .create_from (x )
30+ reveal_sort (perm , res , False )
31+ return res
32+
33+ def ApplyInversePermutation (perm , x ):
34+ res = Array .create_from (x )
35+ reveal_sort (perm , res , True )
36+ return res
37+
1338def get_type (x ):
1439 if isinstance (x , (Array , SubMultiArray )):
1540 return x .value_type
@@ -22,6 +47,12 @@ def get_type(x):
2247 else :
2348 return type (x )
2449
50+
51+ def Custom_GT_Fractions (x_num , x_den , y_num , y_den , n_threads = 2 ):
52+ b = (x_num * y_den ) > (x_den * y_num )
53+ b = Array .create_from (b ).get_vector ()
54+ return b
55+
2556def PrefixSum (x ):
2657 return x .get_vector ().prefix_sum ()
2758
@@ -86,17 +117,9 @@ def Sort(keys, *to_sort, n_bits=None, time=False):
86117
87118def VectMax (key , * data , debug = False ):
88119 def reducer (x , y ):
89- b = x [0 ] > y [0 ]
90- if debug :
91- print_ln ('max b=%s' , b .reveal ())
120+ b = x [0 ]* y [1 ] > y [0 ]* x [1 ]
92121 return [b .if_else (xx , yy ) for xx , yy in zip (x , y )]
93- if debug :
94- key = list (key )
95- data = [list (x ) for x in data ]
96- print_ln ('vect max key=%s data=%s' , util .reveal (key ), util .reveal (data ))
97122 res = util .tree_reduce (reducer , zip (key , * data ))[1 :]
98- if debug :
99- print_ln ('vect max res=%s' , util .reveal (res ))
100123 return res
101124
102125def GroupSum (g , x ):
@@ -119,9 +142,6 @@ def GroupPrefixSum(g, x):
119142 return s .get_vector (size = len (x ), base = 1 ) - GroupSum (g , q )
120143
121144def GroupMax (g , keys , * x ):
122- if debug :
123- print_ln ('group max input g=%s keys=%s x=%s' , util .reveal (g ),
124- util .reveal (keys ), util .reveal (x ))
125145 assert len (keys ) == len (g )
126146 for xx in x :
127147 assert len (xx ) == len (g )
@@ -138,23 +158,17 @@ def GroupMax(g, keys, *x):
138158 vsize = n - w
139159 g_new .assign_vector (g_old .get_vector (size = vsize ).bit_or (
140160 g_old .get_vector (size = vsize , base = w )), base = w )
141- b = keys .get_vector (size = vsize ) > keys .get_vector (size = vsize , base = w )
161+ b = Custom_GT_Fractions (keys .get_vector (size = vsize ), x [0 ].get_vector (size = vsize ), keys .get_vector (size = vsize , base = w ), x [0 ].get_vector (size = vsize , base = w ))
162+ #b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
142163 for xx in [keys ] + x :
143164 a = b .if_else (xx .get_vector (size = vsize ),
144165 xx .get_vector (size = vsize , base = w ))
145166 xx .assign_vector (g_old .get_vector (size = vsize , base = w ).if_else (
146167 xx .get_vector (size = vsize , base = w ), a ), base = w )
147168 break_point ()
148- if debug :
149- print_ln ('group max w=%s b=%s a=%s keys=%s x=%s g=%s' , w , b .reveal (),
150- util .reveal (a ), util .reveal (keys ),
151- util .reveal (x ), g_new .reveal ())
152169 t = sint .Array (len (g ))
153170 t [- 1 ] = 1
154171 t .assign_vector (g .get_vector (size = n - 1 , base = 1 ))
155- if debug :
156- print_ln ('group max end g=%s t=%s keys=%s x=%s' , util .reveal (g ),
157- util .reveal (t ), util .reveal (keys ), util .reveal (x ))
158172 return [GroupSum (g , t [:] * xx ) for xx in [keys ] + x ]
159173
160174def ComputeGini (g , x , y , notysum , ysum , debug = False ):
@@ -196,12 +210,8 @@ def FormatLayer_without_crop(g, *a, debug=False):
196210 for x in a :
197211 assert len (x ) == len (g )
198212 v = [g .if_else (aa , 0 ) for aa in a ]
199- if debug :
200- print_ln ('format in %s' , util .reveal (a ))
201- print_ln ('format mux %s' , util .reveal (v ))
202- v = Sort ([g .bit_not ()], * v , n_bits = [1 ])
203- if debug :
204- print_ln ('format sort %s' , util .reveal (v ))
213+ p = SortPerm (g .get_vector ().bit_not ())
214+ v = [p .apply (vv ) for vv in v ]
205215 return v
206216
207217def CropLayer (k , * v ):
@@ -216,36 +226,12 @@ def TrainLeafNodes(h, g, y, NID):
216226 assert len (g ) == len (NID )
217227 return FormatLayer (h , g , NID , Label )
218228
219- def GroupSame (g , y ):
220- assert len (g ) == len (y )
221- s = GroupSum (g , [sint (1 )] * len (g ))
222- s0 = GroupSum (g , y .bit_not ())
223- s1 = GroupSum (g , y )
224- if debug_split :
225- print_ln ('group same g=%s' , util .reveal (g ))
226- print_ln ('group same y=%s' , util .reveal (y ))
227- return (s == s0 ).bit_or (s == s1 )
228-
229229def GroupFirstOne (g , b ):
230230 assert len (g ) == len (b )
231231 s = GroupPrefixSum (g , b )
232232 return s * b == 1
233233
234234class TreeTrainer :
235- """ Decision tree training by `Hamada et al.`_
236-
237- :param x: sample data (by attribute, list or
238- :py:obj:`~Compiler.types.Matrix`)
239- :param y: binary labels (list or sint vector)
240- :param h: height (int)
241- :param binary: binary attributes instead of continuous
242- :param attr_lengths: attribute description for mixed data
243- (list of 0/1 for continuous/binary)
244- :param n_threads: number of threads (default: single thread)
245-
246- .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
247-
248- """
249235 def GetInversePermutation (self , perm ):
250236 res = Array .create_from (self .identity_permutation )
251237 reveal_sort (perm , res )
@@ -258,14 +244,10 @@ def ApplyTests(self, x, AID, Threshold):
258244 for xx in x :
259245 assert len (xx ) == len (AID )
260246 e = sint .Matrix (m , n )
261- AID = Array .create_from (AID )
262247 @for_range_multithread (self .n_threads , 1 , m )
263248 def _ (j ):
264249 e [j ][:] = AID [:] == j
265250 xx = sum (x [j ] * e [j ] for j in range (m ))
266- if self .debug > 1 :
267- print_ln ('apply e=%s xx=%s' , util .reveal (e ), util .reveal (xx ))
268- print_ln ('threshold %s' , util .reveal (Threshold ))
269251 return 2 * xx < Threshold
270252
271253 def TestSelection (self , g , x , y , pis , notysum , ysum , time = False ):
@@ -353,36 +335,13 @@ def _(j):
353335
354336 return [g , x , y , NID , pis ]
355337
356- def TrainInternalNodes (self , k , x , y , g , NID ):
357- assert len (g ) == len (y )
358- for xx in x :
359- assert len (xx ) == len (g )
360- AID , Threshold = self .GlobalTestSelection (x , y , g )
361- s = GroupSame (g [:], y [:])
362- if self .debug > 1 or debug_split :
363- print_ln ('AID=%s' , util .reveal (AID ))
364- print_ln ('Threshold=%s' , util .reveal (Threshold ))
365- print_ln ('GroupSame=%s' , util .reveal (s ))
366- AID , Threshold = s .if_else (0 , AID ), s .if_else (MIN_VALUE , Threshold )
367- if self .debug > 1 or debug_split :
368- print_ln ('AID=%s' , util .reveal (AID ))
369- print_ln ('Threshold=%s' , util .reveal (Threshold ))
370- b = self .ApplyTests (x , AID , Threshold )
371- layer = FormatLayer_without_crop (g [:], NID , AID , Threshold ,
372- debug = self .debug > 1 )
373- return * layer , b
374-
375338 @method_block
376339 def train_layer (self , k ):
377340 x = self .x
378341 y = self .y
379342 g = self .g
380343 NID = self .NID
381344 pis = self .pis
382- if self .debug > 1 :
383- print_ln ('g=%s' , g .reveal ())
384- print_ln ('y=%s' , y .reveal ())
385- print_ln ('x=%s' , x .reveal_nested ())
386345
387346 s0 = GroupSum (g , y .get_vector ().bit_not ())
388347 s1 = GroupSum (g , y .get_vector ())
@@ -400,6 +359,21 @@ def _():
400359
401360 def __init__ (self , x , y , h , binary = False , attr_lengths = None ,
402361 n_threads = None ):
362+ """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf
363+
364+ This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf
365+
366+ To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128
367+
368+ :param x: Attribute values
369+ :param y: Binary labels
370+ :param h: Height of the decision tree
371+ :param binary: Binary attributes instead of continuous
372+ :param attr_lengths: Attribute description for mixed data
373+ (list of 0/1 for continuous/binary)
374+ :param n_threads: Number of threads
375+
376+ """
403377 assert not (binary and attr_lengths )
404378 if binary :
405379 attr_lengths = [1 ] * len (x )
@@ -412,6 +386,7 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
412386 Matrix .disable_index_checks ()
413387 for xx in x :
414388 assert len (xx ) == len (y )
389+ m = len (x )
415390 n = len (y )
416391 self .g = sint .Array (n )
417392 self .g .assign_all (0 )
@@ -459,7 +434,7 @@ def train_with_testing(self, *test_set, output=False):
459434 """
460435 for k in range (len (self .nids )):
461436 self .train_layer (k )
462- tree = self .get_tree (k + 1 )
437+ tree = self .get_tree (k + 1 , self . label )
463438 if output :
464439 output_decision_tree (tree )
465440 test_decision_tree ('train' , tree , self .y , self .x ,
0 commit comments