From 90534fd7ee539c60d694b0fc0b7885cabd490240 Mon Sep 17 00:00:00 2001 From: DAMIE Marc Date: Sat, 22 Apr 2023 23:27:29 +0200 Subject: [PATCH] Add utility functions to runtime --- mpyc/runtime.py | 23 +++++++++++++++++++++-- tests/test_runtime.py | 4 ++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mpyc/runtime.py b/mpyc/runtime.py index 50ac90af..871c2be5 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -259,6 +259,25 @@ async def start(self): server.close() self.start_time = time.time() + def elapsed_time(self): + """Return the elapsed time since the MPyC runtime started.""" + return time.time() - self.start_time + + def communication_cost(self, per_user=False): + """Return the number of bytes exchanged since the MPyC runtime started. + + If the per_user parameter is True, the function returns a list with the cost per user. + Otherwise, all costs are aggregated. + """ + per_user_cost = [peer.protocol.nbytes_sent if peer.pid != self.pid else 0 for peer in self.parties] + + if per_user: + return per_user_cost + else: + return sum(per_user_cost) + + + async def shutdown(self): """Shutdown the MPyC runtime. @@ -267,8 +286,8 @@ async def shutdown(self): # Wait for all parties behind a barrier. while self._pc_level > self._program_counter[1]: await asyncio.sleep(0) - elapsed = time.time() - self.start_time - nbytes = [peer.protocol.nbytes_sent if peer.pid != self.pid else 0 for peer in self.parties] + elapsed = self.elapsed_time() + nbytes = self.communication_cost(per_user=True) elapsed = datetime.timedelta(seconds=round(elapsed*1000)/1000) # round to milliseconds logging.info(f'Stop MPyC -- elapsed time: {str(elapsed)[:-3]}|bytes sent: {sum(nbytes)}') logging.debug(f'Bytes sent per party: {" ".join(map(str, nbytes))}') diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 898b31e3..571de0a1 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -854,6 +854,10 @@ def test_misc(self): cs_f = lambda b, i: [b * (2*i+1) + i**2, (b*2+1) * 3**i] self.assertEqual(mpc.run(mpc.output(mpc.find(x, 2, bits=False, cs_f=cs_f))), [4, 9]) + def test_utils(self): + self.assertEqual(mpc.communication_cost(), 0) + self.assertEqual(mpc.communication_cost(per_user=True), [0]) + self.assertGreater(mpc.elapsed_time(), 0) if __name__ == "__main__": unittest.main()