Skip to content

Commit

Permalink
Extend mpc.trunc() to secure integers.
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Apr 25, 2024
1 parent 085790a commit ba3c572
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,8 @@ async def trunc(self, x, f=None, l=None):
l = l or sftype.bit_length
if f is None:
f = sftype.frac_length
if issubclass(sftype, self.SecureFixedPoint):
l += f
else:
await self.returnType(Future)
Zp = sftype
Expand All @@ -783,12 +785,12 @@ async def trunc(self, x, f=None, l=None):
s <<= 1
s += r_bits[f * j + i].value
r_modf[j] = Zp(s)
r_divf = self._randoms(Zp, n, 1 << k + l)
r_divf = self._randoms(Zp, n, 1 << k + l - f)
if self.options.no_prss:
r_divf = await r_divf
if issubclass(sftype, self.SecureObject):
x = await self.gather(x)
c = await self.output([a + ((1 << l-1 + f) + (q.value << f) + r.value)
c = await self.output([a + ((1 << l-1) + (q.value << f) + r.value)
for a, q, r in zip(x, r_divf, r_modf)])
c = [c.value % (1<<f) for c in c]
y = [(a - c + r.value) >> f for a, c, r in zip(x, c, r_modf)]
Expand All @@ -810,6 +812,8 @@ async def np_trunc(self, a, f=None, l=None):
l = l or sftype.sectype.bit_length
if f is None:
f = sftype.frac_length
if issubclass(sftype, self.SecureFixedPoint):
l += f
else:
await self.returnType(Future)
Zp = sftype.field
Expand All @@ -818,14 +822,14 @@ async def np_trunc(self, a, f=None, l=None):
r_bits = await self.np_random_bits(Zp, f * n)
r_modf = np.sum(r_bits.value.reshape((n, f)) << np.arange(f), axis=1)
r_modf = r_modf.reshape(a.shape)
r_divf = self._np_randoms(Zp, n, 1 << k + l)
r_divf = self._np_randoms(Zp, n, 1 << k + l - f)
if self.options.no_prss:
r_divf = await r_divf
r_divf = r_divf.value
r_divf = r_divf.reshape(a.shape)
if issubclass(sftype, self.SecureObject):
a = await self.gather(a)
c = await self.output(Zp.array(a.value + (1 << l-1 + f) + (r_divf << f) + r_modf))
c = await self.output(Zp.array(a.value + (1 << l-1) + (r_divf << f) + r_modf))
c = c.value & ((1<<f) - 1)
y = Zp.array(a.value + r_modf - c) >> f
return y
Expand Down
4 changes: 4 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def test_secint_array(self):
np.assertEqual(mpc.run(mpc.output(c @ c)), a @ a)
np.assertEqual(mpc.run(mpc.output(c @ a)), a @ a)
np.assertEqual(mpc.run(mpc.output(a @ c)), a @ a)
self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(np.sum(np.abs(d)), 3))), 1, delta=1)
self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(-np.sum(np.abs(d)), 3))), -2, delta=1)

self.assertEqual(mpc.run(mpc.output(c)).dtype, object)

Expand Down Expand Up @@ -727,6 +729,8 @@ def test_secint(self):
self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2)
self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25)
self.assertEqual(mpc.run(mpc.output(secint(50) // 4)), 12)
self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(secint(50), 2))), 12, delta=1)
self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(secint(-50), 2))), -13, delta=1)
self.assertEqual(mpc.run(mpc.output(secint(11) << 3)), 88)
self.assertEqual(mpc.run(mpc.output(secint(-11) << 3)), -88)
self.assertEqual(mpc.run(mpc.output(secint(70) >> 2)), 17)
Expand Down

0 comments on commit ba3c572

Please sign in to comment.