Skip to content

Commit

Permalink
Improve UX. Restructure type generation function. (#15)
Browse files Browse the repository at this point in the history
* Improved UX. Restructure type generation function.

* Updated versions

* Made type generation backwards compatible with previous version of package
  • Loading branch information
KyrylR authored Oct 17, 2024
1 parent 7a360e8 commit 9735c9c
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 144 deletions.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/zktype",
"version": "0.4.0-rc.0",
"version": "0.4.0-rc.1",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
5 changes: 3 additions & 2 deletions src/core/BaseTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ export default class BaseTSGenerator {
* Extracts the type name from the circuit artifact.
*
* @param {CircuitArtifact} circuitArtifact - The circuit artifact from which the type name is extracted.
* @param protocolType - The protocol type to be added to the type name.
* @param {string} [prefix=""] - The prefix to be added to the type name.
* @returns {string} The extracted type name.
*/
protected _getTypeName(circuitArtifact: CircuitArtifact, prefix: string = ""): string {
return `${prefix}${circuitArtifact.circuitTemplateName.replace(path.extname(circuitArtifact.circuitTemplateName), "")}`;
protected _getTypeName(circuitArtifact: CircuitArtifact, protocolType: string, prefix: string = ""): string {
return `${prefix}${circuitArtifact.circuitTemplateName.replace(path.extname(circuitArtifact.circuitTemplateName), "")}${protocolType}`;
}

/**
Expand Down
142 changes: 77 additions & 65 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ZkitTSGenerator from "./ZkitTSGenerator";
import { normalizeName } from "../utils";

import { Formats } from "../constants";
import { CircuitArtifact, ArtifactWithPath, GeneratedCircuitWrapperResult } from "../types";
import { CircuitArtifact, GeneratedCircuitWrapperResult, CircuitSet } from "../types";

/**
* `CircuitTypesGenerator` is need for generating TypeScript bindings based on circuit artifacts.
Expand Down Expand Up @@ -65,7 +65,7 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
fs.mkdirSync(this.getOutputTypesDir(), { recursive: true });

const isNameExist: Map<string, boolean> = new Map();
const typePathsToResolve: ArtifactWithPath[] = [];
const circuitSet: CircuitSet = {};

for (let i = 0; i < circuitArtifacts.length; i++) {
const circuitName = circuitArtifacts[i].circuitTemplateName;
Expand Down Expand Up @@ -102,15 +102,21 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {

this._saveFileContent(circuitTypePath, preparedNode.content);

typePathsToResolve.push({
if (!circuitSet[circuitName]) {
circuitSet[circuitName] = [];
}

circuitSet[circuitName].push({
circuitArtifact: circuitArtifacts[i],
pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath),
protocol: circuitArtifacts[i].baseCircuitInfo.protocol.length > 1 ? preparedNode.prefix : undefined,
protocol: preparedNode.protocol,
});
}
}

await this._resolveTypePaths(typePathsToResolve);
await this._resolveTypePaths(circuitSet);
await this._saveMainIndexFile(circuitSet);
await this._saveHardhatZkitTypeExtensionFile(circuitSet);

// copy utils to types output dir
const utilsDirPath = this.getOutputTypesDir();
Expand All @@ -119,81 +125,87 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
}

/**
* Generates the index files in the `TYPES_DIR` directory and its subdirectories.
*
* @param {ArtifactWithPath[]} typePaths - The paths to the generated files and the corresponding circuit artifacts.
* Generates the index files in the subdirectories of the `TYPES_DIR` directory.
*/
private async _resolveTypePaths(typePaths: ArtifactWithPath[]): Promise<void> {
private async _resolveTypePaths(circuitSet: CircuitSet): Promise<void> {
const rootTypesDirPath = this.getOutputTypesDir();
const pathToMainIndexFile = path.join(rootTypesDirPath, "index.ts");

// index file path => its content
const indexFilesMap: Map<string, string[]> = new Map();
const isCircuitNameExist: Map<string, number> = new Map();

const topLevelCircuits: {
[circuitName: string]: ArtifactWithPath[];
} = {};

for (const typePath of typePaths) {
const levels: string[] = typePath.pathToGeneratedFile
.replace(this.getOutputTypesDir(), "")
.split(path.sep)
.filter((level) => level !== "");

for (let i = 0; i < levels.length; i++) {
const pathToIndexFile =
i === 0
? path.join(rootTypesDirPath, "index.ts")
: path.join(rootTypesDirPath, levels.slice(0, i).join(path.sep), "index.ts");

const exportDeclaration =
path.extname(levels[i]) === ".ts"
? this._getExportDeclarationForFile(levels[i])
: this._getExportDeclarationForDirectory(levels[i]);

if (
indexFilesMap.get(pathToIndexFile) === undefined ||
!indexFilesMap.get(pathToIndexFile)?.includes(exportDeclaration)
) {
indexFilesMap.set(pathToIndexFile, [
...(indexFilesMap.get(pathToIndexFile) === undefined ? [] : indexFilesMap.get(pathToIndexFile)!),
exportDeclaration,
]);
const indexFilesMap: Map<string, Set<string>> = new Map();

for (const [, artifactWithPaths] of Object.entries(circuitSet)) {
for (const artifactWithPath of artifactWithPaths) {
const levels: string[] = artifactWithPath.pathToGeneratedFile
.replace(this.getOutputTypesDir(), "")
.split(path.sep)
.filter((level) => level !== "");

for (let i = 1; i < levels.length; i++) {
const pathToIndexFile = path.join(rootTypesDirPath, levels.slice(0, i).join(path.sep), "index.ts");

if (!indexFilesMap.has(pathToIndexFile)) {
indexFilesMap.set(pathToIndexFile, new Set());
}

const exportDeclaration =
path.extname(levels[i]) === ".ts"
? this._getExportDeclarationForFile(levels[i])
: this._getExportDeclarationForDirectory(levels[i]);

if (
indexFilesMap.get(pathToIndexFile) === undefined ||
!indexFilesMap.get(pathToIndexFile)!.has(exportDeclaration)
) {
indexFilesMap.set(pathToIndexFile, indexFilesMap.get(pathToIndexFile)!.add(exportDeclaration));
}
}
}
}

for (const [absolutePath, content] of indexFilesMap) {
this._saveFileContent(path.relative(this.getOutputTypesDir(), absolutePath), Array.from(content).join("\n"));
}
}

private async _saveMainIndexFile(circuitSet: CircuitSet): Promise<void> {
let mainIndexFileContent = this._getExportDeclarationForDirectory(CircuitTypesGenerator.DOMAIN_SEPARATOR) + "\n";

for (const [, artifactWithPaths] of Object.entries(circuitSet)) {
let isCircuitNameOverlaps = false;
const seenProtocols: string[] = [];

for (const artifactWithPath of artifactWithPaths) {
if (seenProtocols.includes(artifactWithPath.protocol)) {
isCircuitNameOverlaps = true;
break;
}

const circuitName = typePath.circuitArtifact.circuitTemplateName;
seenProtocols.push(artifactWithPath.protocol);
}

if (
isCircuitNameExist.get(circuitName) === undefined ||
isCircuitNameExist.get(circuitName)! < typePath.circuitArtifact.baseCircuitInfo.protocol.length
) {
indexFilesMap.set(pathToMainIndexFile, [
...(indexFilesMap.get(pathToMainIndexFile) === undefined ? [] : indexFilesMap.get(pathToMainIndexFile)!),
this._getExportDeclarationForFile(path.relative(this._projectRoot, levels.join(path.sep))),
]);
if (isCircuitNameOverlaps) {
continue;
}

isCircuitNameExist.set(
circuitName,
isCircuitNameExist.get(circuitName) === undefined ? 1 : isCircuitNameExist.get(circuitName)! + 1,
);
for (const artifactWithPath of artifactWithPaths) {
const levels: string[] = artifactWithPath.pathToGeneratedFile
.replace(this.getOutputTypesDir(), "")
.split(path.sep)
.filter((level) => level !== "");

topLevelCircuits[circuitName] =
topLevelCircuits[circuitName] === undefined ? [typePath] : [...topLevelCircuits[circuitName], typePath];
}
const exportPathToCircuitType = this._getExportDeclarationForFile(
path.relative(this._projectRoot, levels.join(path.sep)),
);

for (const [absolutePath, content] of indexFilesMap) {
this._saveFileContent(path.relative(this.getOutputTypesDir(), absolutePath), content.join("\n"));
mainIndexFileContent += exportPathToCircuitType + "\n";
}
}

const pathToTypesExtensionFile = path.join(rootTypesDirPath, "hardhat.d.ts");
this._saveFileContent("index.ts", mainIndexFileContent);
}

this._saveFileContent(
path.relative(this.getOutputTypesDir(), pathToTypesExtensionFile),
await this._genHardhatZkitTypeExtension(topLevelCircuits),
);
private async _saveHardhatZkitTypeExtensionFile(circuitSet: CircuitSet): Promise<void> {
this._saveFileContent("hardhat.d.ts", await this._genHardhatZkitTypeExtension(circuitSet));
}

/**
Expand Down
37 changes: 13 additions & 24 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,23 @@ import prettier from "prettier";
import BaseTSGenerator from "./BaseTSGenerator";

import {
ArtifactWithPath,
CircuitArtifact,
CircuitClass,
Inputs,
TypeExtensionTemplateParams,
DefaultWrapperTemplateParams,
WrapperTemplateParams,
SignalInfo,
GeneratedCircuitWrapperResult,
CircuitSet,
ProtocolType,
} from "../types";

import { normalizeName } from "../utils";
import { SignalTypeNames, SignalVisibilityNames } from "../constants";
import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol";

export default class ZkitTSGenerator extends BaseTSGenerator {
protected async _genHardhatZkitTypeExtension(circuits: {
[circuitName: string]: ArtifactWithPath[];
}): Promise<string> {
protected async _genHardhatZkitTypeExtension(circuits: CircuitSet): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "type-extension.ts.ejs"), "utf8");

const circuitClasses: CircuitClass[] = [];
Expand Down Expand Up @@ -90,11 +88,9 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<GeneratedCircuitWrapperResult[]> {
this._validateCircuitArtifact(circuitArtifact);

const result: GeneratedCircuitWrapperResult[] = [];

const unifiedProtocolType = new Set(circuitArtifact.baseCircuitInfo.protocol);
const unifiedProtocolType = this._getUnifiedProtocolType(circuitArtifact);
for (const protocolType of unifiedProtocolType) {
const content = await this._genSingleCircuitWrapperClassContent(
circuitArtifact,
Expand All @@ -109,16 +105,6 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
return result;
}

protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8");

const templateParams: DefaultWrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

private async _genSingleCircuitWrapperClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
Expand Down Expand Up @@ -175,20 +161,21 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
protocolImplementerName: this._getProtocolImplementerName(protocolType),
proofTypeInternalName: this._getProofTypeInternalName(protocolType),
circuitClassName,
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
publicInputsTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Public"),
calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount),
publicInputs,
privateInputs,
calldataPointsType: this._getCalldataPointsType(protocolType),
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
proofTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Proof"),
calldataTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Calldata"),
privateInputsTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
};

return {
content: await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }),
className: circuitClassName,
prefix: this._getPrefix(protocolType).toLowerCase(),
protocol: protocolType,
};
}

Expand Down Expand Up @@ -250,9 +237,11 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
}
}

private _validateCircuitArtifact(circuitArtifact: CircuitArtifact): void {
private _getUnifiedProtocolType(circuitArtifact: CircuitArtifact): Set<ProtocolType> {
if (!circuitArtifact.baseCircuitInfo.protocol) {
throw new Error(`ZKType: Protocol is missing in the circuit artifact: ${circuitArtifact.circuitTemplateName}`);
return new Set(["groth16"]);
}

return new Set(circuitArtifact.baseCircuitInfo.protocol);
}
}
4 changes: 2 additions & 2 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export type <%= proofTypeName %> = {
publicSignals: <%= publicInputsTypeName %>;
}

export type Calldata = [
export type <%= calldataTypeName %> = [
<%= calldataPointsType %>,
<%= calldataPubSignalsType %>,
];
Expand Down Expand Up @@ -59,7 +59,7 @@ export class <%= circuitClassName %> extends CircuitZKit<"<%= protocolTypeName %
});
}

public async generateCalldata(proof: <%= proofTypeName %>): Promise<Calldata> {
public async generateCalldata(proof: <%= proofTypeName %>): Promise<<%= calldataTypeName %>> {
return super.generateCalldata({
proof: proof.proof,
publicSignals: this._denormalizePublicSignals(proof.publicSignals),
Expand Down
37 changes: 0 additions & 37 deletions src/core/templates/default-circuit-wrapper.ts.ejs

This file was deleted.

2 changes: 2 additions & 0 deletions src/types/circuitArtifact.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export type FormatTypes = "hh-zkit-artifacts-1";

export type ProtocolType = "groth16" | "plonk";

export type SignalType = "Output" | "Input" | "Intermediate";

export type VisibilityType = "Public" | "Private";
Expand Down
Loading

0 comments on commit 9735c9c

Please sign in to comment.