Skip to content

Commit e4a76c2

Browse files
committed
Imported existing methods from decision_tree
1 parent caf1818 commit e4a76c2

File tree

2 files changed

+2
-109
lines changed

2 files changed

+2
-109
lines changed

Compiler/decision_tree_optimized.py

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from Compiler.types import *
22
from Compiler.sorting import *
33
from Compiler.library import *
4-
from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne
4+
from Compiler.decision_tree import get_type, PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne, output_decision_tree, pick, run_decision_tree, test_decision_tree
55
from Compiler import util, oram
66

77
from itertools import accumulate
@@ -11,18 +11,6 @@
1111
debug_split = False
1212
max_leaves = None
1313

14-
def get_type(x):
15-
if isinstance(x, (Array, SubMultiArray)):
16-
return x.value_type
17-
elif isinstance(x, (tuple, list)):
18-
x = x[0] + x[-1]
19-
if util.is_constant(x):
20-
return cint
21-
else:
22-
return type(x)
23-
else:
24-
return type(x)
25-
2614
def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
2715
"""
2816
Compute and return secret shared permutation that stably sorts :param keys.
@@ -36,7 +24,7 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
3624
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
3725
for x in to_sort)
3826
res = res.transpose()
39-
return radix_sort_permutation_from_matrix(bs, res)
27+
return radix_sort_from_matrix(bs, res)
4028

4129
def ApplyPermutation(perm, x):
4230
res = Array.create_from(x)
@@ -374,81 +362,6 @@ def get_tree(self, h, Label):
374362
def DecisionTreeTraining(x, y, h, binary=False):
375363
return TreeTrainer(x, y, h, binary=binary).train()
376364

377-
def output_decision_tree(layers):
378-
""" Print decision tree output by :py:class:`TreeTrainer`. """
379-
380-
print_ln('full model %s', util.reveal(layers))
381-
for i, layer in enumerate(layers[:-1]):
382-
print_ln('level %s:', i)
383-
for j, x in enumerate(('NID', 'AID', 'Thr')):
384-
print_ln(' %s: %s', x, util.reveal(layer[j]))
385-
print_ln('leaves:')
386-
for j, x in enumerate(('NID', 'result')):
387-
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
388-
389-
def pick(bits, x):
390-
if len(bits) == 1:
391-
return bits[0] * x[0]
392-
else:
393-
try:
394-
return x[0].dot_product(bits, x)
395-
except:
396-
return sum(aa * bb for aa, bb in zip(bits, x))
397-
398-
def run_decision_tree(layers, data):
399-
""" Run decision tree against sample data.
400-
401-
:param layers: tree output by :py:class:`TreeTrainer`
402-
:param data: sample data (:py:class:`~Compiler.types.Array`)
403-
:returns: binary label
404-
405-
"""
406-
h = len(layers) - 1
407-
index = 1
408-
for k, layer in enumerate(layers[:-1]):
409-
assert len(layer) == 3
410-
for x in layer:
411-
assert len(x) <= 2 ** k
412-
bits = layer[0].equal(index, k)
413-
threshold = pick(bits, layer[2])
414-
key_index = pick(bits, layer[1])
415-
if key_index.is_clear:
416-
key = data[key_index]
417-
else:
418-
key = pick(
419-
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
420-
child = 2 * key < threshold
421-
index += child * 2 ** k
422-
bits = layers[h][0].equal(index, h)
423-
return pick(bits, layers[h][1])
424-
425-
def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
426-
if time:
427-
start_timer(100)
428-
n = len(y)
429-
x = x.transpose().reveal()
430-
y = y.reveal()
431-
guess = regint.Array(n)
432-
truth = regint.Array(n)
433-
correct = regint.Array(2)
434-
parts = regint.Array(2)
435-
layers = [[Array.create_from(util.reveal(x)) for x in layer]
436-
for layer in layers]
437-
@for_range_multithread(n_threads, 1, n)
438-
def _(i):
439-
guess[i] = run_decision_tree([[part[:] for part in layer]
440-
for layer in layers], x[i]).reveal()
441-
truth[i] = y[i].reveal()
442-
@for_range(n)
443-
def _(i):
444-
parts[truth[i]] += 1
445-
c = (guess[i].bit_xor(truth[i]).bit_not())
446-
correct[truth[i]] += c
447-
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
448-
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
449-
if time:
450-
stop_timer(100)
451-
452365
class TreeClassifier:
453366
""" Tree classification that uses
454367
:py:class:`TreeTrainer` internally.

Compiler/sorting.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,4 @@ def _():
7373
@library.else_
7474
def _():
7575
reveal_sort(h, D, reverse=True)
76-
77-
def radix_sort_permutation_from_matrix(bs, D):
78-
n = len(D)
79-
for b in bs:
80-
assert(len(b) == n)
81-
B = types.sint.Matrix(n, 2)
82-
h = types.Array.create_from(types.sint(types.regint.inc(n)))
83-
@library.for_range(len(bs))
84-
def _(i):
85-
b = bs[i]
86-
B.set_column(0, 1 - b.get_vector())
87-
B.set_column(1, b.get_vector())
88-
c = types.Array.create_from(dest_comp(B))
89-
reveal_sort(c, h, reverse=False)
90-
@library.if_e(i < len(bs) - 1)
91-
def _():
92-
reveal_sort(h, bs[i + 1], reverse=True)
93-
@library.else_
94-
def _():
95-
reveal_sort(h, D, reverse=True)
9676
return h

0 commit comments

Comments
 (0)