-
Notifications
You must be signed in to change notification settings - Fork 121
Developer Tutorial: Adding a new tensorflow operation in Cryptflow.
In this tutorial we will see how to add front-end support for a new tensorflow operation in Athos.
Say we want to implement tf.math.sin. First we write a unit test case to test that and place it in test_sin.py in Athos/tests/tf/unittests directory.
import tensorflow as tf
import numpy as np
import pytest
import sys
import os
# Athos DIR
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
from tests.utils import Config, Compiler, assert_almost_equal
def test_sin(test_dir, backend):
graph = tf.Graph()
a_inp = np.single(np.random.randn(2,2))
with graph.as_default():
a = tf.compat.v1.placeholder(tf.as_dtype(np.single), shape=a_inp.shape, name="a")
output = tf.math.sin(a, name="output")
with tf.compat.v1.Session(graph=graph) as sess:
expected_output = sess.run(output, feed_dict={a: a_inp})
config = Config(backend).add_input(a).add_output(output)
compiler = Compiler(graph, config, test_dir)
mpc_output = compiler.compile_and_run([a_inp])
assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
return
From the Athos/tests directory when we run pytest -rs . -k "test_sin" --backend="CPP"
we see that this test fails with the following output:
def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict):
curNodeOp = curNode.getOp()
ast = None
> func = getattr(TFNodesAST, curNodeOp)
E AttributeError: type object 'TFNodesAST' has no attribute 'Sin'
../../TFCompiler/ProcessTFGraph.py:
This tells us that the 'Sin' method is missing from the TFNodesAST class and so we will add that. In TFCompiler/TFNodesAST.py we add the following to the TFNodesAST class:
def Sin(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
inputsRef = curNode.getInputsRef()
assert(len(inputsRef)==1)
return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('sin'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))})
This method expects the following input:
- graph - the computation graph
- curNode - the current node being compiled
- dictNodeNameToOutVarStr - a mapping between node names and their output variables (eg: "input" -> "J0"). These variables are part of the generated program and there will be a corresponding variable for each node in the computation graph.
- extraNodeInfoDict - contains size information for the outputs of each node in the graph.
For the return value we return a function AST, with the sin operator and pass the variable name of the input node(dictNodeNameToOutVarStr[inputsRef[0]]) to sin as a parameter. This AST is define in Seedot/AST/AST.py
Now we need to add the sin operator to the Seedot AST in AST.py:
#
OperatorsSymbolDict = {
..
"TANH": 'tanh',
+ "SIN": 'sin',
"SIGMOID": 'sigmoid',
..
}
class Operators(Enum):
..
TANH = auto()
+ SIN = auto()
SIGMOID = auto()
..
Because we are just adding a simple function we are reusing the AST.Func node. If you were adding a more complex operator you would have to add your own custom node to the Seedot AST (eg: AST.ArgMax). We also need to teach the type inference pass about Sin. In Seedot/Type.py we add:
def visitFunc(self, node:AST.Func, args=None):
..
elif node.op == AST.Operators.TANH:
assert isTensor(eType)
node.type = copy.copy(eType)
+ elif node.op == AST.Operators.SIN:
+ assert isTensor(eType)
+ node.type = copy.copy(eType)
elif node.op == AST.Operators.SIGMOID:
assert isTensor(eType)
node.type = copy.copy(eType)
We assert that input to Sin is in fact a tensor and return the type of input as the type of sin node. (sin of a 2d tensor is a 2d tensor). Also if we had added a new node to the Seedot AST we would have had to add a visitor for that in AST/ASTVisitor.py and implement visitors in MtdAST.py, PrintAST.py, Type.py and also in GarbageCollector.py. However since we are just using the Func node, we don't need to do anything.
So far we have been able to add frontend support in Athos and Seedot. However Seedot does not know how to generate Sin code. If we run the pytest command at this point we see the following error:
(prog_1, expr_1) = self.visit(node.decl)
File "/home/bhatu/upstream/EzPC/Athos/SeeDot/AST/IRBuilderAST.py", line 31, in visit
ret = super().visit( node, args)
File "/home/bhatu/upstream/EzPC/Athos/SeeDot/AST/ASTVisitor.py", line 112, in visit
return self.visitFunc(node, args)
File "/home/bhatu/upstream/EzPC/Athos/SeeDot/IR/IRBuilderCSF.py", line 1013, in visitFunc
AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic])
AssertionError
It is basically an assert failure in visitFunc complaining that it cannot handle a AST.Operators.Sin node. IRBuilderCSF.py lowers the Seedot AST to a lowlevel IR which in turn compiles down to ezpc code. Most codegen is done in this file. Let's add support for sin now:
def visitFunc(self, node:AST.Func, args=None):
op = node.op
assert(op in [AST.Operators.Floor, AST.Operators.Shape, AST.Operators.RELU, AST.Operators.TANH,
- AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT,
+ AST.Operators.SIN, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT,
AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic])
return self.visitFloorLike(node)
def visitFloorLike(self, node:AST.Func, args=None):
..
elif node.op == AST.Operators.TANH:
funcName = "Tanh"
+ elif node.op == AST.Operators.SIN:
+ funcName = "Sin"
elif node.op == AST.Operators.SIGMOID:
funcName = "Sigmoid"
..
And..that's it. With this the compiler will be able insert function calls to Sin in the generated program. (Note: we are skipping over truncation related things required to be implemented in visitFloorLike for the sake of keeping this tutorial brief). Now if we run the same pytest command we get this error
Running ezpc compiler on generated file :::
./ezpc --bitlen 64 --codegen CPP --disable-tac --o_prefix ./model_64_cpp ./model_64_cpp__temp2.ezpc
Read file ./model_64_cpp__temp2.ezpc ...
Parsed file ./model_64_cpp__temp2.ezpc ...
Running the inference pass ...
Inferred binop labels and coercions ...
Type_error(2240,1-2240,23): Function Sin2 not found
We can inspect the generated ezpc code in /tmp/cryptflow_tests/athos_test_sin/model_64_cpp.ezpc
def void main(){
(* {'TFOpName': 'Placeholder', 'TFNodeName': 'a'} *)
input(CLIENT, tmp0, int64_al[2][2]);
StartComputation();
int64_al[2][2] tmp1;
(* {'TFOpName': 'Sin', 'TFNodeName': 'output'} *)
Sin2(2, 2, tmp0, tmp1);
(* {'TFOpName': 'No-op: ClearMem', 'TFNodeName': ''} *)
ClearMemSecret2(2, 2, tmp0);
EndComputation();
output(CLIENT, tmp1);
}
We can see that Sin2 accepts an input parameter tmp0 of shape [2,2] and writes the output in tmp1. The 2 in Sin2 is for 2D input tensors. Now we can implement the Sin2 function in Athos/TFEzPCLibrary/Library64_common_cpp_pre.ezpc file. We do a dummy version which just copies the input to the output.
def void Sin2(int32_pl s1, int32_pl s2, int64_al[s1][s2] inArr, int64_al[s1][s2] outArr)
{
for i1=[0:s1]{
for i2=[0:s2]{
outArr[i1][i2] = inArr[i1][i2];
};
};
}
Now on running the command the program successfully compiles and runs. The test still fails because the sin values are incorrect. But if we modify the test to compare the output of the compiled program with the input to sin itself:
- assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2)
+ assert_almost_equal(tf_output=a_inp, mpc_tensor=mpc_output, precision=2)
We see that the test passes now.
In this tutorial we took a simple operation and showed how to add support for it in the compiler. More complex operations can be added too but are a bit more involved as mentioned above. Feel free to study the code of different operations like add in the above mentioned files to get a better idea of how to implement operations.
The overall flow of the compiler is this:
- Athos/ProcessTFGraph.py processed tensorflow graph_def files and creates an in-memory Graph as per Graph.py.
- Athos/TFNodesAST.py then visits all the nodes in this graph and creates a Seedot AST and dumps it.
- SeeDot/Compiler.py takes this ast and runs optimisations like relumaxpoolopt, garbage collection, and then runs type inference on it.
- Seedot/IRBuilderCSF.py then visits this ast and generates low level IR.
- Seedot/Codegen/EzPC.py then iterates through this IR and dumps EzPC code.
- We link the generated ezpc code with the corresponding .ezpc libraries in Athos/TFEzPCLibrary.
- We run the ezpc compiler and generate target code which is then linked with Porthos/SCI or directly compiled if we are targeting the c++ backend for debugging.