diff --git a/src/base/snark0.ml b/src/base/snark0.ml index 9521fad9f..473b7117a 100644 --- a/src/base/snark0.ml +++ b/src/base/snark0.ml @@ -1393,6 +1393,41 @@ module Run = struct in { run_circuit; finish_computation } + (* start an as_prover / exists block and return a function to finish it and witness a given list of fields *) + let as_prover_manual (size_to_witness : int) : + (field array option -> Field.t array) Staged.t = + let s = !state in + let old_as_prover = Run_state.as_prover s in + (* enter the as_prover block *) + Run_state.set_as_prover s true ; + + let finish_computation (values_to_witness : field array option) = + (* leave the as_prover block *) + Run_state.set_as_prover s old_as_prover ; + + (* return variables *) + match (Run_state.has_witness s, values_to_witness) with + (* in compile mode, we return empty vars *) + | false, None -> + Core_kernel.Array.init size_to_witness ~f:(fun _ -> + Run_state.alloc_var s () ) + (* in prover mode, we expect values to turn into vars *) + | true, Some values_to_witness -> + let store_value = + (* If we're nested in a prover block, create constants instead of + storing. *) + if old_as_prover then Field.constant + else Run_state.store_field_elt s + in + Core_kernel.Array.map values_to_witness ~f:store_value + (* the other cases are invalid *) + | false, Some _ -> + failwith "Did not expect values to witness" + | true, None -> + failwith "Expected values to witness" + in + Staged.stage finish_computation + let run_unchecked x = finalize_is_running (fun () -> Perform.run_unchecked ~run:as_stateful (fun () -> mark_active ~f:x) ) diff --git a/src/base/snark_intf.ml b/src/base/snark_intf.ml index 777ef6c66..a6416ca31 100644 --- a/src/base/snark_intf.ml +++ b/src/base/snark_intf.ml @@ -1402,6 +1402,9 @@ module type Run_basic = sig , Proof_inputs.t * 'return_value ) manual_callbacks + (* Callback, low-level version of [as_prover] and [exists]. *) + val as_prover_manual : int -> (field array option -> Field.t array) Staged.t + (** Generate the public input vector for a given statement. *) val generate_public_input : ('input_var, 'input_value) Typ.t -> 'input_value -> Field.Constant.Vector.t