Skip to content

Commit

Permalink
Add mpc.peek() for "informational" purposes.
Browse files Browse the repository at this point in the history
Insert calls pretty much anywhere in your code to inspect values of secure (secret-shared) objects.
  • Loading branch information
lschoe authored Jul 17, 2024
1 parent a4f87df commit 94f820e
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mpyc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
and statistics (securely mimicking Python’s statistics module).
"""

__version__ = '0.10.1'
__version__ = '0.10.2'

import os
import sys
Expand Down
9 changes: 9 additions & 0 deletions mpyc/asyncoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def _add_callbacks(self, obj):
else:
self.tally += 1
obj.share.add_done_callback(self._decrement)
elif isinstance(obj.share, tuple):
for x in obj.share:
self._add_callbacks(x)
elif isinstance(obj, Future) and not obj.done():
self.tally += 1
obj.add_done_callback(self._decrement)
Expand All @@ -230,6 +233,9 @@ def _get_results(obj):
if isinstance(obj.share, Future):
return obj.share.result()

elif isinstance(obj.share, tuple):
return tuple(map(_get_results, obj.share))

return obj.share

if isinstance(obj, Future):
Expand All @@ -255,6 +261,9 @@ def gather_shares(rt, *obj):
if isinstance(obj.share, Future):
return obj.share

elif isinstance(obj.share, tuple):
return gather_shares(rt, obj.share)

return _AwaitableFuture(obj.share)

if not rt.options.no_async:
Expand Down
22 changes: 22 additions & 0 deletions mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,28 @@ def logging(self, enable=None):
else:
logging.disable(logging.INFO)

@asyncoro.mpc_coro
async def peek(self, x, label='') -> None:
"""Peek at the value of secret-shared x,
for "informational" and debugging purposes only.
For secure object x (or list of secure objects compatible with
runtime.output()), the value is logged once its computation is done.
In debug mode, the moment at which the task for the computation of x is
scheduled is logged, and the value of this party's secret share of x is
logged once the task for the computation of x has completed.
To facilitate matching of scheduled with completed tasks, an "address"
based on the program counter is included as well the given label, if any.
"""
txt = f'Peek at {abs(mpc._program_counter[0]) % (1<<24):#08x}:'
if label:
txt += f' {label}'
logging.debug(f'{txt} Task scheduled')
logging.info(f'{txt} Task output {await self.output(x)}')
logging.debug(f'{txt} Party {self.pid}\'s share {await self.gather(x)}')

async def start(self):
"""Start the MPyC runtime.
Expand Down
5 changes: 5 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_secint_array(self):
a = FF.array([[[-1, 1], [1, -1]]]) # 3D array
np.assertEqual(mpc.run(mpc.output(mpc.run(mpc.transfer(a, senders=0)))), a)
c = secint.array(a)
mpc.peek(c, label='secint.array')
a = a.copy() # NB: needed to pass tests with inplace operations
self.assertTrue((a == mpc.run(mpc.output(mpc.np_sgn(c)))).all()) # via FF array __eq__
self.assertTrue((mpc.run(mpc.output(mpc.np_sgn(c))) == a).all()) # via FF array ufunc equal
Expand Down Expand Up @@ -251,6 +252,7 @@ def test_secfxp_array(self):
secfxp = mpc.SecFxp(12)
a = np.array([[-1.5, 2.5], [4.5, -8.5]])
c = secfxp.array(a)
mpc.peek(c)

np.assertEqual(mpc.run(mpc.output(c + np.array([1, 2]))), a + np.array([1, 2]))
np.assertEqual(mpc.run(mpc.output(c * np.array([1, 2]))), a * np.array([1, 2]))
Expand Down Expand Up @@ -464,6 +466,7 @@ def test_secfld_array(self):

secfld = mpc.SecFld(2**2)
c = secfld.array(np.array([[-3, 0], [1, 2]]))
mpc.peek(c)
np.assertEqual(mpc.run(mpc.output(mpc.np_to_bits(c))), [[[1, 1], [0, 0]], [[1, 0], [0, 1]]])
np.assertEqual(mpc.run(mpc.output(mpc.np_from_bits(mpc.np_to_bits(c[1])))), [1, 2])
c = mpc._np_randoms(secfld, 5)
Expand Down Expand Up @@ -681,6 +684,7 @@ def test_secint(self):
secint = mpc.SecInt()
a = secint(12)
b = secint(13)
mpc.peek(b)
self.assertEqual(mpc.run(mpc.output(mpc.input(a, 0))), 12)
self.assertEqual(mpc.run(mpc.output(mpc.input([a, b], 0))), [12, 13])
self.assertEqual(mpc.run(mpc.output(-a)), -12)
Expand Down Expand Up @@ -911,6 +915,7 @@ def test_secflt(self):
secflt = mpc.SecFlt()
a = secflt(1.25)
b = secflt(2.5)
mpc.peek(b)
self.assertEqual(mpc.run(mpc.output(mpc.input(a, 0))), 1.25)
self.assertEqual(mpc.run(mpc.output(a + b)), 3.75)
self.assertEqual(mpc.run(mpc.output(-a + -b)), -3.75)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_secgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Arithmetic(unittest.TestCase):

@classmethod
def setUpClass(cls):
pass
mpc.logging(False)

@classmethod
def tearDownClass(cls):
Expand All @@ -22,7 +22,9 @@ def test_Sn(self):
b = a @ a
secgrp = mpc.SecGrp(group)
c = secgrp(a)
mpc.peek(c)
d = a @ c
mpc.peek([c, d])
self.assertEqual(mpc.run(mpc.output(d)), b)
e = ~c
f = e @ b
Expand Down Expand Up @@ -70,6 +72,7 @@ def test_QR_SG(self):
self.assertEqual(mpc.run(mpc.output(h)), g2)

a = secgrp(g)
mpc.peek([a, a])
self.assertRaises(TypeError, operator.truediv, 2, a)
self.assertRaises(TypeError, operator.add, a, a)
self.assertRaises(TypeError, operator.add, g, a)
Expand Down Expand Up @@ -99,6 +102,7 @@ def test_EC(self):
self.assertEqual(mpc.run(mpc.output(2*secgrp(g))), g^2)
bp4 = 4*g
sec_bp4 = 4*secgrp(g) + secgrp.identity
mpc.peek([sec_bp4, sec_bp4])
self.assertEqual(mpc.run(mpc.output(-sec_bp4)), -bp4)
sec_bp8 = secgrp.repeat(bp4, secfld(2))
self.assertEqual(mpc.run(mpc.output(sec_bp8)), bp4 + bp4)
Expand Down Expand Up @@ -141,6 +145,7 @@ def test_Cl(self):
secgrp = mpc.SecGrp(group)
g = group.generator
a = secgrp(g)^6
mpc.peek(a)
self.assertEqual(mpc.run(mpc.output(a)), g^6)
self.assertEqual(mpc.run(mpc.output(a * (a^-1))), group.identity)
m, z = group.encode(5)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_seclists.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

class Arithmetic(unittest.TestCase):

@classmethod
def setUpClass(cls):
mpc.logging(False)

def test_secfld(self):
secfld = mpc.SecFld(101)
s = seclist([], secfld)
Expand All @@ -25,6 +29,7 @@ def test_secfld(self):
s[5] = 9
del s[2:4]
self.assertEqual(mpc.run(mpc.output(list(s))), [1, 2, 6, 9])
mpc.peek(list(s))

secfld2 = mpc.SecFld()
self.assertRaises(TypeError, seclist, [secfld(1)], secfld2)
Expand Down

0 comments on commit 94f820e

Please sign in to comment.