Skip to content

Commit

Permalink
Complete mpc.np_divide().
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Mar 11, 2024
1 parent e77f9a8 commit 10950f7
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
6 changes: 3 additions & 3 deletions mpyc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
and statistics (securely mimicking Python’s statistics module).
"""

__version__ = '0.9.9'
__version__ = '0.9.10'
__license__ = 'MIT License'

import os
Expand Down Expand Up @@ -170,10 +170,10 @@ def get_arg_parser():
if importlib.util.find_spec('winloop' if sys.platform.startswith('win32') else 'uvloop'):
# uvloop (winloop) package available
if options.no_uvloop or env_no_uvloop:
logging.info(f'Use of package uvloop (winloop) inside MPyC disabled.')
logging.info('Use of package uvloop (winloop) inside MPyC disabled.')
elif sys.platform.startswith('win32'):
from winloop import EventLoopPolicy
logging.debug(f'Load winloop')
logging.debug('Load winloop')
else:
from uvloop import EventLoopPolicy, _version
logging.debug(f'Load uvloop version {_version.__version__}')
Expand Down
8 changes: 2 additions & 6 deletions mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,8 +1115,6 @@ def div(self, a, b):
if f:
if isinstance(b, (int, float)):
c = 1/b
if c.is_integer():
c = round(c)
else:
c = b.reciprocal() << f
else:
Expand All @@ -1140,17 +1138,15 @@ def np_divide(self, a, b):

# isinstance(a, self.SecureArray) ensured
if f:
if isinstance(b, (int, float)):
if isinstance(b, (int, float, np.ndarray)):
c = 1/b
if c.is_integer():
c = round(c)
elif isinstance(b, self.SecureFixedPoint):
c = self._rec(b)
else:
c = b.reciprocal() << f
else:
if not isinstance(b, field.array):
b = field.array(b) # TODO: see if this can be used for case f != 0 as well
b = field.array(b)
c = b.reciprocal()
return self.np_multiply(a, c)

Expand Down
3 changes: 3 additions & 0 deletions mpyc/sectypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if op == operator.sub:
return inputs[1].__rsub__(inputs[0])

if op == operator.truediv:
return inputs[1].__rtruediv__(inputs[0])

return op(inputs[1], inputs[0])

if op := unary_ops.get(ufunc):
Expand Down
19 changes: 15 additions & 4 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __lt__(self, other):
@unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled')
def test_secfxp_array(self):
np.assertEqual = np.testing.assert_array_equal
np.assertAlmostEqual = np.testing.assert_allclose

secfxp = mpc.SecFxp(12)
a = np.array([[-1.5, 2.5], [4.5, -8.5]])
Expand All @@ -252,6 +253,8 @@ def test_secfxp_array(self):
np.assertEqual(mpc.run(mpc.output(c * np.array([1.5, 2.5]))), a * np.array([1.5, 2.5]))
np.assertEqual(mpc.run(mpc.output(c * secfxp(2.5))), a * 2.5)
np.assertEqual(mpc.run(mpc.output(c * 2.5)), a * 2.5)
np.assertEqual(mpc.run(mpc.output(c / secfxp.field(2))), a / 2)
np.assertEqual(mpc.run(mpc.output(c / secfxp.field.array([2]))), a / 2)

# NB: NumPy dispatcher converts np.int8 to int
np.assertEqual(mpc.run(mpc.output(c * np.int8(2))), a * 2)
Expand All @@ -278,10 +281,17 @@ def test_secfxp_array(self):
f = 32
secfxp = mpc.SecFxp(2*f)
c = secfxp.array(a)
np.testing.assert_allclose(mpc.run(mpc.output(c / 2.45)), a / 2.45, rtol=0, atol=2**(1-f))
np.testing.assert_allclose(mpc.run(mpc.output(c / 2.5)), a / 2.5, rtol=0, atol=2**(2-f))
np.testing.assert_allclose(mpc.run(mpc.output(1 / c)), 1 / a, rtol=0, atol=2**(1-f))
np.testing.assert_allclose(mpc.run(mpc.output(c / c)), 1, rtol=0, atol=2**(3-f))
np.assertAlmostEqual(mpc.run(mpc.output(c / 0.5)), a / 0.5, rtol=0, atol=0)
np.assertAlmostEqual(mpc.run(mpc.output(c / 2.45)), a / 2.45, rtol=0, atol=2**(1-f))
np.assertAlmostEqual(mpc.run(mpc.output(c / 2.5)), a / 2.5, rtol=0, atol=2**(2-f))
np.assertAlmostEqual(mpc.run(mpc.output(c / c[0, 1])), a / 2.5, rtol=0, atol=2**(3-f))
np.assertAlmostEqual(mpc.run(mpc.output(1 / c)), 1 / a, rtol=0, atol=2**(1-f))
np.assertAlmostEqual(mpc.run(mpc.output(secfxp(1.5) / c)), 1.5 / a, rtol=0, atol=2**(1-f))
np.assertAlmostEqual(mpc.run(mpc.output(1.5 / c)), 1.5 / a, rtol=0, atol=2**(1-f))
np.assertAlmostEqual(mpc.run(mpc.output(a / c)), 1, rtol=0, atol=2**(3-f))
np.assertAlmostEqual(mpc.run(mpc.output((2*a).astype(int) / c)), 2, rtol=0, atol=2**(4-f))
np.assertAlmostEqual(mpc.run(mpc.output(c / a)), 1, rtol=0, atol=2**(0-f))
np.assertAlmostEqual(mpc.run(mpc.output(c / c)), 1, rtol=0, atol=2**(3-f))
np.assertEqual(mpc.run(mpc.output(np.equal(c, c))), True)
np.assertEqual(mpc.run(mpc.output(np.equal(c, 0))), False)
np.assertEqual(mpc.run(mpc.output(np.sum(c, axis=(-2, 1)))), np.sum(a, axis=(-2, 1)))
Expand Down Expand Up @@ -824,6 +834,7 @@ def test_secfxp(self):
self.assertAlmostEqual(mpc.run(mpc.output(c / d)), t, delta=2**(3-f))
t = -s[3] / s[2]
self.assertAlmostEqual(mpc.run(mpc.output(-d / c)), t, delta=2**(3-f))
self.assertEqual(mpc.run(mpc.output(secfxp(2) / secfxp.field(2))), 1)

self.assertEqual(mpc.run(mpc.output(mpc.sgn(+a))), s[0] > 0)
self.assertEqual(mpc.run(mpc.output(mpc.sgn(-a))), -(s[0] > 0))
Expand Down

0 comments on commit 10950f7

Please sign in to comment.