Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
d9c7e2b
custom_sfix using rabbit
Poppy22 Jan 11, 2025
b84f89c
replace not with 1 -
Poppy22 Jan 11, 2025
d948fbc
replace and with *
Poppy22 Jan 11, 2025
c9eda84
adding fixes after fixes
Poppy22 Feb 1, 2025
00fef6a
try to fix it
Poppy22 Feb 1, 2025
bf3aa31
add modulo
Poppy22 Feb 1, 2025
293b0da
add print_ln
Poppy22 Feb 1, 2025
76f017f
new day, new attempt
Poppy22 Feb 2, 2025
69903f7
add self
Poppy22 Feb 2, 2025
56159b9
debug
Poppy22 Feb 2, 2025
79eb954
add reveal
Poppy22 Feb 2, 2025
d6a6180
fix printing
Poppy22 Feb 2, 2025
6f2509d
p of 64 bits
Poppy22 Feb 2, 2025
d93a837
infinite debugging
Poppy22 Feb 2, 2025
87aef70
add reveal
Poppy22 Feb 2, 2025
33a657e
trying to print a binary value.....
Poppy22 Feb 2, 2025
fcd3560
still debugging...
Poppy22 Feb 2, 2025
1c0feb0
a=try dragos
Poppy22 Feb 2, 2025
0769176
going back
Poppy22 Feb 2, 2025
bc5fd15
add .output()
Poppy22 Feb 2, 2025
dc9b9af
trying to print...
Poppy22 Feb 2, 2025
62557d6
use print_ln
Poppy22 Feb 3, 2025
1f1d02a
add %s
Poppy22 Feb 3, 2025
f1fd91c
add prints
Poppy22 Feb 5, 2025
1af4340
add prints
Poppy22 Feb 5, 2025
440aa9c
still
Poppy22 Feb 5, 2025
3fef1ce
fix typo
Poppy22 Feb 5, 2025
94ab935
add y
Poppy22 Feb 5, 2025
aa28c06
return z1
Poppy22 Feb 11, 2025
dbf3c41
trying to fix w3
Poppy22 Feb 11, 2025
c22ccd9
new w3
Poppy22 Feb 11, 2025
afc4d43
fix truth value
Poppy22 Feb 11, 2025
c233a0c
add trick to remove if-branch
Poppy22 Feb 11, 2025
44d300c
add print
Poppy22 Feb 11, 2025
e858c43
try with R 0
Poppy22 Feb 12, 2025
de4a1d2
fixes and remove prints
Poppy22 Feb 12, 2025
0ca76cf
cleanup
Poppy22 Feb 14, 2025
dd7f9c3
override abs
Poppy22 Feb 17, 2025
10cb1da
add vectorise
Poppy22 Mar 1, 2025
8a88dab
WIP: add comparison_rabbit program argument
Poppy22 Mar 22, 2025
75b5c97
add comparison_rabbit in CompilerLib
Poppy22 Mar 22, 2025
8fd63ef
debug attempt 1
Poppy22 Mar 22, 2025
82054bf
add print
Poppy22 Mar 22, 2025
5ebccab
raise comp error for debugging
Poppy22 Mar 22, 2025
5fcc8af
add rabbit code to non_linear.py
Poppy22 Mar 22, 2025
c439ccc
remove sum one-liner
Poppy22 Mar 22, 2025
53d9d22
try sintbit
Poppy22 Mar 22, 2025
0ebd7cf
try converting the other term as well
Poppy22 Mar 22, 2025
73987ac
go back to sum()
Poppy22 Mar 22, 2025
816499c
add prints in non_linear.py
Poppy22 Mar 24, 2025
c782d37
try 32 bits
Poppy22 Mar 24, 2025
8efeb88
debugging
Poppy22 Mar 28, 2025
a2b4679
try to make it work for field
Poppy22 Mar 29, 2025
cf02aff
debug for fields
Poppy22 Mar 29, 2025
1066833
testing a theory
Poppy22 Mar 29, 2025
f944659
use < 0
Poppy22 Mar 29, 2025
210c21a
debug ltbits
Poppy22 Mar 29, 2025
a9b579e
try to fix bitlt
Poppy22 Mar 29, 2025
931246b
reverse w
Poppy22 Mar 29, 2025
6241f2c
even more debugging prints
Poppy22 Mar 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ def build_option_parser(self):
dest="edabit",
help="mixing arithmetic and binary computation using edaBits",
)
parser.add_option(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried setting the flag for comparison_rabbit with program.use_comparison_rabbit(True) in the .mpc file (similar to use_edabits), but it always remained false, and I only managed to set it true by adding --comparison_rabit in the compilation instruction. Do you see what I am doing wrong there?

"--comparison_rabbit",
action="store_true",
dest="comparison_rabbit",
help="using the rabbit comparison protocol for known prime modulus, instead of truncation",
)
parser.add_option(
"-Z",
"--split",
Expand Down
19 changes: 19 additions & 0 deletions Compiler/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ def print_ln(s='', *args, **kwargs):
"""
print_str(str(s) + '\n', *args, **kwargs)


Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will also be deleted

def print_without_ln(s='', *args, **kwargs):
""" Print line, with optional args for adding variables/registers
with ``%s``. By default only player 0 outputs, but the ``-I``
command-line option changes that.

:param s: Python string with same number of ``%s`` as length of :py:obj:`args`
:param args: list of public values (regint/cint/int/cfix/cfloat/localint)
:param print_secrets: whether to output secret shares

Example:

.. code::

print_ln('a is %s.', a.reveal())
"""
print_str(str(s), *args, **kwargs)


def print_both(s, end='\n'):
""" Print line during compilation and execution. """
print(s, end=end)
Expand Down
48 changes: 48 additions & 0 deletions Compiler/non_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,52 @@ def trunc(self, a, k, m, signed):
if m == 0:
return a
return self._trunc(a, k, m, signed)

def LTBits(self, R, x, BIT_SIZE):
library.print_ln("in LTBits")
R_bits = cint.bit_decompose(R, BIT_SIZE)
y = [x[i].bit_xor(R_bits[i]) for i in range(BIT_SIZE)]
z = floatingpoint.PreOpL(floatingpoint.or_op, y[::-1])[::-1] + [0]
w = [z[i] - z[i + 1] for i in range(BIT_SIZE)]

return types.sintbit(1) - types.sintbit(sum((R_bits[i] & w[i]) for i in range(BIT_SIZE)))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was initially 1 - sum(...) but I was getting the error below, and it only worked when I converted 1 and the value of the sum. Could you explain this part?

File "/benchmarks/MP-SPDZ/Compiler/comparison.py", line 85, in LTZ
    movs(s, program.non_linear.ltz(a, k))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/non_linear.py", line 81, in ltz
    return self.rabbitLTZ(a, k)
           ^^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/non_linear.py", line 71, in rabbitLTZ
    w[1] = self.LTBits(masked_a, r_bits, BIT_SIZE)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/non_linear.py", line 54, in LTBits
    return_value = 1 - sum((R_bits[i] & w[i]) for i in range(BIT_SIZE))
                   ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  File "/benchmarks/MP-SPDZ/Compiler/types.py", line 229, in read_mem_operation
    return operation(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/GC/types.py", line 523, in __add__
    return self.xor_int(other)
           ^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/GC/types.py", line 585, in xor_int
    self_bits = self.bit_decompose()
                ^^^^^^^^^^^^^^^^^^^^
  File "/benchmarks/MP-SPDZ/Compiler/GC/types.py", line 86, in bit_decompose
    suffix = [0] * (n - self.n)
                    ~~^~~~~~~~
TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, now the result of the comparison is no longer 0 or 1, but I had values like -1 and 2, so I think that there might be something wrong with the typing there.


def rabbitLTZ(self, x, BIT_SIZE = 64):
"""
s = (x <? 0)
BIT_SIZE: bit length of x
"""
length_eda = 64 # BIT_SIZE

M = P_VALUES[64] # TODO: get program.prime
R = 0

r, r_bits = sint.get_edabit(length_eda, True)
masked_a = (x + r).reveal()
masked_b = (x + r + M - R).reveal()
w = [None, None, None, None]

library.print_ln("w1, comparing: masked_a=%s edabit=%s", masked_a, r.reveal())
w[1] = self.LTBits(masked_a, r_bits, BIT_SIZE)

library.print_ln("w2, comparing: masked_b=%s edabit=%s", masked_b, r.reveal())
w[2] = self.LTBits(masked_b, r_bits, BIT_SIZE)

library.print_ln("w3, comparing: masked_b=%s with zero", masked_a)
w[3] = cint(masked_b < 0)

result = w[1] - w[2] + w[3]

library.print_ln("w1=%s w2=%s w3=%s result=%s", w[1].reveal(), w[2].reveal(), w[3], result.reveal())
return sint(1 - result)

def ltz(self, a, k):
library.print_ln("a=%s k=%s", a.reveal(), k)
prog = program.Program.prog
if prog.options.comparison_rabbit:
return self.rabbitLTZ(a, k)

# else, use truncation
return -self.trunc(a, k, k - 1, True)

class Masking(NonLinear):
Expand Down Expand Up @@ -126,6 +170,10 @@ def eqz(self, a, k):
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))

def ltz(self, a, k):
prog = program.Program.prog
if prog.options.comparison_rabbit:
return -20

if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
Expand Down
21 changes: 21 additions & 0 deletions Compiler/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class defaults:
stop = False
insecure = False
keep_cisc = False
comparison_rabbit = False


class Program(object):
Expand Down Expand Up @@ -203,11 +204,16 @@ def __init__(self, args, options=defaults, name=None):
gc.inputbvec,
gc.reveal,
]

self.use_trunc_pr = False
""" Setting whether to use special probabilistic truncation. """

self.use_dabit = options.mixed
""" Setting whether to use daBits for non-linear functionality. """

self._edabit = options.edabit
self._comparison_rabbit = options.comparison_rabbit

""" Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)"""
self._invperm = options.invperm
self._split = False
Expand Down Expand Up @@ -676,6 +682,19 @@ def use_edabit(self, change=None):
else:
self._edabit = change

def use_comparison_rabbit(self, change=None):
"""Setting whether to use the rabbit comparison protocol (default: false).

:param change: change setting if not :py:obj:`None`
:returns: setting if :py:obj:`change` is :py:obj:`None`
"""
if change is None:
if not self._comparison_rabbit:
self.relevant_opts.add("comparison_rabbit")
return self._comparison_rabbit
else:
self._comparison_rabbit = change

def use_invperm(self, change=None):
""" Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used.

Expand Down Expand Up @@ -750,6 +769,8 @@ def options_from_args(self):
self.always_raw(True)
if "edabit" in self.args:
self.use_edabit(True)
if "comparison_rabbit" in self.args:
self.use_comparison_rabbit(True)
if "invperm" in self.args:
self.use_invperm(True)
if "linear_rounds" in self.args:
Expand Down
135 changes: 135 additions & 0 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5107,6 +5107,141 @@ def update(self, other):
assert self.m == other.m
self.v.update(other.v)


Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class will be removed from types.py as it's not needed. The rabbit comparison logic is added in non_linear.py

class custom_sfix(sfix):

# R: clear-text; x: edabit in binary format
@vectorize
def LTBits(self, R, x, r, BIT_SIZE):
R_bits = cint.bit_decompose(R, BIT_SIZE)
library.print_ln("\nLTBits: R_bits= ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", R_bits[i])

edabit = cint.bit_decompose(r.reveal(), BIT_SIZE)
library.print_ln("\nLTBits: edabit= ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", edabit[i].reveal())

library.print_ln("\nLTBits from bits: x = ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", x[i].reveal())

y = [x[i].bit_xor(R_bits[i]) for i in range(BIT_SIZE)]
library.print_ln("\nLTBits: y= ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", y[i].reveal())

#z = floatingpoint.PreOpL(floatingpoint.or_op, y[::-1])[::-1] + [0]
z = floatingpoint.PreOpL(floatingpoint.or_op, y)
library.print_ln("\nLTBits: z= ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", z[i].reveal())

z = [0] + z
w = [z[i] - z[i - 1] for i in range(BIT_SIZE, 0, -1)]
w = w[::-1]
library.print_ln("\nLTBits: w= ")
for i in range(BIT_SIZE):
library.print_without_ln("%s", w[i].reveal())

s = sum((R_bits[i] & w[i]) for i in range(BIT_SIZE))
library.print_ln("\nSum=%s", s.reveal())
return_value = 1 - sum((R_bits[i] & w[i]) for i in range(BIT_SIZE))
return return_value

@vectorize
def rabbitLTZ(self, x, BIT_SIZE = 64):
"""
s = (c ?< a)

BIT_SIZE: bit length of a
"""
length_eda = BIT_SIZE
library.print_ln("custom sfix: bitsize = %s", BIT_SIZE)

M = P_VALUES[64]
R = (M - 1) // 2 # for field; use 0 for ring

r, r_bits = sint.get_edabit(length_eda, True)
masked_a = (x + r).reveal()
masked_b = (x + r + M - R).reveal() # masked_a
w = [None, None, None, None]

w[1] = self.LTBits(masked_a, r_bits, r, BIT_SIZE)
library.print_ln("w1, comparing: masked_a=%s edabit=%s w1=%s", masked_a, r.reveal(), w[1].reveal())

w[2] = self.LTBits(masked_b, r_bits, r, BIT_SIZE)
library.print_ln("w2, comparing: masked_b=%s edabit=%s w2=%s", masked_b, r.reveal(), w[2].reveal())

w[3] = cint(masked_b < 0)
library.print_ln("w3, comparing: masked_b=%s with %s, w3=%s", masked_b, M - R, w[3].reveal())

result = w[1] - w[2] + w[3]

library.print_ln("final result = %s and 1- result=%s", result.reveal(), (1-result).reveal())
return sint(1 - result)

@vectorize
def rabbitLTS_fix(self, a, b):
return 1 - self.rabbitLTS(b, a)

@vectorize
def rabbitLTS(self, a, b):
res = self.rabbitLTZ(a - b)
return res


# These are based on the implementation
# self.rabbitLTS(a, b) = rabbitLTC(a-b, 0)
@vectorize
def __lt__(self, other):
a = self.v
b = other.v
result = self.rabbitLTS_fix(a, b)
return result

@vectorize
def __le__(self, other):
a = self.v
b = other.v
result = 1 - self.rabbitLTS_fix(b, a)
return result

@vectorize
def __gt__(self, other):
a = self.v
b = other.v
result = self.rabbitLTS_fix(b, a)
return result

@vectorize
def __ge__(self, other):
a = self.v
b = other.v
result = 1 - self.rabbitLTS_fix(a, b)
return result

@vectorize
def __eq__(self, other):
a = self.v
b = other.v
result = (1 - self.rabbitLTS_fix(a, b)) * (1 - self.rabbitLTS_fix(b, a))
return result

@vectorize
def __ne__(self, other):
a = self.v
b = other.v
result = 1 - (1 - self.rabbitLTS_fix(a, b)) * (1 - self.rabbitLTS_fix(b, a))
return result

@vectorize
def __abs__(self):
""" Absolute value. """
return (self < custom_sfix(0)).if_else(-self, self)


sfix.unreduced_type = unreduced_sfix

sfix.set_precision(16, 31)
Expand Down
Loading