Skip to content

Commit

Permalink
Fixed incorrect export of CircuitTypesGenerator class (#5)
Browse files Browse the repository at this point in the history
* Fixed incorrect export of `CircuitTypesGenerator` class

* Fixed paths in package.json

* Added an ability to set project root path in `ZKTypeConfig`

* Added validation if CircuitAST is undefined

* Added validation for ast.circomCompilerOutput

* Fixed `_findTemplateForCircuit` function

* Updated publish-to-npm script

* Added lint to publish-to-npm

* Updated publish-to-npm

* Deleted `_nameToObjectNameMap` field

* Cleaned up project
  • Loading branch information
KyrylR authored Jul 9, 2024
1 parent 3f85ef2 commit ef158b7
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 62 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## [v0.2.1]

- Fixed incorrect export of `CircuitTypesGenerator` class
- Added an ability to set project root path in `ZKTypeConfig`

## [v0.2.0]

- Resolved an issue where inputs could have the wrong number of dimensions, such as `bigint[]` when `bigint[][]` was expected.
Expand Down
8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"name": "@solarity/zktype",
"version": "0.2.0",
"version": "0.2.1",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/src/index.js",
"types": "dist/src/index.d.ts",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"engines": {
"node": ">=18"
},
Expand All @@ -20,7 +20,7 @@
"test": "npm run prepare-test && mocha --recursive 'test/**/*.ts' --exit",
"coverage": "prepare-test && nyc mocha --recursive 'test/**/*.ts' --exit",
"lint-fix": "prettier --write '**/*.ts'",
"publish-to-npm": "npm run build && npm run lint-fix && npm publish ./ --access public"
"publish-to-npm": "npm run build && npm run lint-fix && rm -rf dist/core/templates && cp -rf src/core/templates dist/core/templates && npm publish ./ --access public"
},
"nyc": {
"reporter": [
Expand Down
1 change: 1 addition & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { ZKTypeConfig } from "./types";

export const defaultCircuitArtifactGeneratorConfig: ZKTypeConfig = {
basePath: "circuits",
projectRoot: process.cwd(),
circuitsASTPaths: [],
outputArtifactsDir: "artifacts/circuits",
outputTypesDir: "generated-types/circuits",
Expand Down
61 changes: 58 additions & 3 deletions src/core/BaseTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import CircuitArtifactGenerator from "./CircuitArtifactGenerator";

import { CircuitArtifact, ZKTypeConfig } from "../types";

import { findProjectRoot } from "../utils";

/**
* `BaseTSGenerator` is a base class for all TypeScript generators.
*
Expand All @@ -33,7 +31,7 @@ export default class BaseTSGenerator {

this._artifactsGenerator = new CircuitArtifactGenerator(config);

this._projectRoot = findProjectRoot(process.cwd());
this._projectRoot = config.projectRoot;

this._printer = ts.createPrinter({ newLine: ts.NewLineKind.LineFeed });
this._resultFile = ts.createSourceFile("", "", ts.ScriptTarget.Latest, false, ts.ScriptKind.TS);
Expand Down Expand Up @@ -128,4 +126,61 @@ export default class BaseTSGenerator {

return preambleNodes.join("\n");
}

/**
* Returns the long path to the circuit type.
*
* The long path is the path that includes the source name and the circuit name.
*/
protected _getCircuitTypeLongPath(basePath: string, sourceName: string, circuitName: string): string {
return path.join(BaseTSGenerator.DOMAIN_SEPARATOR, sourceName.replace(basePath, ""), `${circuitName}.ts`);
}

/**
* Returns the short path to the circuit type.
*
* The short path is the path that includes ONLY the circuit name.
*/
protected _getCircuitTypeShortPath(basePath: string, sourceName: string, circuitName: string): string {
return path
.join(BaseTSGenerator.DOMAIN_SEPARATOR, sourceName.replace(basePath, ""))
.replace(path.basename(sourceName), `${circuitName}.ts`);
}

/**
* Returns the path to the generated file for the given circuit.
*
* The path can be either long or short, depending on the existence of the long path.
*/
protected _getPathToGeneratedFile(basePath: string, sourceName: string, circuitName: string): string {
const longObjectPath = this._getCircuitTypeLongPath(basePath, sourceName, circuitName);
const shortObjectPath = this._getCircuitTypeShortPath(basePath, sourceName, circuitName);

const isLongPathExist = this._checkIfCircuitExists(longObjectPath);
const isShortPathExist = this._checkIfCircuitExists(shortObjectPath);

if (!isLongPathExist && !isShortPathExist) {
throw new Error(`Circuit ${circuitName} type does not exist.`);
}

return isLongPathExist ? longObjectPath : shortObjectPath;
}

/**
* Checks if the circuit name is fully qualified.
*/
protected _isFullyQualifiedCircuitName(circuitName: string): boolean {
return circuitName.includes(":");
}

/**
* Checks if the circuit exists.
*
* Expects to get the path to the circuit file, relative to the directory where the generated types are stored.
*/
protected _checkIfCircuitExists(pathToCircuit: string): boolean {
const pathFromRoot = path.join(this._projectRoot, this.getOutputTypesDir(), pathToCircuit);

return fs.existsSync(pathFromRoot);
}
}
30 changes: 18 additions & 12 deletions src/core/CircuitArtifactGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import fs from "fs";
import path from "path";

import { findProjectRoot } from "../utils";

import { InternalType, SignalTypeNames, SignalVisibilityNames } from "../constants";
import {
Stmt,
Expand Down Expand Up @@ -40,7 +38,7 @@ export default class CircuitArtifactGenerator {
* @param {ArtifactGeneratorConfig} circuitArtifactGeneratorConfig - The configuration for the `CircuitArtifactGenerator`.
*/
constructor(circuitArtifactGeneratorConfig: ArtifactGeneratorConfig) {
this._projectRoot = findProjectRoot(process.cwd());
this._projectRoot = circuitArtifactGeneratorConfig.projectRoot;
this._circuitArtifactGeneratorConfig = circuitArtifactGeneratorConfig;
}

Expand Down Expand Up @@ -79,7 +77,11 @@ export default class CircuitArtifactGenerator {
* @returns {Promise<CircuitArtifact>} A promise that resolves to the extracted circuit artifact.
*/
public async extractArtifact(pathToTheAST: string): Promise<CircuitArtifact> {
const ast: CircuitAST = JSON.parse(fs.readFileSync(pathToTheAST, "utf-8"));
const ast: CircuitAST | undefined = JSON.parse(fs.readFileSync(pathToTheAST, "utf-8"));

if (!ast) {
throw new Error(`The circuit AST is missing. Path: ${pathToTheAST}`);
}

this._validateCircuitAST(ast);

Expand Down Expand Up @@ -192,18 +194,18 @@ export default class CircuitArtifactGenerator {
*/
private _findTemplateForCircuit(compilerOutputs: CircomCompilerOutput[], circuitName: string): Template {
for (const compilerOutput of compilerOutputs) {
if (
!compilerOutput.definitions ||
compilerOutput.definitions.length < 1 ||
!compilerOutput.definitions[0].Template
) {
if (!compilerOutput.definitions || compilerOutput.definitions.length < 1) {
continue;
}

const template = compilerOutput.definitions[0].Template;
for (const definition of compilerOutput.definitions) {
if (!definition.Template) {
continue;
}

if (template.name === circuitName) {
return template;
if (definition.Template.name === circuitName) {
return definition.Template;
}
}
}

Expand All @@ -218,6 +220,10 @@ export default class CircuitArtifactGenerator {
* @throws {Error} If the AST does not meet the expected structure.
*/
private _validateCircuitAST(ast: CircuitAST): void {
if (!ast.circomCompilerOutput) {
throw new Error(`The circomCompilerOutput field is missing in the circuit AST`);
}

if (
ast.circomCompilerOutput.length < 1 ||
!ast.circomCompilerOutput[0].main_component ||
Expand Down
42 changes: 23 additions & 19 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,31 @@ import { CircuitArtifact, ArtifactWithPath } from "../types";
*
* Note: Currently, all signals are considered as `bigint` type.
*/
export default class CircuitTypesGenerator extends ZkitTSGenerator {
export class CircuitTypesGenerator extends ZkitTSGenerator {
/**
* Returns an object that represents the circuit class based on the circuit name.
*/
public async getCircuitObject(circuitName: string): Promise<any> {
const pathToGeneratedTypes = path.join(this._projectRoot, this.getOutputTypesDir());

if (this._nameToObjectNameMap.size === 0) {
throw new Error("No circuit types have been generated.");
}

const module = await import(pathToGeneratedTypes);

const circuitObjectPath = this._nameToObjectNameMap.get(circuitName);
if (!this._isFullyQualifiedCircuitName(circuitName)) {
if (!module[circuitName]) {
throw new Error(`Circuit ${circuitName} type does not exist.`);
}

if (!circuitObjectPath) {
throw new Error(`Circuit ${circuitName} type does not exist.`);
return module[circuitName];
}

return circuitObjectPath.split(".").reduce((acc, key) => acc[key], module as any);
const parts = circuitName.split(":");
const pathsToModule = this._getPathToGeneratedFile(this._zktypeConfig.basePath, parts[0], parts[1]);

return this._getObjectFromModule(module, this._getObjectPath(pathsToModule));
}

private _getObjectFromModule(module: any, path: string): any {
return path.split(".").reduce((acc, key) => acc[key], module as any);
}

/**
Expand Down Expand Up @@ -66,18 +71,17 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator {
const isNameAlreadyExist = isNameExist.has(circuitName);
isNameExist.set(circuitName, true);

let circuitTypePath = path
.join(
BaseTSGenerator.DOMAIN_SEPARATOR,
circuitArtifacts[i].sourceName.replace(circuitArtifacts[i].basePath, ""),
)
.replace(path.basename(circuitArtifacts[i].sourceName), `${circuitName}.ts`);
let circuitTypePath = this._getCircuitTypeShortPath(
circuitArtifacts[i].basePath,
circuitArtifacts[i].sourceName,
circuitName,
);

if (isNameAlreadyExist) {
circuitTypePath = path.join(
BaseTSGenerator.DOMAIN_SEPARATOR,
circuitArtifacts[i].sourceName.replace(circuitArtifacts[i].basePath, ""),
`${circuitName}.ts`,
circuitTypePath = this._getCircuitTypeLongPath(
circuitArtifacts[i].basePath,
circuitArtifacts[i].sourceName,
circuitName,
);
}

Expand Down
29 changes: 10 additions & 19 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import { normalizeName } from "../utils";
import { SignalTypeNames, SignalVisibilityNames } from "../constants";

export default class ZkitTSGenerator extends BaseTSGenerator {
protected _nameToObjectNameMap: Map<string, string> = new Map();

protected async _genHardhatZkitTypeExtension(circuits: {
[circuitName: string]: ArtifactWithPath[];
}): Promise<string> {
Expand All @@ -29,8 +27,6 @@ export default class ZkitTSGenerator extends BaseTSGenerator {

const keys = Object.keys(circuits);

const outputTypesDir = this.getOutputTypesDir();

for (let i = 0; i < keys.length; i++) {
const artifacts = circuits[keys[i]];

Expand All @@ -40,28 +36,14 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
object: this._getCircuitName(artifacts[0].circuitArtifact),
});

this._nameToObjectNameMap.set(
this._getCircuitName(artifacts[0].circuitArtifact),
this._getCircuitName(artifacts[0].circuitArtifact),
);

continue;
}

for (const artifact of artifacts) {
const objectName = path
.normalize(artifact.pathToGeneratedFile.replace(outputTypesDir, ""))
.split(path.sep)
.filter((level) => level !== "")
.map((level, index, array) => (index !== array.length - 1 ? normalizeName(level) : level.replace(".ts", "")))
.join(".");

circuitClasses.push({
name: this._getFullCircuitName(artifact.circuitArtifact),
object: objectName,
object: this._getObjectPath(artifact.pathToGeneratedFile),
});

this._nameToObjectNameMap.set(this._getFullCircuitName(artifact.circuitArtifact), objectName);
}
}

Expand All @@ -72,6 +54,15 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

protected _getObjectPath(pathToGeneratedFile: string): string {
return path
.normalize(pathToGeneratedFile.replace(this.getOutputTypesDir(), ""))
.split(path.sep)
.filter((level) => level !== "")
.map((level, index, array) => (index !== array.length - 1 ? normalizeName(level) : level.replace(".ts", "")))
.join(".");
}

protected async _genCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

Expand Down
5 changes: 5 additions & 0 deletions src/types/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ export interface ArtifactGeneratorConfig {
*/
circuitsASTPaths: string[];

/**
* The absolute path to the root directory of the project.
*/
projectRoot: string;

/**
* The path to the directory where the generated artifacts will be stored.
*/
Expand Down
1 change: 1 addition & 0 deletions test/CircuitArtifactGenerator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ describe("Circuit Artifact Generation", function () {
const astDir = "test/cache/circuits-ast";
const artifactGenerator = new CircuitArtifactGenerator({
basePath: "test/fixture",
projectRoot: findProjectRoot(process.cwd()),
circuitsASTPaths: [
"test/cache/circuits-ast/Basic.json",
"test/cache/circuits-ast/credentialAtomicQueryMTPV2OnChainVoting.json",
Expand Down
11 changes: 8 additions & 3 deletions test/CircuitProofGeneration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import { CircuitZKitConfig } from "@solarity/zkit";

import { generateAST } from "./helpers/generator";

import CircuitTypesGenerator from "../src/core/CircuitTypesGenerator";
import { CircuitTypesGenerator } from "../src";
import { findProjectRoot } from "../src/utils";

describe("Circuit Proof Generation", function () {
const astDir = "test/cache/circuits-ast";

const circuitTypesGenerator = new CircuitTypesGenerator({
basePath: "test/fixture",
projectRoot: findProjectRoot(process.cwd()),
circuitsASTPaths: [
"test/cache/circuits-ast/Basic.json",
"test/cache/circuits-ast/credentialAtomicQueryMTPV2OnChainVoting.json",
Expand Down Expand Up @@ -49,8 +51,11 @@ describe("Circuit Proof Generation", function () {
new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVoting"))();
new (await circuitTypesGenerator.getCircuitObject("EnhancedMultiplier"))();

await expect(circuitTypesGenerator.getCircuitObject("Multiplier2")).to.be.rejectedWith(
"Circuit Multiplier2 type does not exist.",
await expect(circuitTypesGenerator.getCircuitObject("Multiplier3")).to.be.rejectedWith(
"Circuit Multiplier3 type does not exist.",
);
await expect(circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier3")).to.be.rejectedWith(
"Circuit Multiplier3 type does not exist.",
);
});

Expand Down
3 changes: 2 additions & 1 deletion test/CircuitTypesGenerator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { generateAST } from "./helpers/generator";

import { findProjectRoot } from "../src/utils";

import CircuitTypesGenerator from "../src/core/CircuitTypesGenerator";
import { CircuitTypesGenerator } from "../src";

describe("Circuit Types Generation", function () {
const expectedTypes = ["core/CredentialAtomicQueryMTPOnChainVoting.ts", "core/Multiplier2.ts"];
Expand All @@ -18,6 +18,7 @@ describe("Circuit Types Generation", function () {

const circuitTypesGenerator = new CircuitTypesGenerator({
basePath: "test/fixture",
projectRoot: findProjectRoot(process.cwd()),
circuitsASTPaths: [
"test/cache/circuits-ast/Basic.json",
"test/cache/circuits-ast/credentialAtomicQueryMTPV2OnChainVoting.json",
Expand Down
Loading

0 comments on commit ef158b7

Please sign in to comment.