Skip to content

Commit

Permalink
Merge pull request #203 from EYBlockchain/swati/publicParam
Browse files Browse the repository at this point in the history
Public Params in Function Signature
  • Loading branch information
MirandaWood authored Jul 10, 2023
2 parents a47a3cd + 8e7261b commit ddc0644
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,12 @@ class FunctionBoilerplateGenerator {
},

getIndicators() {
const { indicators } = this.scope;
const { indicators, msgSigRequired } = this.scope;
const isConstructor = this.scope.path.node.kind === 'constructor' ? true : false;



const { nullifiersRequired, oldCommitmentAccessRequired, msgSenderParam, msgValueParam, containsAccessedOnlyState, encryptionRequired } = indicators;
const newCommitmentsRequired = indicators.newCommitmentsRequired;
const nullifierRootRequired = indicators.nullifiersRequired;
return { nullifierRootRequired,nullifiersRequired, oldCommitmentAccessRequired, newCommitmentsRequired, msgSenderParam, msgValueParam, containsAccessedOnlyState, isConstructor, encryptionRequired };
return { nullifiersRequired, oldCommitmentAccessRequired, newCommitmentsRequired, msgSenderParam, msgValueParam, containsAccessedOnlyState, isConstructor, encryptionRequired };
},

parameters() {
Expand All @@ -91,38 +88,47 @@ class FunctionBoilerplateGenerator {
postStatements(customInputs: any[] = []) {
const { scope } = this;
const { path } = scope;


const customInputsMap = (node: any) => {
if (path.isStruct(node)) {
const structDef = path.getStructDeclaration(node);
const names = structDef.members.map((mem: any) => {
return { name: `${node.name}.${mem.name}`, type: mem.typeName.name };
});
return { structName: structDef.name, properties: names, isParam: path.isFunctionParameter(node) };
return { structName: structDef.name, properties: names, isParam: path.isFunctionParameter(node), inCircuit: node.interactsWithSecret };
}
return { name: node.name, type: node.typeName.name, isParam: path.isFunctionParameter(node) };
return { name: node.name, type: node.typeName.name, isParam: path.isFunctionParameter(node), inCircuit: node.interactsWithSecret };
}

const params = path.getFunctionParameters();
const publicParams = params?.filter((p: any) => (!p.isSecret && p.interactsWithSecret)).map((p: any) => customInputsMap(p)).concat(customInputs);

const publicParams = params?.filter((p: any) => !p.isSecret).map((p: any) => customInputsMap(p)).concat(customInputs);
const functionName = path.getUniqueFunctionName();
const indicators = this.customFunction.getIndicators.bind(this)();



// special check for msgSender and msgValue param. If msgsender is found, prepend a msgSender uint256 param to the contact's function.
if (indicators.msgSenderParam) publicParams.unshift({ name: 'msg.sender', type:'address' , dummy: true});
if (indicators.msgValueParam) publicParams.unshift({ name: 'msg.value', type:'uint256' , dummy: true});
if (indicators.msgSenderParam) publicParams.unshift({ name: 'msg.sender', type:'address' , dummy: true, inCircuit: true});
if (indicators.msgValueParam) publicParams.unshift({ name: 'msg.value', type:'uint256' , dummy: true, inCircuit: true});
let internalFunctionEncryptionRequired = false;


path.node._newASTPointer.body.statements?.forEach((node) => {
if(node.expression?.nodeType === 'InternalFunctionCall')
if(node.expression.parameters.includes('cipherText') )
internalFunctionEncryptionRequired = true
if(node.expression?.nodeType === 'InternalFunctionCall'){
if(node.expression.parameters.includes('cipherText') )
internalFunctionEncryptionRequired = true

}

})


if(path.node.returnParameters.parameters.length === 0 && !indicators.encryptionRequired && !internalFunctionEncryptionRequired) {
publicParams?.push({ name: 1, type: 'uint256', dummy: true });
publicParams?.push({ name: 1, type: 'uint256', dummy: true , inCircuit: true });
}

return {
...(publicParams?.length && { customInputs: publicParams }),
functionName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,35 @@ class FunctionBoilerplateGenerator {
postStatements({
functionName,
customInputs, // array of custom input names
nullifierRootRequired : nullifierRootRequired,
isConstructor,
nullifiersRequired: newNullifiers,
oldCommitmentAccessRequired: commitmentRoot,
newCommitmentsRequired: newCommitments,
encryptionRequired,
isConstructor
encryptionRequired
}): string[] {
// prettier-ignore

let parameter = [
...(customInputs ? customInputs.filter(input => !input.dummy && input.isParam).map(input => input.structName ? `(${input.properties.map(p => p.type)})` : input.type) : []),
...(nullifierRootRequired ? [`uint256`] : []),
...(nullifierRootRequired ? [`uint256`] : []),
...(newNullifiers ? [`uint256`] : []),
...(newNullifiers ? [`uint256`] : []),
...(newNullifiers ? [`uint256[]`] : []),
...(commitmentRoot ? [`uint256`] : []),
...(newCommitments ? [`uint256[]`] : []),
...(encryptionRequired ? [`uint256[][]`] : []),
...(commitmentRoot ? [`uint256`] : []),
...(newCommitments ? [`uint256[]`] : []),
...(encryptionRequired ? [`uint256[][]`] : []),
...(encryptionRequired ? [`uint256[2][]`] : []),
`uint256[]`,
].filter(para => para !== undefined); // Added for return parameter


customInputs?.forEach((input, i) => {
if (input.structName) customInputs[i] = input.properties;
});

let msgSigCheck = ([...(isConstructor ? [] : [`bytes4 sig = bytes4(keccak256("${functionName}(${parameter})")) ; \n \t \t \t if (sig == msg.sig)`])]);

let msgSigCheck = ([...(isConstructor ? [] : [`bytes4 sig = bytes4(keccak256("${functionName}(${parameter})")) ; \n \t \t \t if (sig == msg.sig)`])]);

customInputs = customInputs?.filter(p => p.inCircuit);

return [
`
Expand All @@ -96,10 +99,10 @@ class FunctionBoilerplateGenerator {
}).join('\n')}`]
: []),

...(nullifierRootRequired ? [`
...(newNullifiers ? [`
inputs.nullifierRoot = nullifierRoot; `] : []),

...(nullifierRootRequired ? [`
...(newNullifiers ? [`
inputs.latestNullifierRoot = latestNullifierRoot; `] : []),


Expand Down
9 changes: 6 additions & 3 deletions src/codeGenerators/contract/solidity/toContract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import FunctionBP from '../../../boilerplate/contract/solidity/raw/FunctionBoile
const contractBP = new ContractBP();
const functionBP = new FunctionBP();


function codeGenerator(node: any) {
// We'll break things down by the `type` of the `node`.
switch (node.nodeType) {
Expand Down Expand Up @@ -80,10 +81,12 @@ function codeGenerator(node: any) {
break;

}
const functionSignature = `${functionType} (${codeGenerator(node.parameters)}) ${node.visibility} ${node.stateMutability} {`;
const body = codeGenerator(node.body);


const functionSignature = `${functionType} (${codeGenerator(node.parameters)}) ${node.visibility} ${node.stateMutability} {`;
let body = codeGenerator(node.body);
let msgSigCheck = body.slice(body.indexOf('bytes4 sig'), body.indexOf('verify') )
if(!node.msgSigRequired)
body = body.replace(msgSigCheck, ' ');
return `
${functionSignature}
Expand Down
19 changes: 4 additions & 15 deletions src/transformers/visitors/circuitInternalFunctionCallVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,7 @@ const internalCallVisitor = {
if(childNode.nodeType === 'FunctionDefinition'){
state.newParameterList = cloneDeep(childNode.parameters.parameters);
state.newReturnParameterList = cloneDeep(childNode.returnParameters.parameters);
// node._newASTPointer.forEach(file => {
// if(file.fileName === state.callingFncName[index].name){
// file.nodes.forEach(childNode => {
// if(childNode.nodeType === 'FunctionDefinition'){
// let callParameterList = cloneDeep(childNode.parameters.parameters);
// }
// })
// }
// })

state.newParameterList.forEach((node, nodeIndex) => {
if(node.nodeType === 'Boilerplate') {
for(const [id, oldStateName] of state.oldStateArray.entries()) {
Expand All @@ -47,11 +39,9 @@ const internalCallVisitor = {
for(const [id, oldStateName] of state.oldStateArray.entries()) {
if(oldStateName !== state.newStateArray[name][id].name)
node.name = state.newStateArray[name][id].name;
node.name = node.name.replace('_'+oldStateName, '_'+state.newStateArray[name][id].name)
if(state.newStateArray[name][id].memberName)
node.name = node.name.replace('_'+oldStateName, '_'+state.newStateArray[name][id].name)
if(state.newStateArray[name][id].memberName)
state.newParameterList.splice(nodeIndex,1);
else
node.name = node.name.replace(oldStateName, state.newStateArray[name][id].name)
}
}
})
Expand All @@ -62,8 +52,6 @@ const internalCallVisitor = {
node.name = node.name.replace('_'+oldStateName, '_'+state.newStateArray[name][id].name)
if(state.newStateArray[name][id].memberName)
state.state.newReturnParameterList.splice(nodeIndex,1);
else
node.name = node.name.replace(oldStateName, state.newStateArray[name][id].name)
}
})
}
Expand Down Expand Up @@ -124,6 +112,7 @@ const internalCallVisitor = {
state.circuitArguments.push(param);
}
});

node._newASTPointer.forEach(file => {
if(file.fileName === state.callingFncName[index].name){
file.nodes.forEach(childNode => {
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/visitors/toContractVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ export default {
if(node.nodeType === 'FunctionDefinition' && node.kind === 'function'){
state.internalFncName?.forEach( (name, index) => {
if(node.name === name) {
node.msgSigRequired = true;
state.postStatements ??= [];
state.postStatements = cloneDeep(node.body.postStatements);
}
Expand Down Expand Up @@ -271,20 +272,23 @@ export default {
enter(path: NodePath, state: any) {
const { node, parent } = path;
const isConstructor = node.kind === 'constructor';
state.msgSigRequired = false;
if(node.kind === 'fallback' || node.kind === 'receive')
{
node.fileName = node.kind;
state.functionName = node.kind;
}
else
state.functionName = path.getUniqueFunctionName();

const newNode = buildNode('FunctionDefinition', {
name: node.fileName || state.functionName,
id: node.id,
kind: node.kind,
stateMutability: node.stateMutability === 'payable'? node.stateMutability : '',
visibility: node.kind ==='function' ? 'public' : node.kind === 'constructor'? '': 'external',
isConstructor,
msgSigRequired: state.msgSigRequired,
});

node._newASTPointer = newNode;
Expand All @@ -302,12 +306,13 @@ export default {
exit(path: NodePath, state: any) {
// We populate the entire shield contract upon exit, having populated the FunctionDefinition's scope by this point.
const { node, scope } = path;

const newFunctionDefinitionNode = node._newASTPointer;

// Let's populate the `parameters` and `body`:
const { parameters } = newFunctionDefinitionNode.parameters;
const { postStatements, preStatements } = newFunctionDefinitionNode.body;


// if contract is entirely public, we don't want zkp related boilerplate
if (!path.scope.containsSecret && !(node.kind === 'constructor')) return;
Expand All @@ -334,9 +339,9 @@ export default {
bpSection: 'postStatements',
scope,
customInputs: state.customInputs,

}),
);

delete state?.customInputs;
},
},
Expand Down Expand Up @@ -891,7 +896,9 @@ DoWhileStatement: {
state.fnParameters.push(args[index]);

});
const params = [...(internalfnDefIndicators.nullifiersRequired? [`nullifierRoot, latestNullifierRoot, newNullifiers`] : []),
const params = [...(internalfnDefIndicators.nullifiersRequired? [`nullifierRoot`] : []),
...(internalfnDefIndicators.nullifiersRequired? [`latestNullifierRoot`] : []),
...(internalfnDefIndicators.nullifiersRequired? [`newNullifiers`] : []),
...(internalfnDefIndicators.oldCommitmentAccessRequired ? [`commitmentRoot`] : []),
...(internalfnDefIndicators.newCommitmentsRequired ? [`newCommitments`] : []),
...(internalfnDefIndicators.containsAccessedOnlyState ? [`checkNullifiers`] : []),
Expand Down
2 changes: 2 additions & 0 deletions src/types/solidity-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ export function buildNode(nodeType: string, fields: any = {}): any {
isConstructor,
kind,
stateMutability,
msgSigRequired,
body = buildNode('Block'),
parameters = buildNode('ParameterList'),
returnParameters = buildNode('ParameterList'), // TODO
Expand All @@ -161,6 +162,7 @@ export function buildNode(nodeType: string, fields: any = {}): any {
isConstructor,
kind,
stateMutability,
msgSigRequired,
body,
parameters,
returnParameters,
Expand Down

0 comments on commit ddc0644

Please sign in to comment.