From 3c69da139ae36d6fe7a3e68da455a37f83c06a0f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 12 Oct 2018 12:15:34 +1100 Subject: [PATCH] More examples. --- .gitignore | 1 - Programs/Source/blink.mpc | 100 ++++++++++++++++++++ Programs/Source/gc_and.mpc | 26 +++++ Programs/Source/gc_fixed_point_tutorial.mpc | 44 +++++++++ Programs/Source/gc_tutorial.mpc | 50 ++++++++++ Programs/Source/test_sbitfix.mpc | 53 +++++++++++ Programs/Source/test_sbitint.mpc | 38 ++++++++ 7 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 Programs/Source/blink.mpc create mode 100644 Programs/Source/gc_and.mpc create mode 100644 Programs/Source/gc_fixed_point_tutorial.mpc create mode 100644 Programs/Source/gc_tutorial.mpc create mode 100644 Programs/Source/test_sbitfix.mpc create mode 100644 Programs/Source/test_sbitint.mpc diff --git a/.gitignore b/.gitignore index aa53a0df1..134c8c088 100644 --- a/.gitignore +++ b/.gitignore @@ -45,7 +45,6 @@ callgrind.out.* # Compiled source # ################### -Programs/Source/* Programs/Bytecode/* Programs/Schedules/* Programs/Public-Input/* diff --git a/Programs/Source/blink.mpc b/Programs/Source/blink.mpc new file mode 100644 index 000000000..4e632994b --- /dev/null +++ b/Programs/Source/blink.mpc @@ -0,0 +1,100 @@ +import math +import util + +n_threads = 64 +xor_op = lambda x, y: x ^ y +n_bits = 64 +full_t = sbits.get_type(64) +sbits.n = n_bits + +if len(program.args) > 1: + n_batches = int(program.args[1]) +else: + n_batches = 78 + +batch_size = 64 +n = n_batches * batch_size +l = 16 +a = Matrix(n, l, full_t) +b = Matrix(n, l, full_t) +t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1) +matches = Matrix(n, n, t.bit_type) +mismatches = Matrix(n, n, t) +threshold = MemValue(t(10)) + +for i in range(n): + for j in range(l): + a[i][j] = full_t.get_input_from(0) + b[i][j] = full_t.get_input_from(1) + +# test, create match between a[0] and b[1] but no match for a[1] +a.assign_all(0) +b.assign_all(0) +a[0][0] = -1 +b[1][0] = -1 +a[1][1] = -1 + +@for_range_multithread(n_batches, 1, n) +def _(i): + print_ln('%s', i) + @for_range_parallel(100, n_batches) + def _(j): + j = j * batch_size + av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \ + for kk in range(l)) + bv = sbitintvec.from_matrix((b[j + k][kk] for k in range(batch_size)) \ + for kk in range(l)) + res = xor_op(av, bv).popcnt() + mismatches[i].set_range(j, (t(x) for x in res.elements())) + +@for_range_multithread(n_batches, 8, n) +def _(i): + print_ln('%s', i) + @for_range_parallel(100, n_batches) + def _(j): + j = j * batch_size + v = sbitintvec(mismatches[i].get_range(j, batch_size)) + vv = sbitintvec([threshold.read()] * batch_size) + matches[i].set_range(j, v.less_than(vv, 10).elements()) + +mg = MultiArray([n_batches, n, t.n], full_t) +ag = Matrix(n_batches, n, full_t) + +@for_range_multithread(n_batches, 1, n_batches) +def _(i): + m = mg[i] + a = ag[i] + i = i * batch_size + print_ln('best %s', i) + @for_range(n) + def _(j): + m[j].assign(sbitintvec(mismatches[i + k][j] + for k in range(batch_size)).v) + m = [sbitintvec.from_vec(m[j]) for j in range(n)] + def reducer(a, b): + c = a[0].less_than(b[0]) + return util.if_else(c, (a[0], a[1] + [0] * len(b[1])), + (b[0], [0] * len(a[1]) + b[1])) + mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in m)) + a.assign(mm[1]) + @for_range_parallel(100, len(a)) + def _(j): + x = a[j] + pm = sbitintvec(matches[i + k][j] for k in range(batch_size)) + x = sbitintvec.from_vec([x]) + for k, y in enumerate((pm & x).elements()): + matches[i + k][j] = y + +def test(result, expected): + print_ln('%s ?= %s', result.reveal(), expected) + +test(matches[0][1], 1) +test(matches[0][0], 0) +test(matches[1][0], 0) +test(matches[1][1], 0) +test(sum(matches[2]), 1) + +test(mismatches[0][1], 0) +test(mismatches[0][0], 64) +test(mismatches[1][0], 64) +test(mismatches[1][1], 128) diff --git a/Programs/Source/gc_and.mpc b/Programs/Source/gc_and.mpc new file mode 100644 index 000000000..be9fd8f0a --- /dev/null +++ b/Programs/Source/gc_and.mpc @@ -0,0 +1,26 @@ +from Compiler.GC.types import sbits, sbit, cbits + + +import random + +n = 4096 +m = 1 + +if len(program.args) > 1: + n = int(program.args[1]) + +if len(program.args) > 2: + m = int(program.args[2]) + +pack = min(n, 50) +n = (n + pack - 1) / pack + +a = sbit(1) +b = sbit(1, n=pack) + +start_timer(1) +@for_range(m) +def f(_): + for i in range(n): + a * b +stop_timer(1) diff --git a/Programs/Source/gc_fixed_point_tutorial.mpc b/Programs/Source/gc_fixed_point_tutorial.mpc new file mode 100644 index 000000000..fd7f392ff --- /dev/null +++ b/Programs/Source/gc_fixed_point_tutorial.mpc @@ -0,0 +1,44 @@ +sfix = sbitfix +sint = sbitint.get_type(20) + +sfix.set_precision(16, 32) + +n = 10 +m = 5 + +# array of fixed points +A = Array(n, sfix) + +for i in range(n): + A[i] = sfix(i) + +print_ln('mrray of fixed points') +for i in range(n): + print_ln('%s', A[i].reveal()) + +# matrix of fixed points +M = Matrix(n, m, sfix) + +for i in range(n): + for j in range(m): + M[i][j] = sfix(i*j) + +print_ln('matrix of fixed points') +for i in range(n): + for j in range(m): + print_str('%s ', M[i][j].reveal()) + print_ln(' ') + + +# assign scalar to sfix +A[5] = sfix(1.12345) +print_ln('%s', A[5].reveal()) + +# assign sint to sfix +s = sint(10) +sa = sfix(); sa.load_int(s) +print_ln('successfully assigned sint to sfix %s', sa.reveal()) + +# division between fixed points +sb = sfix(2.5) +print_ln('division between %s %s = %s', sa.reveal(), sb.reveal(), (sa/sb).reveal()) diff --git a/Programs/Source/gc_tutorial.mpc b/Programs/Source/gc_tutorial.mpc new file mode 100644 index 000000000..48591c33a --- /dev/null +++ b/Programs/Source/gc_tutorial.mpc @@ -0,0 +1,50 @@ +# sbitint: factory for signed integer types + +sint = sbitint.get_type(32) + +def test(a, b, value_type=None): + try: + a = a.reveal() + except AttributeError: + pass + import inspect + print_ln('line %s: diff %s, got %s, expected %s', + inspect.currentframe().f_back.f_lineno, \ + (a ^ cbits(b, n=a.n)).reveal(), a, hex(b)) + +a = sint(1) +b = sint(2) + +test(a + b, 3) +test(a + a, 2) +test(a * b, 2) +test(a * a, 1) +test(a - b, -1) +test(a < b, 1) +test(a <= b, 1) +test(a >= b, 0) +test(a > b, 0) +test(a == b, 0) +test(a != b, 1) + +clear_a = a.reveal() + +# arrays and loops + +a = Array(100, sint) + +@for_range(100) +def f(i): + a[i] = sint(i)**2 + +test(a[99], 99**2) + +# conditional + +if_then(regint(0)) +a[0] = 123 +else_then() +a[0] = 789 +end_if() + +test(a[0], 789) diff --git a/Programs/Source/test_sbitfix.mpc b/Programs/Source/test_sbitfix.mpc new file mode 100644 index 000000000..ac9a197fa --- /dev/null +++ b/Programs/Source/test_sbitfix.mpc @@ -0,0 +1,53 @@ +from Compiler.GC.types import sbitfix, cbits + +#sbitfix.set_precision(3, 7) + +def test(a, b, value_type=None): + try: + b = int(round((b * (1 << a.f)))) + a = a.v.reveal() + except AttributeError: + pass + try: + a = a.reveal() + except AttributeError: + pass + import inspect + print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \ + (a ^ cbits(b)).reveal(), a, (b)) + +aa = 5321.0 +bb = 142.0 + +for a_sign, b_sign in (1, -1), (-1, -1): + a = a_sign * aa + b = b_sign * bb + + sa = sbitfix(a) + sb = sbitfix(b) + + test(sa + sb, a+b) + test(sa - sb, a-b) + test(sa * sb, a*b) + test(sa / sb, a/b) + + test(-sa, -a) + +a = 126 +b = 125 +sa = sbitfix(a) +sb = sbitfix(b) + +test(sa < sb, int(a sb, int(a>b)) +test(sa <= sb, int(a<=b)) +test(sa >= sb, int(a>=b)) +test(sa == sb, int(a==b)) +test(sa != sb, int(a!=b)) +test(sa != sa, int(a!=a)) diff --git a/Programs/Source/test_sbitint.mpc b/Programs/Source/test_sbitint.mpc new file mode 100644 index 000000000..ff4a4e13d --- /dev/null +++ b/Programs/Source/test_sbitint.mpc @@ -0,0 +1,38 @@ +program.options.merge_opens = False + +from Compiler.GC.types import * + +def test(a, b, value_type=None): + try: + a = a.reveal() + except AttributeError: + pass + import inspect + print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \ + (a ^ cbits(b)).reveal(), a, hex(b)) + +si32 = sbitint.get_type(32) + +test(si32(3) + si32(2), 5) +test(si32(3) - si32(2), 1) +test(si32(3) < si32(2), 0) +test(si32(3) > si32(2), 1) +test(si32(2) <= si32(2), 1) +test((si32(0) < si32(1)).if_else(si32(1), si32(2)) + si32(3), 4) + +test(si32(3) * si32(2), 6) +test(3 * si32(2), 6) +test(si32(3) * 2, 6) + +test(si32(-1), 2**32 - 1) +test(si32(-1) + si32(3), 2) +test(si32(-1) - si32(-2), 1) + +test(si32(1) * 2 * 2, 4) + +for i in range(3, 32): + t = sbitint.get_type(i) + test(t(3) + t(2), 5) + +test(abs(si32(-2)), 2) +test(abs(si32(2)), 2)