Skip to content

Commit

Permalink
Machine learning functionality, dishonest-majority binary secret shar…
Browse files Browse the repository at this point in the history
…ing.
  • Loading branch information
mkskeller committed Oct 11, 2019
1 parent 5f0a7ad commit 7a5195d
Show file tree
Hide file tree
Showing 203 changed files with 6,255 additions and 1,484 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.1.2

- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission
- Binary computation for dishonest majority using secret sharing
- Mathematical functions from [SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA)
- Fixed security bug: CowGear would reuse triples.

## 0.1.1 (Aug 6, 2019)

- ECDSA
Expand Down
26 changes: 15 additions & 11 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def determine_scope(block, options):
used_from_scope = set()

def find_in_scope(reg, scope):
if scope is None:
return False
elif reg in scope.defined_registers:
return True
else:
return find_in_scope(reg, scope.scope)
while True:
if scope is None:
return False
elif reg in scope.defined_registers:
return True
scope = scope.scope

def read(reg, n):
if last_def[reg] == -1:
Expand Down Expand Up @@ -386,7 +386,7 @@ def dependency_graph(self, merge_classes):
last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
last_text_input = None
last_text_input = [None, None]

depths = [0] * len(block.instructions)
self.depths = depths
Expand Down Expand Up @@ -474,10 +474,14 @@ def keep_order(instr, n, t, arg_index=None):

# will be merged
if isinstance(instr, TextInputInstruction):
if last_text_input is not None and \
type(block.instructions[last_text_input]) is not type(instr):
add_edge(last_text_input, n)
last_text_input = n
if last_text_input[0] is not None:
if instr.merge_id() != \
block.instructions[last_text_input[0]].merge_id():
add_edge(last_text_input[0], n)
last_text_input[1] = last_text_input[0]
elif last_text_input[1] is not None:
add_edge(last_text_input[1], n)
last_text_input[0] = n

if isinstance(instr, merge_classes):
open_nodes.add(n)
Expand Down
24 changes: 20 additions & 4 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def LTZ(s, a, k, kappa):
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)

def LessThanZero(a, k, kappa):
import types
res = types.sint()
LTZ(res, a, k, kappa)
return res

def Trunc(d, a, k, m, kappa, signed):
"""
d = a >> m
Expand Down Expand Up @@ -153,6 +159,8 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
k: bit length of a
m: compile-time integer
"""
if m == 0:
return a
if k == int(program.options.ring):
# cannot work with bit length k+1
tmp = TruncRing(None, a, k, m - 1, signed)
Expand Down Expand Up @@ -359,7 +367,7 @@ def CarryOutAux(d, a, kappa):
movs(d, a[0][1])

# carry out with carry-in bit c
def CarryOut(res, a, b, c, kappa):
def CarryOut(res, a, b, c=0, kappa=None):
"""
res = last carry bit in addition of a and b
Expand All @@ -368,21 +376,29 @@ def CarryOut(res, a, b, c, kappa):
c: initial carry-in bit
"""
k = len(a)
import types
d = [program.curr_block.new_reg('s') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
t = [[types.sint() for i in range(k)] for i in range(4)]
s = [program.curr_block.new_reg('s') for i in range(3)]
for i in range(k):
mulm(t[0][i], b[i], a[i])
mulsi(t[1][i], t[0][i], 2)
addm(t[2][i], b[i], a[i])
subs(t[3][i], t[2][i], t[1][i])
d[i] = [t[3][i], t[0][i]]
mulsi(s[0], d[-1][0], c)
adds(s[1], d[-1][1], s[0])
s[0] = d[-1][0] * c
s[1] = d[-1][1] + s[0]
d[-1][1] = s[1]

CarryOutAux(res, d[::-1], kappa)

def CarryOutLE(a, b, c=0):
""" Little-endian version """
import types
res = types.sint()
CarryOut(res, a[::-1], b[::-1], c)
return res

def BitLTL(res, a, b, kappa):
"""
res = a <? b (logarithmic rounds version)
Expand Down
5 changes: 4 additions & 1 deletion Compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
'bittriple': 0.00004828818388140422,
'bitgf2ntriple': 0.00020716801325875284,
'PreMulC': 2 * 0.00020716801325875284,
})
}),
'all': { 'round': 0,
'inv': 0,
}
}


Expand Down
61 changes: 43 additions & 18 deletions Compiler/floatingpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
def Pow2(a, l, kappa):
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
x = [types.sint() for i in range(m)]
return Pow2_from_bits(t)

def Pow2_from_bits(bits):
m = len(bits)
t = list(bits)
pow2k = [types.cint() for i in range(m)]
for i in range(m):
pow2k[i] = two_power(2**i)
Expand Down Expand Up @@ -353,13 +357,20 @@ def B2U_from_Pow2(pow2a, l, kappa):
#print ' '.join(str(b.value) for b in y)
return [1 - y[i] for i in range(l)]

def Trunc(a, l, m, kappa, compute_modulo=False):
def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
""" Oblivious truncation by secret m """
if util.is_constant(m) and not compute_modulo:
# cheaper
res = type(a)(size=a.size)
comparison.Trunc(res, a, l, m, kappa, signed=signed)
return res
if l == 1:
if compute_modulo:
return a * m, 1 + m
else:
return a * (1 - m)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, Pow2(m, l, kappa))
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
Expand All @@ -370,8 +381,6 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
x, pow2m = B2U(m, l, kappa)
#assert(pow2m.value == 2**m.value)
#assert(sum(b.value for b in x) == m.value)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, pow2m)
for i in range(l):
bit(r[i])
t1 = two_power(i) * r[i]
Expand Down Expand Up @@ -495,17 +504,28 @@ def TruncPrRing(a, k, m, signed=True):
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
else:
from types import sint
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m))
if signed:
a += (1 << (k - 1))
if program.Program.prog.use_trunc_pr:
res = sint()
trunc_pr(res, a, k, m)
else:
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + \
(overflow << (k - m))
if signed:
res -= (1 << (k - m - 1))
return res

def TruncPrField(a, k, m, kappa=None):
if m == 0:
return a
if kappa is None:
kappa = 40

Expand All @@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False):
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest)
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest,
signed=False)
y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False)
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
signed=False)
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
round_nearest, signed=False)
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
return y

Expand Down
65 changes: 65 additions & 0 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,55 @@ def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())

@base.vectorize
class inputmixed(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTMIXED']
field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM)
types = {
0: (1, 0),
1: (1, 1),
2: (4, 1)
}
type_ids = {
'int': 0,
'fix': 1,
'float': 2
}

def __init__(self, name, *args):
try:
type_id = self.type_ids[name]
except:
pass
super(inputmixed_class, self).__init__(type_id, *args)

@property
def arg_format(self):
for i in self.bases():
t = self.args[i]
yield 'int'
for j in range(self.types[t][0]):
yield 'sw'
for j in range(self.types[t][1]):
yield 'int'
yield 'p'

def bases(self):
i = 0
while i < len(self.args):
yield i
i += sum(self.types[self.args[i]]) + 2

def add_usage(self, req_node):
for i in self.bases():
t = self.args[i]
player = self.args[i + sum(self.types[t]) + 1]
n_dest = self.types[t][0]
req_node.increment((self.field_type, 'input', player), \
n_dest * self.get_size())

@base.gf2n
class startinput(base.RawInputInstruction):
r""" Receive inputs from player $p$. """
Expand Down Expand Up @@ -957,6 +1006,11 @@ class print_reg_plain(base.IOInstruction):
code = base.opcodes['PRINTREGPLAIN']
arg_format = ['c']

class cond_print_plain(base.IOInstruction):
r""" Conditionally print the value of a register. """
code = base.opcodes['CONDPRINTPLAIN']
arg_format = ['c', 'c']

class print_int(base.IOInstruction):
r""" Print only the value of register \verb|ci| to stdout. """
__slots__ = []
Expand Down Expand Up @@ -1383,6 +1437,9 @@ def get_repeat(self):

def merge_id(self):
# can merge different sizes
# but not if large
if self.get_size() > 100:
return type(self), self.get_size()
return type(self)

# def expand(self):
Expand Down Expand Up @@ -1468,6 +1525,14 @@ def get_used(self):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg

@base.vectorize
class trunc_pr(base.VarArgsInstruction):
""" Probalistic truncation for semi-honest computation """
""" with honest majority """
__slots__ = []
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])

###
### CISC-style instructions
###
Expand Down
18 changes: 3 additions & 15 deletions Compiler/instructions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
MULS = 0xA6,
MULRS = 0xA7,
DOTPRODS = 0xA8,
TRUNC_PR = 0xA9,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
Expand All @@ -102,6 +103,7 @@
INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
INPUTMIXED = 0xF2,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
Expand Down Expand Up @@ -168,6 +170,7 @@
READFILESHARE = 0xBE,
CONDPRINTSTR = 0xBF,
PRINTFLOATPREC = 0xE0,
CONDPRINTPLAIN = 0xE1,
GBITDEC = 0x184,
GBITCOM = 0x185,
# Secure socket
Expand Down Expand Up @@ -767,21 +770,6 @@ def check_args(self):
### Jumps etc
###

class dummywrite(Instruction):
""" Dummy instruction to create source node in the dependency graph,
preventing read-before-write warnings. """
__slots__ = []

def __init__(self, *args, **kwargs):
self.arg_format = [arg.reg_type + 'w' for arg in args]
super(dummywrite, self).__init__(*args, **kwargs)

def execute(self):
pass

def get_encoding(self):
return []

class JumpInstruction(Instruction):
__slots__ = ['jump_arg']

Expand Down
Loading

0 comments on commit 7a5195d

Please sign in to comment.