diff --git a/mpyc/runtime.py b/mpyc/runtime.py index f0e2f10..7d071f2 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -770,6 +770,8 @@ async def trunc(self, x, f=None, l=None): l = l or sftype.bit_length if f is None: f = sftype.frac_length + if issubclass(sftype, self.SecureFixedPoint): + l += f else: await self.returnType(Future) Zp = sftype @@ -783,12 +785,12 @@ async def trunc(self, x, f=None, l=None): s <<= 1 s += r_bits[f * j + i].value r_modf[j] = Zp(s) - r_divf = self._randoms(Zp, n, 1 << k + l) + r_divf = self._randoms(Zp, n, 1 << k + l - f) if self.options.no_prss: r_divf = await r_divf if issubclass(sftype, self.SecureObject): x = await self.gather(x) - c = await self.output([a + ((1 << l-1 + f) + (q.value << f) + r.value) + c = await self.output([a + ((1 << l-1) + (q.value << f) + r.value) for a, q, r in zip(x, r_divf, r_modf)]) c = [c.value % (1<> f for a, c, r in zip(x, c, r_modf)] @@ -810,6 +812,8 @@ async def np_trunc(self, a, f=None, l=None): l = l or sftype.sectype.bit_length if f is None: f = sftype.frac_length + if issubclass(sftype, self.SecureFixedPoint): + l += f else: await self.returnType(Future) Zp = sftype.field @@ -818,14 +822,14 @@ async def np_trunc(self, a, f=None, l=None): r_bits = await self.np_random_bits(Zp, f * n) r_modf = np.sum(r_bits.value.reshape((n, f)) << np.arange(f), axis=1) r_modf = r_modf.reshape(a.shape) - r_divf = self._np_randoms(Zp, n, 1 << k + l) + r_divf = self._np_randoms(Zp, n, 1 << k + l - f) if self.options.no_prss: r_divf = await r_divf r_divf = r_divf.value r_divf = r_divf.reshape(a.shape) if issubclass(sftype, self.SecureObject): a = await self.gather(a) - c = await self.output(Zp.array(a.value + (1 << l-1 + f) + (r_divf << f) + r_modf)) + c = await self.output(Zp.array(a.value + (1 << l-1) + (r_divf << f) + r_modf)) c = c.value & ((1<> f return y diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 6df15a3..b051070 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -99,6 +99,8 @@ def test_secint_array(self): np.assertEqual(mpc.run(mpc.output(c @ c)), a @ a) np.assertEqual(mpc.run(mpc.output(c @ a)), a @ a) np.assertEqual(mpc.run(mpc.output(a @ c)), a @ a) + self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(np.sum(np.abs(d)), 3))), 1, delta=1) + self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(-np.sum(np.abs(d)), 3))), -2, delta=1) self.assertEqual(mpc.run(mpc.output(c)).dtype, object) @@ -727,6 +729,8 @@ def test_secint(self): self.assertEqual(mpc.run(mpc.output(secint(5) // 2)), 2) self.assertEqual(mpc.run(mpc.output(secint(50) // 2)), 25) self.assertEqual(mpc.run(mpc.output(secint(50) // 4)), 12) + self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(secint(50), 2))), 12, delta=1) + self.assertAlmostEqual(mpc.run(mpc.output(mpc.trunc(secint(-50), 2))), -13, delta=1) self.assertEqual(mpc.run(mpc.output(secint(11) << 3)), 88) self.assertEqual(mpc.run(mpc.output(secint(-11) << 3)), -88) self.assertEqual(mpc.run(mpc.output(secint(70) >> 2)), 17)