Skip to content

Commit

Permalink
Created a proper gate manifest.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iluvmagick committed Aug 4, 2023
1 parent c7736d3 commit 3e2b053
Show file tree
Hide file tree
Showing 45 changed files with 1,201 additions and 497 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ namespace nil {
using var = typename component_type::var;
using manifest_type = nil::blueprint::plonk_component_manifest;

class gate_manifest_type : public component_gate_manifest {
public:
std::uint32_t gates_amount() const override {
return bool_scalar_multiplication::gates_amount;
}
};

static gate_manifest get_gate_manifest(std::size_t witness_amount,
std::size_t lookup_column_amount) {
static gate_manifest manifest = gate_manifest(gate_manifest_type());
return manifest;
}

static manifest_type get_manifest() {
static manifest_type manifest = manifest_type(
std::shared_ptr<manifest_param>(
Expand All @@ -71,11 +84,6 @@ namespace nil {
return 2;
}

constexpr static std::size_t get_total_gates_amount(std::size_t witness_amount,
std::size_t lookup_column_amount) {
return gates_amount;
}

const std::size_t rows_amount = get_rows_amount(this->witness_amount(), 0);

constexpr static const std::size_t gates_amount = 1;
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ namespace nil {
using var = typename component_type::var;
using manifest_type = nil::blueprint::plonk_component_manifest;

class gate_manifest_type : public component_gate_manifest {
public:
std::uint32_t gates_amount() const override {
return scalar_non_native_range::gates_amount;
}
};

static gate_manifest get_gate_manifest(std::size_t witness_amount,
std::size_t lookup_column_amount) {
static gate_manifest manifest = gate_manifest(gate_manifest_type());
return manifest;
}

static manifest_type get_manifest() {
static manifest_type manifest = manifest_type(
std::shared_ptr<nil::blueprint::manifest_param>(
Expand All @@ -67,11 +80,6 @@ namespace nil {
return 3;
}

constexpr static std::size_t get_total_gates_amount(std::size_t witness_amount,
std::size_t lookup_column_amount) {
return gates_amount;
}

const std::size_t rows_amount = get_rows_amount(this->witness_amount(), 0);
constexpr static const std::size_t gates_amount = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,16 @@
#include <nil/blueprint/components/algebra/curves/edwards/plonk/non_native/variable_base_multiplication_per_bit.hpp>
#include <nil/blueprint/components/algebra/curves/edwards/plonk/non_native/bool_scalar_multiplication.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/bit_decomposition.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bit_shift_constant.hpp>

namespace nil {
namespace blueprint {
namespace components {

namespace detail {
enum bit_shift_mode {
LEFT,
RIGHT,
};
} // namespace detail
using detail::bit_shift_mode;

template<typename ArithmetizationType, typename CurveType, typename Ed25519Type,
std::uint32_t WitnessesAmount, typename NonNativePolicyType>
typename NonNativePolicyType>
class variable_base_multiplication;

template<typename BlueprintFieldType,
Expand All @@ -55,45 +50,83 @@ namespace nil {
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
CurveType,
Ed25519Type,
9,
basic_non_native_policy<BlueprintFieldType>>:
public plonk_component<BlueprintFieldType, ArithmetizationParams, 9, 1, 0> {

constexpr static const std::uint32_t WitnessesAmount = 9;

using component_type = plonk_component<BlueprintFieldType, ArithmetizationParams, WitnessesAmount, 1, 0>;
public plonk_component<BlueprintFieldType, ArithmetizationParams, 1, 0> {

using component_type = plonk_component<BlueprintFieldType, ArithmetizationParams, 1, 0>;

constexpr static const std::size_t rows_amount_internal(std::size_t witness_amount,
std::size_t lookup_column_amount,
std::size_t bits_amount) {
return
decomposition_component_type::get_rows_amount(witness_amount, lookup_column_amount,
bits_amount) +
252 * mul_per_bit_component::get_rows_amount(witness_amount, lookup_column_amount) +
bool_scalar_mul_component::get_rows_amount(witness_amount, lookup_column_amount);
}

public:
using var = typename component_type::var;
using manifest_type = typename component_type::manifest_type;
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;

typedef crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>
ArithmetizationType;

using mul_per_bit_component = variable_base_multiplication_per_bit<
ArithmetizationType, CurveType, Ed25519Type, 9, non_native_policy_type>;
ArithmetizationType, CurveType, Ed25519Type, non_native_policy_type>;

using decomposition_component_type = bit_decomposition<ArithmetizationType, 9>;
using decomposition_component_type = bit_decomposition<ArithmetizationType>;

using bool_scalar_mul_component = bool_scalar_multiplication<
ArithmetizationType, Ed25519Type, 9, non_native_policy_type>;
ArithmetizationType, Ed25519Type, non_native_policy_type>;

const decomposition_component_type decomposition_subcomponent;
const mul_per_bit_component mul_per_bit_subcomponent;
const bool_scalar_mul_component bool_scalar_mul_subcomponent;
class gate_manifest_type : public component_gate_manifest {
public:
std::uint32_t gates_amount() const override {
return variable_base_multiplication::gates_amount;
}
};

const std::size_t rows_amount;
static gate_manifest get_gate_manifest(std::size_t witness_amount,
std::size_t lookup_column_amount,
std::size_t bits_amount) {
static gate_manifest manifest =
gate_manifest(gate_manifest_type())
.merge_with(
bool_scalar_mul_component::get_gate_manifest(witness_amount, lookup_column_amount))
.merge_with(mul_per_bit_component::get_gate_manifest(witness_amount, lookup_column_amount))
.merge_with(
decomposition_component_type::get_gate_manifest(witness_amount, lookup_column_amount,
bits_amount));

return manifest;
}

static manifest_type get_manifest() {
static manifest_type manifest = manifest_type(
std::shared_ptr<manifest_param>(new manifest_single_value_param(9)),
false
).merge_with(mul_per_bit_component::get_manifest())
.merge_with(decomposition_component_type::get_manifest())
.merge_with(bool_scalar_mul_component::get_manifest());
return manifest;
}

constexpr static const std::size_t rows(
const decomposition_component_type& decomposition_subcomponent,
const mul_per_bit_component& mul_per_bit_subcomponent,
const bool_scalar_mul_component& bool_scalar_mul_subcomponent
) {
return decomposition_subcomponent.rows_amount
+ mul_per_bit_subcomponent.rows_amount * 252
+ bool_scalar_mul_subcomponent.rows_amount;
constexpr static std::size_t get_rows_amount(std::size_t witness_amount,
std::size_t lookup_column_amount,
std::size_t bits_amount) {
return rows_amount_internal(witness_amount, lookup_column_amount, bits_amount);
}

// We use bits_amount from decomposition subcomponent to initialize rows_amount
// CRITICAL: do not move decomposition_subcomponent below rows_amount
const decomposition_component_type decomposition_subcomponent;
// CRITICAL: do not move decomposition_subcomponent below rows_amount
const mul_per_bit_component mul_per_bit_subcomponent;
const bool_scalar_mul_component bool_scalar_mul_subcomponent;

const std::size_t rows_amount = rows_amount_internal(this->witness_amount(), 0, decomposition_subcomponent.bits_amount);
constexpr static const std::size_t gates_amount = 0;

struct input_type {
Expand All @@ -116,7 +149,7 @@ namespace nil {
result_type(const variable_base_multiplication &component, std::uint32_t start_row_index) {
using mul_per_bit_component =
components::variable_base_multiplication_per_bit<ArithmetizationType,
CurveType, Ed25519Type, 9, non_native_policy_type>;
CurveType, Ed25519Type, non_native_policy_type>;
mul_per_bit_component component_instance({0, 1, 2, 3, 4, 5, 6, 7, 8}, {0}, {});

auto final_mul_per_bit_res = typename plonk_ed25519_mul_per_bit<BlueprintFieldType, ArithmetizationParams, CurveType>::result_type(
Expand All @@ -136,23 +169,21 @@ namespace nil {

template<typename ContainerType>
variable_base_multiplication(ContainerType witness, std::uint32_t bits_amount, bit_shift_mode mode_) :
component_type(witness, {}, {}),
component_type(witness, {}, {}, get_manifest()),
decomposition_subcomponent(witness, bits_amount, bit_composition_mode::MSB),
mul_per_bit_subcomponent(witness),
bool_scalar_mul_subcomponent(witness),
rows_amount(rows(decomposition_subcomponent, mul_per_bit_subcomponent, bool_scalar_mul_subcomponent)) {};
bool_scalar_mul_subcomponent(witness) {};

template<typename WitnessContainerType, typename ConstantContainerType,
typename PublicInputContainerType>
variable_base_multiplication(WitnessContainerType witness, ConstantContainerType constant,
PublicInputContainerType public_input, std::uint32_t bits_amount,
bit_shift_mode mode_) :
component_type(witness, constant, public_input),
component_type(witness, constant, public_input, get_manifest()),
decomposition_subcomponent(witness, constant, public_input,
bits_amount, bit_composition_mode::MSB),
mul_per_bit_subcomponent(witness, constant, public_input),
bool_scalar_mul_subcomponent(witness, constant, public_input),
rows_amount(rows(decomposition_subcomponent, mul_per_bit_subcomponent, bool_scalar_mul_subcomponent)) {};
bool_scalar_mul_subcomponent(witness, constant, public_input) {};

variable_base_multiplication(
std::initializer_list<typename component_type::witness_container_type::value_type>
Expand All @@ -162,20 +193,18 @@ namespace nil {
std::initializer_list<typename component_type::public_input_container_type::value_type>
public_inputs,
std::uint32_t bits_amount = 253, bit_shift_mode mode_ = bit_shift_mode::RIGHT) :
component_type(witnesses, constants, public_inputs),
component_type(witnesses, constants, public_inputs, get_manifest()),
decomposition_subcomponent(witnesses, constants, public_inputs,
bits_amount, bit_composition_mode::MSB),
mul_per_bit_subcomponent(witnesses, constants, public_inputs),
bool_scalar_mul_subcomponent(witnesses, constants, public_inputs),
rows_amount(rows(decomposition_subcomponent, mul_per_bit_subcomponent, bool_scalar_mul_subcomponent)) {};
bool_scalar_mul_subcomponent(witnesses, constants, public_inputs) {};
};

template<typename BlueprintFieldType, typename ArithmetizationParams, typename CurveType>
using plonk_ed25519_var_base_mul = variable_base_multiplication<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
CurveType,
typename crypto3::algebra::curves::ed25519,
9,
basic_non_native_policy<BlueprintFieldType>>;

template<typename BlueprintFieldType, typename ArithmetizationParams, typename CurveType>
Expand All @@ -186,13 +215,15 @@ namespace nil {
const typename plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>::input_type instance_input,
const std::uint32_t start_row_index) {

using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using component_type =
plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>;
using non_native_policy_type = typename component_type::non_native_policy_type;
using var = typename plonk_ed25519_mul_per_bit<BlueprintFieldType, ArithmetizationParams, CurveType>::var;
using Ed25519Type = typename crypto3::algebra::curves::ed25519;
typedef crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>
ArithmetizationType;

using component_type = plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>;


using mul_per_bit_component = typename component_type::mul_per_bit_component;
using decomposition_component_type = typename component_type::decomposition_component_type;
Expand Down Expand Up @@ -238,14 +269,14 @@ namespace nil {
const typename plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>::input_type instance_input,
const std::uint32_t start_row_index) {

using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using component_type =
plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>;
using non_native_policy_type = typename component_type::non_native_policy_type;
using var = typename plonk_ed25519_mul_per_bit<BlueprintFieldType, ArithmetizationParams, CurveType>::var;
using Ed25519Type = typename crypto3::algebra::curves::ed25519;
typedef crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>
ArithmetizationType;

using component_type = plonk_ed25519_var_base_mul<BlueprintFieldType, ArithmetizationParams, CurveType>;

using mul_per_bit_component = typename component_type::mul_per_bit_component;
using decomposition_component_type = typename component_type::decomposition_component_type;
using bool_scalar_mul_component = typename component_type::bool_scalar_mul_component;
Expand Down
Loading

0 comments on commit 3e2b053

Please sign in to comment.