Skip to content

Commit 0a4417d

Browse files
committed
Unified decision_tree and decision_tree_optimized, testing due
1 parent 0a9d5e8 commit 0a4417d

File tree

1 file changed

+53
-78
lines changed

1 file changed

+53
-78
lines changed

Compiler/decision_tree.py

Lines changed: 53 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,31 @@
1010
debug_split = False
1111
max_leaves = None
1212

13+
def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
14+
"""
15+
Compute and return secret shared permutation that stably sorts :param keys.
16+
"""
17+
for k in keys:
18+
assert len(k) == len(keys[0])
19+
n_bits = n_bits or [None] * len(keys)
20+
bs = Matrix.create_from(sum([k.get_vector().bit_decompose(nb)
21+
for k, nb in reversed(list(zip(keys, n_bits)))], []))
22+
get_vec = lambda x: x[:] if isinstance(x, Array) else x
23+
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
24+
for x in to_sort)
25+
res = res.transpose()
26+
return radix_sort_from_matrix(bs, res)
27+
28+
def ApplyPermutation(perm, x):
29+
res = Array.create_from(x)
30+
reveal_sort(perm, res, False)
31+
return res
32+
33+
def ApplyInversePermutation(perm, x):
34+
res = Array.create_from(x)
35+
reveal_sort(perm, res, True)
36+
return res
37+
1338
def get_type(x):
1439
if isinstance(x, (Array, SubMultiArray)):
1540
return x.value_type
@@ -22,6 +47,12 @@ def get_type(x):
2247
else:
2348
return type(x)
2449

50+
51+
def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2):
52+
b = (x_num*y_den) > (x_den*y_num)
53+
b = Array.create_from(b).get_vector()
54+
return b
55+
2556
def PrefixSum(x):
2657
return x.get_vector().prefix_sum()
2758

@@ -86,17 +117,9 @@ def Sort(keys, *to_sort, n_bits=None, time=False):
86117

87118
def VectMax(key, *data, debug=False):
88119
def reducer(x, y):
89-
b = x[0] > y[0]
90-
if debug:
91-
print_ln('max b=%s', b.reveal())
120+
b = x[0]*y[1] > y[0]*x[1]
92121
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
93-
if debug:
94-
key = list(key)
95-
data = [list(x) for x in data]
96-
print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
97122
res = util.tree_reduce(reducer, zip(key, *data))[1:]
98-
if debug:
99-
print_ln('vect max res=%s', util.reveal(res))
100123
return res
101124

102125
def GroupSum(g, x):
@@ -119,9 +142,6 @@ def GroupPrefixSum(g, x):
119142
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
120143

121144
def GroupMax(g, keys, *x):
122-
if debug:
123-
print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
124-
util.reveal(keys), util.reveal(x))
125145
assert len(keys) == len(g)
126146
for xx in x:
127147
assert len(xx) == len(g)
@@ -138,23 +158,17 @@ def GroupMax(g, keys, *x):
138158
vsize = n - w
139159
g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
140160
g_old.get_vector(size=vsize, base=w)), base=w)
141-
b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
161+
b = Custom_GT_Fractions(keys.get_vector(size=vsize), x[0].get_vector(size=vsize), keys.get_vector(size=vsize, base=w), x[0].get_vector(size=vsize, base=w))
162+
#b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
142163
for xx in [keys] + x:
143164
a = b.if_else(xx.get_vector(size=vsize),
144165
xx.get_vector(size=vsize, base=w))
145166
xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
146167
xx.get_vector(size=vsize, base=w), a), base=w)
147168
break_point()
148-
if debug:
149-
print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
150-
util.reveal(a), util.reveal(keys),
151-
util.reveal(x), g_new.reveal())
152169
t = sint.Array(len(g))
153170
t[-1] = 1
154171
t.assign_vector(g.get_vector(size=n - 1, base=1))
155-
if debug:
156-
print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
157-
util.reveal(t), util.reveal(keys), util.reveal(x))
158172
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
159173

160174
def ComputeGini(g, x, y, notysum, ysum, debug=False):
@@ -196,12 +210,8 @@ def FormatLayer_without_crop(g, *a, debug=False):
196210
for x in a:
197211
assert len(x) == len(g)
198212
v = [g.if_else(aa, 0) for aa in a]
199-
if debug:
200-
print_ln('format in %s', util.reveal(a))
201-
print_ln('format mux %s', util.reveal(v))
202-
v = Sort([g.bit_not()], *v, n_bits=[1])
203-
if debug:
204-
print_ln('format sort %s', util.reveal(v))
213+
p = SortPerm(g.get_vector().bit_not())
214+
v = [p.apply(vv) for vv in v]
205215
return v
206216

207217
def CropLayer(k, *v):
@@ -216,36 +226,12 @@ def TrainLeafNodes(h, g, y, NID):
216226
assert len(g) == len(NID)
217227
return FormatLayer(h, g, NID, Label)
218228

219-
def GroupSame(g, y):
220-
assert len(g) == len(y)
221-
s = GroupSum(g, [sint(1)] * len(g))
222-
s0 = GroupSum(g, y.bit_not())
223-
s1 = GroupSum(g, y)
224-
if debug_split:
225-
print_ln('group same g=%s', util.reveal(g))
226-
print_ln('group same y=%s', util.reveal(y))
227-
return (s == s0).bit_or(s == s1)
228-
229229
def GroupFirstOne(g, b):
230230
assert len(g) == len(b)
231231
s = GroupPrefixSum(g, b)
232232
return s * b == 1
233233

234234
class TreeTrainer:
235-
""" Decision tree training by `Hamada et al.`_
236-
237-
:param x: sample data (by attribute, list or
238-
:py:obj:`~Compiler.types.Matrix`)
239-
:param y: binary labels (list or sint vector)
240-
:param h: height (int)
241-
:param binary: binary attributes instead of continuous
242-
:param attr_lengths: attribute description for mixed data
243-
(list of 0/1 for continuous/binary)
244-
:param n_threads: number of threads (default: single thread)
245-
246-
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
247-
248-
"""
249235
def GetInversePermutation(self, perm):
250236
res = Array.create_from(self.identity_permutation)
251237
reveal_sort(perm, res)
@@ -258,14 +244,10 @@ def ApplyTests(self, x, AID, Threshold):
258244
for xx in x:
259245
assert len(xx) == len(AID)
260246
e = sint.Matrix(m, n)
261-
AID = Array.create_from(AID)
262247
@for_range_multithread(self.n_threads, 1, m)
263248
def _(j):
264249
e[j][:] = AID[:] == j
265250
xx = sum(x[j] * e[j] for j in range(m))
266-
if self.debug > 1:
267-
print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx))
268-
print_ln('threshold %s', util.reveal(Threshold))
269251
return 2 * xx < Threshold
270252

271253
def TestSelection(self, g, x, y, pis, notysum, ysum, time=False):
@@ -353,36 +335,13 @@ def _(j):
353335

354336
return [g, x, y, NID, pis]
355337

356-
def TrainInternalNodes(self, k, x, y, g, NID):
357-
assert len(g) == len(y)
358-
for xx in x:
359-
assert len(xx) == len(g)
360-
AID, Threshold = self.GlobalTestSelection(x, y, g)
361-
s = GroupSame(g[:], y[:])
362-
if self.debug > 1 or debug_split:
363-
print_ln('AID=%s', util.reveal(AID))
364-
print_ln('Threshold=%s', util.reveal(Threshold))
365-
print_ln('GroupSame=%s', util.reveal(s))
366-
AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold)
367-
if self.debug > 1 or debug_split:
368-
print_ln('AID=%s', util.reveal(AID))
369-
print_ln('Threshold=%s', util.reveal(Threshold))
370-
b = self.ApplyTests(x, AID, Threshold)
371-
layer = FormatLayer_without_crop(g[:], NID, AID, Threshold,
372-
debug=self.debug > 1)
373-
return *layer, b
374-
375338
@method_block
376339
def train_layer(self, k):
377340
x = self.x
378341
y = self.y
379342
g = self.g
380343
NID = self.NID
381344
pis = self.pis
382-
if self.debug > 1:
383-
print_ln('g=%s', g.reveal())
384-
print_ln('y=%s', y.reveal())
385-
print_ln('x=%s', x.reveal_nested())
386345

387346
s0 = GroupSum(g, y.get_vector().bit_not())
388347
s1 = GroupSum(g, y.get_vector())
@@ -400,6 +359,21 @@ def _():
400359

401360
def __init__(self, x, y, h, binary=False, attr_lengths=None,
402361
n_threads=None):
362+
""" Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf
363+
364+
This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf
365+
366+
To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128
367+
368+
:param x: Attribute values
369+
:param y: Binary labels
370+
:param h: Height of the decision tree
371+
:param binary: Binary attributes instead of continuous
372+
:param attr_lengths: Attribute description for mixed data
373+
(list of 0/1 for continuous/binary)
374+
:param n_threads: Number of threads
375+
376+
"""
403377
assert not (binary and attr_lengths)
404378
if binary:
405379
attr_lengths = [1] * len(x)
@@ -412,6 +386,7 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None,
412386
Matrix.disable_index_checks()
413387
for xx in x:
414388
assert len(xx) == len(y)
389+
m = len(x)
415390
n = len(y)
416391
self.g = sint.Array(n)
417392
self.g.assign_all(0)
@@ -459,7 +434,7 @@ def train_with_testing(self, *test_set, output=False):
459434
"""
460435
for k in range(len(self.nids)):
461436
self.train_layer(k)
462-
tree = self.get_tree(k + 1)
437+
tree = self.get_tree(k + 1, self.label)
463438
if output:
464439
output_decision_tree(tree)
465440
test_decision_tree('train', tree, self.y, self.x,

0 commit comments

Comments
 (0)