-
Notifications
You must be signed in to change notification settings - Fork 339
Rabbit implementation #1613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Rabbit implementation #1613
Changes from all commits
d9c7e2b
b84f89c
d948fbc
c9eda84
00fef6a
bf3aa31
293b0da
76f017f
69903f7
56159b9
79eb954
d6a6180
6f2509d
d93a837
87aef70
33a657e
fcd3560
1c0feb0
0769176
bc5fd15
dc9b9af
62557d6
1f1d02a
f1fd91c
1af4340
440aa9c
3fef1ce
94ab935
aa28c06
dbf3c41
c22ccd9
afc4d43
c233a0c
44d300c
e858c43
de4a1d2
0ca76cf
dd7f9c3
10cb1da
8a88dab
75b5c97
8fd63ef
82054bf
5ebccab
5fcc8af
c439ccc
53d9d22
0ebd7cf
73987ac
816499c
c782d37
8efeb88
a2b4679
cf02aff
1066833
f944659
210c21a
a9b579e
931246b
6241f2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,6 +131,25 @@ def print_ln(s='', *args, **kwargs): | |
| """ | ||
| print_str(str(s) + '\n', *args, **kwargs) | ||
|
|
||
|
|
||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will also be deleted |
||
| def print_without_ln(s='', *args, **kwargs): | ||
| """ Print line, with optional args for adding variables/registers | ||
| with ``%s``. By default only player 0 outputs, but the ``-I`` | ||
| command-line option changes that. | ||
|
|
||
| :param s: Python string with same number of ``%s`` as length of :py:obj:`args` | ||
| :param args: list of public values (regint/cint/int/cfix/cfloat/localint) | ||
| :param print_secrets: whether to output secret shares | ||
|
|
||
| Example: | ||
|
|
||
| .. code:: | ||
|
|
||
| print_ln('a is %s.', a.reveal()) | ||
| """ | ||
| print_str(str(s), *args, **kwargs) | ||
|
|
||
|
|
||
| def print_both(s, end='\n'): | ||
| """ Print line during compilation and execution. """ | ||
| print(s, end=end) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,8 +44,52 @@ def trunc(self, a, k, m, signed): | |
| if m == 0: | ||
| return a | ||
| return self._trunc(a, k, m, signed) | ||
|
|
||
| def LTBits(self, R, x, BIT_SIZE): | ||
| library.print_ln("in LTBits") | ||
| R_bits = cint.bit_decompose(R, BIT_SIZE) | ||
| y = [x[i].bit_xor(R_bits[i]) for i in range(BIT_SIZE)] | ||
| z = floatingpoint.PreOpL(floatingpoint.or_op, y[::-1])[::-1] + [0] | ||
| w = [z[i] - z[i + 1] for i in range(BIT_SIZE)] | ||
|
|
||
| return types.sintbit(1) - types.sintbit(sum((R_bits[i] & w[i]) for i in range(BIT_SIZE))) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line was initially
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, now the result of the comparison is no longer 0 or 1, but I had values like -1 and 2, so I think that there might be something wrong with the typing there. |
||
|
|
||
| def rabbitLTZ(self, x, BIT_SIZE = 64): | ||
| """ | ||
| s = (x <? 0) | ||
| BIT_SIZE: bit length of x | ||
| """ | ||
| length_eda = 64 # BIT_SIZE | ||
|
|
||
| M = P_VALUES[64] # TODO: get program.prime | ||
| R = 0 | ||
|
|
||
| r, r_bits = sint.get_edabit(length_eda, True) | ||
| masked_a = (x + r).reveal() | ||
| masked_b = (x + r + M - R).reveal() | ||
| w = [None, None, None, None] | ||
|
|
||
| library.print_ln("w1, comparing: masked_a=%s edabit=%s", masked_a, r.reveal()) | ||
| w[1] = self.LTBits(masked_a, r_bits, BIT_SIZE) | ||
|
|
||
| library.print_ln("w2, comparing: masked_b=%s edabit=%s", masked_b, r.reveal()) | ||
| w[2] = self.LTBits(masked_b, r_bits, BIT_SIZE) | ||
|
|
||
| library.print_ln("w3, comparing: masked_b=%s with zero", masked_a) | ||
| w[3] = cint(masked_b < 0) | ||
|
|
||
| result = w[1] - w[2] + w[3] | ||
|
|
||
| library.print_ln("w1=%s w2=%s w3=%s result=%s", w[1].reveal(), w[2].reveal(), w[3], result.reveal()) | ||
| return sint(1 - result) | ||
|
|
||
| def ltz(self, a, k): | ||
| library.print_ln("a=%s k=%s", a.reveal(), k) | ||
| prog = program.Program.prog | ||
| if prog.options.comparison_rabbit: | ||
| return self.rabbitLTZ(a, k) | ||
|
|
||
| # else, use truncation | ||
| return -self.trunc(a, k, k - 1, True) | ||
|
|
||
| class Masking(NonLinear): | ||
|
|
@@ -126,6 +170,10 @@ def eqz(self, a, k): | |
| return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) | ||
|
|
||
| def ltz(self, a, k): | ||
| prog = program.Program.prog | ||
| if prog.options.comparison_rabbit: | ||
| return -20 | ||
|
|
||
| if k + 1 < self.prime.bit_length(): | ||
| # https://dl.acm.org/doi/10.1145/3474123.3486757 | ||
| # "negative" values wrap around when doubling, thus becoming odd | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5107,6 +5107,141 @@ def update(self, other): | |
| assert self.m == other.m | ||
| self.v.update(other.v) | ||
|
|
||
|
|
||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class will be removed from types.py as it's not needed. The rabbit comparison logic is added in non_linear.py |
||
| class custom_sfix(sfix): | ||
|
|
||
| # R: clear-text; x: edabit in binary format | ||
| @vectorize | ||
| def LTBits(self, R, x, r, BIT_SIZE): | ||
| R_bits = cint.bit_decompose(R, BIT_SIZE) | ||
| library.print_ln("\nLTBits: R_bits= ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", R_bits[i]) | ||
|
|
||
| edabit = cint.bit_decompose(r.reveal(), BIT_SIZE) | ||
| library.print_ln("\nLTBits: edabit= ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", edabit[i].reveal()) | ||
|
|
||
| library.print_ln("\nLTBits from bits: x = ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", x[i].reveal()) | ||
|
|
||
| y = [x[i].bit_xor(R_bits[i]) for i in range(BIT_SIZE)] | ||
| library.print_ln("\nLTBits: y= ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", y[i].reveal()) | ||
|
|
||
| #z = floatingpoint.PreOpL(floatingpoint.or_op, y[::-1])[::-1] + [0] | ||
| z = floatingpoint.PreOpL(floatingpoint.or_op, y) | ||
| library.print_ln("\nLTBits: z= ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", z[i].reveal()) | ||
|
|
||
| z = [0] + z | ||
| w = [z[i] - z[i - 1] for i in range(BIT_SIZE, 0, -1)] | ||
| w = w[::-1] | ||
| library.print_ln("\nLTBits: w= ") | ||
| for i in range(BIT_SIZE): | ||
| library.print_without_ln("%s", w[i].reveal()) | ||
|
|
||
| s = sum((R_bits[i] & w[i]) for i in range(BIT_SIZE)) | ||
| library.print_ln("\nSum=%s", s.reveal()) | ||
| return_value = 1 - sum((R_bits[i] & w[i]) for i in range(BIT_SIZE)) | ||
| return return_value | ||
|
|
||
| @vectorize | ||
| def rabbitLTZ(self, x, BIT_SIZE = 64): | ||
| """ | ||
| s = (c ?< a) | ||
|
|
||
| BIT_SIZE: bit length of a | ||
| """ | ||
| length_eda = BIT_SIZE | ||
| library.print_ln("custom sfix: bitsize = %s", BIT_SIZE) | ||
|
|
||
| M = P_VALUES[64] | ||
| R = (M - 1) // 2 # for field; use 0 for ring | ||
|
|
||
| r, r_bits = sint.get_edabit(length_eda, True) | ||
| masked_a = (x + r).reveal() | ||
| masked_b = (x + r + M - R).reveal() # masked_a | ||
| w = [None, None, None, None] | ||
|
|
||
| w[1] = self.LTBits(masked_a, r_bits, r, BIT_SIZE) | ||
| library.print_ln("w1, comparing: masked_a=%s edabit=%s w1=%s", masked_a, r.reveal(), w[1].reveal()) | ||
|
|
||
| w[2] = self.LTBits(masked_b, r_bits, r, BIT_SIZE) | ||
| library.print_ln("w2, comparing: masked_b=%s edabit=%s w2=%s", masked_b, r.reveal(), w[2].reveal()) | ||
|
|
||
| w[3] = cint(masked_b < 0) | ||
| library.print_ln("w3, comparing: masked_b=%s with %s, w3=%s", masked_b, M - R, w[3].reveal()) | ||
|
|
||
| result = w[1] - w[2] + w[3] | ||
|
|
||
| library.print_ln("final result = %s and 1- result=%s", result.reveal(), (1-result).reveal()) | ||
| return sint(1 - result) | ||
|
|
||
| @vectorize | ||
| def rabbitLTS_fix(self, a, b): | ||
| return 1 - self.rabbitLTS(b, a) | ||
|
|
||
| @vectorize | ||
| def rabbitLTS(self, a, b): | ||
| res = self.rabbitLTZ(a - b) | ||
| return res | ||
|
|
||
|
|
||
| # These are based on the implementation | ||
| # self.rabbitLTS(a, b) = rabbitLTC(a-b, 0) | ||
| @vectorize | ||
| def __lt__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = self.rabbitLTS_fix(a, b) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __le__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = 1 - self.rabbitLTS_fix(b, a) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __gt__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = self.rabbitLTS_fix(b, a) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __ge__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = 1 - self.rabbitLTS_fix(a, b) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __eq__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = (1 - self.rabbitLTS_fix(a, b)) * (1 - self.rabbitLTS_fix(b, a)) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __ne__(self, other): | ||
| a = self.v | ||
| b = other.v | ||
| result = 1 - (1 - self.rabbitLTS_fix(a, b)) * (1 - self.rabbitLTS_fix(b, a)) | ||
| return result | ||
|
|
||
| @vectorize | ||
| def __abs__(self): | ||
| """ Absolute value. """ | ||
| return (self < custom_sfix(0)).if_else(-self, self) | ||
|
|
||
|
|
||
| sfix.unreduced_type = unreduced_sfix | ||
|
|
||
| sfix.set_precision(16, 31) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried setting the flag for comparison_rabbit with
program.use_comparison_rabbit(True)in the .mpc file (similar touse_edabits), but it always remained false, and I only managed to set it true by adding--comparison_rabitin the compilation instruction. Do you see what I am doing wrong there?