From ec63d10303d52f1bd9d1b56823da6241365f817b Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Wed, 8 Sep 2021 11:47:01 -0700 Subject: [PATCH] add model local function support (#8540) * updates for picking pnnx commit * add tests filter to c# tests * plus test fixes * fix versioning for contrib ops * fix tests * test filter for optional ops * more versioning related updates * fix test * fix layernorm spec * more updates * update docs * add more test filters * more filters * update binary size threshold * update docs * draft - enable model local function * enable model local functions in ORT * update to latest rel onnx commit * plus tests * plus more updates * plus updates * test updates * Fix for nested functions + shape inference * plus bug fix and updates per review * plus fixes per review * plus test updates * plus updates per review * plus fixes * fix a test --- include/onnxruntime/core/graph/function.h | 3 + include/onnxruntime/core/graph/graph.h | 36 +- .../core/framework/tensorprotoutils.cc | 10 +- onnxruntime/core/framework/tensorprotoutils.h | 5 + onnxruntime/core/graph/function.cc | 342 ++++++++++++++++-- onnxruntime/core/graph/function_impl.h | 24 +- onnxruntime/core/graph/graph.cc | 158 ++++++-- onnxruntime/core/graph/model.cc | 18 +- onnxruntime/core/graph/model.h | 2 +- onnxruntime/core/graph/model_load_utils.h | 6 + onnxruntime/test/ir/onnx_model_test.cc | 236 ++++++++++++ .../test/gradient/function_ops_test.cc | 2 +- .../linux/docker/scripts/requirements.txt | 2 +- 13 files changed, 753 insertions(+), 91 deletions(-) diff --git a/include/onnxruntime/core/graph/function.h b/include/onnxruntime/core/graph/function.h index cf9eefdd500a0..c07b7aaeacc1e 100644 --- a/include/onnxruntime/core/graph/function.h +++ b/include/onnxruntime/core/graph/function.h @@ -28,6 +28,9 @@ class Function { /** Gets the Graph instance for the Function body subgraph. */ virtual const onnxruntime::Graph& Body() const = 0; + + /** Gets the Mutable Graph instance for the Function body subgraph. */ + virtual onnxruntime::Graph& MutableBody() = 0; }; /** diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index af9601bdc11c5..898e2e5af0442 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -145,17 +145,19 @@ class Node { /** Gets the function body if applicable otherwise nullptr - @param try_init_func_body If not already intialized, initialize the function body - (only applicable to operators which are defined as function in ONNX spec). - Function body can be initialized in 2 cases : + @param try_init_func_body If not already initialized, initialize the function body + (This is not applicable for primitive operators.) + Function body can be initialized in 3 cases : 1. For nodes of type "Fused" - 2. For nodes which are defined as functions in ONNX spec (example: DynamicQuantizeLinear) + 2. For nodes which are defined as functions in the spec (example: DynamicQuantizeLinear) + 3. For nodes which reference a model local function. These functions are defined in the model itself and + do not have any schema associated with them. For all other cases this will always return nullptr. Nodes of type "Fused" are created during partitioning and the function body initialization for such nodes also happens during node creation. Therefore, - initialization of function body will happen via this method only in case 2 mentioned above. + initialization of function body will happen via this method only in cases 2 and 3 mentioned above. */ - const Function* GetFunctionBody(bool try_init_func_body = true); + Function* GetMutableFunctionBody(bool try_init_func_body = true); /** Gets the function body if applicable otherwise nullptr. */ const Function* GetFunctionBody() const noexcept { return func_body_; } @@ -395,6 +397,11 @@ class Node { execution_provider_type_ = execution_provider_type; } + /** Sets initialized function body for node. This is called right after function body initialization for a node. + * or during function inlining when a nested function is encountered. + */ + void SetFunctionBody(Function& func); + /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node. If the NodeArg is an explicit or implicit input, is_input will be true when func is called. @param include_missing_optional_defs Include NodeArgs that are optional and were not provided @@ -527,8 +534,6 @@ class Node { // validate and update the input arg count common::Status UpdateInputArgCount(); - void SetFunctionBody(const Function& func); - const Definitions& GetDefinitions() const noexcept { return definitions_; } const Relationships& GetRelationships() const noexcept { return relationships_; } @@ -558,7 +563,7 @@ class Node { Node::Type node_type_ = Node::Type::Primitive; // The function body is owned by graph_ - const Function* func_body_ = nullptr; + Function* func_body_ = nullptr; // Node doc string. std::string description_; @@ -1022,6 +1027,9 @@ class Graph { /** Initialize function body for the given node */ void InitFunctionBodyForNode(Node& node); + /** Gets Model local functions from the root/parent graph.*/ + const std::unordered_map& GetModelLocalFunctions() const; + /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will be used as a GraphProto attribute in another Node.. e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to @@ -1045,7 +1053,7 @@ class Graph { void SetOutputs(const std::vector& outputs); /** Sets the type of a NodeArg, replacing existing type/shape if any */ - void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto); + void SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto); const Node* GetProducerNode(const std::string& node_arg_name) const { return GetProducerNodeImpl(*this, node_arg_name); @@ -1107,6 +1115,9 @@ class Graph { // Whether to set that no proto sync is required after resolving. // Useful for resolving right after loading from a GraphProto. bool no_proto_sync_required = false; + // When set to true, graph resolve will be called for initialized function bodies as well. This is used + // in case of nested model local functions. + bool traverse_function_body = false; }; /** @@ -1203,6 +1214,7 @@ class Graph { const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + const std::vector& model_functions, const logging::Logger& logger); // internal use by the Graph class only @@ -1213,6 +1225,7 @@ class Graph { IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node, + const std::vector& model_functions, const logging::Logger& logger); void InitializeStateFromModelFileGraphProto(); @@ -1391,7 +1404,10 @@ class Graph { #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaCollectionPtr schema_registry_; + // Container to hold initialized function bodies std::vector> function_container_; + + std::unordered_map model_local_functions_; #endif // Graph nodes. diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 1619402cab63c..ce23ae1de5192 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -791,7 +791,7 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, const Path& model_path, - ONNX_NAMESPACE::TensorProto& tensor) { + ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name) { const AttributeProto& constant_attribute = node.attribute(0); switch (constant_attribute.type()) { @@ -836,11 +836,17 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n } // set name last in case attribute type was tensor (would copy over name) - *(tensor.mutable_name()) = node.output(0); + *(tensor.mutable_name()) = tensor_name; return Status::OK(); } +common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, + const Path& model_path, + ONNX_NAMESPACE::TensorProto& tensor) { + return ConstantNodeProtoToTensorProto(node, model_path, tensor, node.output(0)); +} + #if !defined(DISABLE_SPARSE_TENSORS) static Status CopySparseData(size_t n_sparse_elements, const ONNX_NAMESPACE::TensorProto& indices, diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index ade59cc3013df..365e27ac2daa0 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -81,6 +81,11 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& // However if AttributeProto contains SparseTensorProto then it converts the data into dense tensor proto // (including loading external data when applicable). // model_path is used for contructing full path for external_data +// tensor_name specifies the name for the new TensorProto TensorProto +common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, + const Path& model_path, + ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name); + common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, const Path& model_path, ONNX_NAMESPACE::TensorProto& tensor); diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 04900501e1a91..a6fcada57183b 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -7,6 +7,129 @@ #include "core/graph/model.h" #include "onnx/shape_inference/implementation.h" +namespace ONNX_NAMESPACE { +// Infer shape for functions. This also supports +// nested model local functions. +// TODO: Add this to onnx instead of adding it here. +void InferShapeForFunctionNode( + InferenceContext& ctx, + const FunctionProto& func_proto, + const std::unordered_map& func_opset_imports, + const ShapeInferenceOptions& options, + const ISchemaRegistry* schema_registry, + const std::unordered_map& in_model_functions, + std::function get_func_id) { + GraphProto g; + // Get a temporary tensor-shape map + const auto num_func_inputs = func_proto.input_size(); + std::unordered_map value_types_by_name; + std::vector types_cache(func_proto.input_size()); + for (int i = 0; i < num_func_inputs; ++i) { + types_cache[i] = *ctx.getInputType(i); + value_types_by_name[func_proto.input().Get(i)] = &types_cache[i]; + } + + // Get a temporary initial value map + std::unordered_map initializers_by_name; + std::unordered_map sparse_initializers_by_name; + for (int i = 0; i < static_cast(ctx.getNumInputs()) && i < num_func_inputs; ++i) { + const TypeProto* type = ctx.getInputType(i); + if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) { + initializers_by_name[func_proto.input().Get(i)] = ctx.getInputData(i); + } else if (type->value_case() == TypeProto::kSparseTensorType && + ctx.getInputSparseData(i) != nullptr) { + sparse_initializers_by_name[func_proto.input().Get(i)] = ctx.getInputSparseData(i); + } + } + std::unordered_map attr_map; + for (auto& attr : func_proto.attribute()) { + if (ctx.getAttribute(attr) != nullptr) { + attr_map[attr] = ctx.getAttribute(attr); + } + } + + for (auto& n : func_proto.node()) { + NodeProto copy_n(n); + // Add attribute information into the temporary node + copy_n.clear_attribute(); + for (const auto& attr : n.attribute()) { + if (attr.has_ref_attr_name()) { + if (attr_map.count(attr.ref_attr_name())) { + auto copy_attr = *attr_map[attr.ref_attr_name()]; + copy_attr.set_name(attr.name()); + copy_n.add_attribute()->CopyFrom(copy_attr); + } + } else { + copy_n.add_attribute()->CopyFrom(attr); + } + } + ONNX_NAMESPACE::shape_inference::InferenceContextImpl func_node_ctx( + copy_n, value_types_by_name, initializers_by_name, sparse_initializers_by_name, {}); + + // Resolve domain for node + auto it = func_opset_imports.find(n.domain()); + if (it == func_opset_imports.end()) { + fail_type_inference("Cannot infer type and shape for function", func_proto.name(), + ". No opset import for domain", n.domain(), " referenced by function body node ", + n.name(), " optype ", n.op_type()); + } + auto domain_version = it->second; + const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain()); + if (schema) { + schema->GetTypeAndShapeInferenceFunction()(func_node_ctx); + } else { + // check model local functions for FunctionProto + auto iter = in_model_functions.find(get_func_id(n.domain(), n.op_type())); + if (iter == in_model_functions.end()) { + return; + } + + std::unordered_map func_node_opset_imports; + for (const auto& opset_import : iter->second->opset_import()) { + // If graph imports does not contain opset_import then insert it otherwise the one in graph imports overrides. + // If the opset imports are not compatible then this will be caught during function body inline. + func_node_opset_imports.insert({opset_import.domain(), static_cast(opset_import.version())}); + } + + InferShapeForFunctionNode(func_node_ctx, *iter->second, func_node_opset_imports, options, schema_registry, in_model_functions, get_func_id); + } + + for (int i = 0; i < copy_n.output_size(); ++i) { + TypeProto* inferred_output_type = func_node_ctx.getOutputType(i); + // Checking, Storing the inferred information + auto iter = value_types_by_name.find(n.output(i)); + TypeProto* existingType = nullptr; + if (iter != value_types_by_name.end()) { + existingType = iter->second; + shape_inference::checkShapesAndTypes(*inferred_output_type, *existingType); + } else { + // Store the inferred type info in the + // subgraph temporarily + auto vi = g.add_value_info(); + vi->set_name(copy_n.output(i)); + existingType = vi->mutable_type(); + } + + shape_inference::mergeShapesAndTypes(*inferred_output_type, existingType); + + // Make merged info available to further inference. + value_types_by_name[copy_n.output(i)] = existingType; + } + } + for (int i = 0; i < func_proto.output_size(); ++i) { + const std::string& output_name = func_proto.output().Get(i); + // Skip if no type inferred for the tensor + auto iter = value_types_by_name.find(output_name); + if (iter != value_types_by_name.cend()) { + // Copy the type info to ctx + // to pass back to main graph + auto type_proto = ctx.getOutputType(i); + type_proto->CopyFrom(*(iter->second)); + } + } +} + +} // namespace ONNX_NAMESPACE namespace onnxruntime { // Utilify function to get the imported version of domain from opset imports @@ -29,6 +152,15 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto std::vector> output_types_list(onnx_func_proto_.output_size()); std::unordered_map> type_constraint_map; std::unordered_map attribute_type_map; + + // Create an all permissive list of data types. This will be used in case of model local functions + // when we cannot infer the type constraints from function proto body + std::unordered_set all_types; + all_types.insert(ONNX_NAMESPACE::OpSchema::all_tensor_types_with_bfloat().cbegin(), + ONNX_NAMESPACE::OpSchema::all_tensor_types_with_bfloat().cend()); + all_types.insert(ONNX_NAMESPACE::OpSchema::all_tensor_sequence_types().cbegin(), + ONNX_NAMESPACE::OpSchema::all_tensor_sequence_types().cend()); + auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); std::unordered_map opset_imports; for (auto& relied_opset : onnx_func_proto_.opset_import()) { @@ -44,12 +176,21 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto auto iter = input_name_idx_map.find(in_name); if (iter != input_name_idx_map.end()) { int idx = iter->second; - const auto& p = node_op_schema->inputs().at(i); - std::string type_str = p.GetTypeStr() + "in" + std::to_string(i); + std::string type_str = node_op_schema ? node_op_schema->inputs().at(i).GetTypeStr() + "in" + std::to_string(idx) : "Tin" + std::to_string(idx); input_types_list[idx] = std::make_pair(in_name, type_str); if (!type_constraint_map.count(type_str)) { - for (auto s : p.GetTypes()) { - type_constraint_map[type_str].emplace_back(*s); + // If schema is available for the node then get the allowed types from the schema + // else add all types to allowed types list. It is OK to add all types. Any issues will be + // caught later if we try to inline the nodes and there is no kernl available for + // the requested types. + if (node_op_schema) { + for (auto s : node_op_schema->inputs().at(i).GetTypes()) { + type_constraint_map[type_str].emplace_back(*s); + } + } else { + for (const auto& s : all_types) { + type_constraint_map[type_str].emplace_back(s); + } } } } @@ -59,12 +200,21 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto auto iter = output_name_idx_map.find(out_name); if (iter != output_name_idx_map.end()) { int idx = iter->second; - const auto& p = node_op_schema->outputs().at(i); - std::string type_str = p.GetTypeStr() + "out" + std::to_string(i); + std::string type_str = node_op_schema ? node_op_schema->outputs().at(i).GetTypeStr() + "out" + std::to_string(i) : "Tout" + std::to_string(i); output_types_list[idx] = std::make_pair(out_name, type_str); if (!type_constraint_map.count(type_str)) { - for (auto s : p.GetTypes()) { - type_constraint_map[type_str].emplace_back(*s); + // If schema is available for the node then get the allowed types from the schema + // else add all types to allowed types list. It is OK to add all types. Any issues will be + // caught later if we try to inline the nodes and there is no kernel available for + // the requested types. + if (node_op_schema) { + for (auto data_type : node_op_schema->outputs().at(i).GetTypes()) { + type_constraint_map[type_str].emplace_back(*data_type); + } + } else { + for (const auto& data_type : all_types) { + type_constraint_map[type_str].emplace_back(data_type); + } } } } @@ -100,6 +250,40 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto } } +/** Utility function to initialize function body for nested model local function + @param graph Graph in which this node belongs too. For nested functions, graph is the parent function body graph + @param node_index index of the node in graph + @param onnx_function_proto FunctionProto for the function + @param in_model_function_protos Model local functions. These are schema less functions which are defined in the ModelProto of the main/parent model. + @param function_container graph level function container which will own the initialized function body + @param logger instance of Logger + @param is_nested_function True if this is a nested function. For nested functions graph resolved is delayed until parent function body is fully initialized. +*/ +static void InitNestedModelLocalFunction(onnxruntime::Graph& graph, + const onnxruntime::NodeIndex& node_index, + ONNX_NAMESPACE::FunctionProto& onnx_function_proto, + const std::unordered_map& in_model_function_protos, + std::vector>& function_container, + const logging::Logger& logger, + bool is_nested_function) { + ORT_TRY { + auto func_ptr = std::make_unique(graph, node_index, onnx_function_proto, + in_model_function_protos, function_container, + logger, is_nested_function); + function_container.emplace_back(std::move(func_ptr)); + auto* node_in_graph = graph.GetNode(node_index); + node_in_graph->SetFunctionBody(*function_container.back()); + } + ORT_CATCH(const std::exception& e) { + LOGS(logger, WARNING) << "Function body initialization failed for Function '" + << onnx_function_proto.name() << "'. Error message " << e.what() + << ". Execution will fail if ORT does not have a specialized kernel for this op"; + // Return without using this function op's expansion. No need to fail just yet. + // If ORT has a specialized kernel for this op then execution will proceed + return; + } +} + // This method updates the names of inputs/outputs of nodes in subgraphs // within nodes in an op that has a FunctionBody. // Subgraphs within an op with a FunctionBody could be referencing inputs/outputs in the OpSchema @@ -113,7 +297,7 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto // E.g. For Range-11, {"start" : 0, "limit": 1, "delta": 2} // (5) A map containing the output name from the op schema to the corresponding index // E.g. For Range-11, {"output" : 0} -static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& subgraph_proto, +static void UpdateSubgraphsWithinFunctionBody(ONNX_NAMESPACE::GraphProto& subgraph_proto, const Graph& parent_graph, const ONNX_NAMESPACE::NodeProto& function_node_in_parent_graph, const std::unordered_map& input_name_idx_map, @@ -153,7 +337,7 @@ static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& su } // Recurse into any subgraphs in the current subgraph being processed if ((*subgraph_node_attr).has_g()) { - update_subgraphs_within_function_body(*(*subgraph_node_attr).mutable_g(), + UpdateSubgraphsWithinFunctionBody(*(*subgraph_node_attr).mutable_g(), parent_graph, function_node_in_parent_graph, input_name_idx_map, output_name_idx_map); } @@ -212,7 +396,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, body_("fused_function_subgraph", false, onnxruntime::ModelMetaData(), graph.ModelPath().ToPathString(), IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), - graph.DomainToVersionMap(), {}, logger) { + graph.DomainToVersionMap(), {} , logger) { auto& function_body_graph = body_.MainGraph(); auto* meta_def = nodes_to_fuse.GetMetaDef(); @@ -275,10 +459,13 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); } -FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, +FunctionImpl::FunctionImpl(onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, const ONNX_NAMESPACE::FunctionProto& onnx_func_proto, - const logging::Logger& logger) + const std::unordered_map& model_local_functions, + std::vector>& function_container, + const logging::Logger& logger, + bool is_nested_function) : parent_graph_(&graph), body_(onnx_func_proto.name(), false, onnxruntime::ModelMetaData(), graph.ModelPath().ToPathString(), IOnnxRuntimeOpSchemaRegistryList(), @@ -291,12 +478,13 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, // as we might make some modifications to the FunctionProto along the way const auto* node_in_parent_graph = parent_graph_->GetNode(node_index); + // For schema defined functions get the version from the node in parent graph. - // For the functions which do not have schema defined (model local functions) + // For the functions which do not have schema defined (model local functions) // get the since version from the version in opset imports using the domain. - auto since_version = node_in_parent_graph->SinceVersion() == -1 - ? GetVersionForDomain(node_in_parent_graph->Domain(), body_.MainGraph().DomainToVersionMap()) - : node_in_parent_graph->SinceVersion(); + auto since_version = node_in_parent_graph->SinceVersion() == -1 + ? GetVersionForDomain(node_in_parent_graph->Domain(), body_.MainGraph().DomainToVersionMap()) + : node_in_parent_graph->SinceVersion(); op_schema_ = std::make_unique(); op_schema_->SetName(onnx_func_proto_.name()); op_schema_->SetDomain(node_in_parent_graph->Domain()); @@ -304,6 +492,10 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, op_schema_->SinceVersion(static_cast(since_version)); std::unordered_map input_name_idx_map; std::unordered_map output_name_idx_map; + std::unordered_map internal_input_output_updates; + + auto& function_body_graph = body_.MainGraph(); + for (int i = 0; i < onnx_func_proto_.input_size(); ++i) { input_name_idx_map[onnx_func_proto_.input().Get(i)] = i; } @@ -341,10 +533,10 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, if (!cached_op_schema || !cached_op_schema->has_type_and_shape_inference_function()) { op_schema_->TypeAndShapeInferenceFunction( - [this](ONNX_NAMESPACE::InferenceContext& ctx) { + [this, &model_local_functions](ONNX_NAMESPACE::InferenceContext& ctx) { auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); ONNX_NAMESPACE::ShapeInferenceOptions options {true, 1, false}; - ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(&onnx_func_proto_, body_.MainGraph().DomainToVersionMap(), schema_registry, ctx, options); + InferShapeForFunctionNode(ctx, onnx_func_proto_, body_.MainGraph().DomainToVersionMap(), options, schema_registry, model_local_functions, function_utils::GetFunctionIdentifier); }); } else { op_schema_->TypeAndShapeInferenceFunction(cached_op_schema->GetTypeAndShapeInferenceFunction()); @@ -352,7 +544,6 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, op_schema_->Finalize(); //construct body - auto& function_body_graph = body_.MainGraph(); std::vector graph_inputs(node_in_parent_graph->InputDefs().size(), nullptr), graph_outputs(node_in_parent_graph->OutputDefs().size(), nullptr); @@ -363,6 +554,19 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, ONNX_NAMESPACE::NodeProto function_op_node_proto; // NodeProto pertaining to the op with a FunctionBody node_in_parent_graph->ToProto(function_op_node_proto); + std::unordered_set node_input_outputs; + + for (const auto* input_def : node_in_parent_graph->InputDefs()) { + if (input_def->Exists()) { + node_input_outputs.insert(input_def->Name()); + } + } + + for (const auto* output_def : node_in_parent_graph->OutputDefs()) { + if (output_def->Exists()) { + node_input_outputs.insert(output_def->Name()); + } + } ONNX_NAMESPACE::TypeProto tensor_int32; // dummy type used for unused formal parameters tensor_int32.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); @@ -383,7 +587,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, const std::string& tensor_name = (*node).input().Get(idx); auto iter = input_name_idx_map.find(tensor_name); if (iter != input_name_idx_map.end()) { - // Preserving NodeArg and input/output names + // If input is part of function inputs, preserve NodeArg and input/output names const std::string& actual_parameter_name = function_op_node_proto.input().Get(iter->second); if (!actual_parameter_name.empty()) { const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(actual_parameter_name); @@ -399,11 +603,34 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, graph_inputs[iter->second] = &unused_formal_param; } } else { - auto& n_input = function_body_graph.GetOrCreateNodeArg( - tensor_name + "_" + std::to_string(node_index), nullptr); - inputs.push_back(&n_input); + // If input is part of function outputs, preserve NodeArg and input/output names + iter = output_name_idx_map.find(tensor_name); + if (iter != output_name_idx_map.end()) { + const std::string& actual_parameter_name = function_op_node_proto.output().Get(iter->second); + const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(actual_parameter_name); + const ONNX_NAMESPACE::TypeProto* actual_type = node_arg->TypeAsProto(); + auto& n_input = function_body_graph.GetOrCreateNodeArg(actual_parameter_name, actual_type); + inputs.push_back(&n_input); + } else { + // Input is intermediate input in function body. + // Check if input name needs to be mapped to a new unique name (this is required when node input\outputs + // have same names as intermediate input\outputs. + auto it = internal_input_output_updates.find(tensor_name); + if (it != internal_input_output_updates.end()) { + auto& n_input = function_body_graph.GetOrCreateNodeArg( + it->second, nullptr); + inputs.push_back(&n_input); + } else { + // Input is intermediate function body input and has no name collision with node input\output + // It can be added to the graph without any modification + auto& n_input = function_body_graph.GetOrCreateNodeArg( + tensor_name, nullptr); + inputs.push_back(&n_input); + } + } } } + for (int idx = 0; idx < (*node).output_size(); ++idx) { std::string tensor_name = (*node).output().Get(idx); auto iter = output_name_idx_map.find(tensor_name); @@ -424,9 +651,20 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, graph_outputs[iter->second] = &unused_formal_param; } } else { - auto& n_output = function_body_graph.GetOrCreateNodeArg( - tensor_name + "_" + std::to_string(node_index), nullptr); - outputs.push_back(&n_output); + // Output is intermediate output in function body. + // Check if output name needs to be mapped to a new unique name (this is required when node input\outputs + // have same names as intermediate input\outputs. + auto it = node_input_outputs.find(tensor_name); + if (it != node_input_outputs.end()) { + auto& n_output = function_body_graph.GetOrCreateNodeArg( + tensor_name + uniq_identifier, nullptr); + outputs.push_back(&n_output); + internal_input_output_updates.insert({tensor_name, tensor_name + uniq_identifier}); + } else { + auto& n_output = function_body_graph.GetOrCreateNodeArg( + tensor_name, nullptr); + outputs.push_back(&n_output); + } } } @@ -454,7 +692,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, node_attr != (*node).mutable_attribute()->end(); ++node_attr) { // If this node contains subgraphs, the node inputs/outputs within them needs to be fixed as well if ((*node_attr).has_g()) { - update_subgraphs_within_function_body(*(*node_attr).mutable_g(), + UpdateSubgraphsWithinFunctionBody(*(*node_attr).mutable_g(), *parent_graph_, function_op_node_proto, input_name_idx_map, output_name_idx_map); } @@ -470,15 +708,53 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, new_attr_map[(*node_attr).name()] = *node_attr; } } - function_body_graph.AddNode(uniq_identifier + "_" + std::to_string(node_index), (*node).op_type(), + function_body_graph.AddNode(uniq_identifier, (*node).op_type(), (*node).doc_string(), inputs, outputs, &new_attr_map, (*node).domain()); } function_body_graph.SetInputs(graph_inputs); function_body_graph.SetOutputs(graph_outputs); - auto status = function_body_graph.Resolve(); - ORT_ENFORCE(status.IsOK(), "Resolve subgraph failed:", status.ErrorMessage()); + // Nested model local functions need to be initialized before Graph::Resolve() can be called for the function body. + // Parse the graph and initialize functions for nodes which reference model local functions. + // Only parse the graph if the model contains model local functions. + // Once all model local functions within function body are initialized, Graph Resolve for parent function body is called + // During graph resolve for parent function, graph resolve for every nested model local function is called too... + // Such a top down approach is required to successfully carry out type inference for schema less functions. + // Schema defined functions are treated a bit different from model local aka schema less functions. These are initialized + // during graph resolve of parent functions. + if (model_local_functions.size() > 0) { + for (auto node = function_body_graph.Nodes().begin(); node != function_body_graph.Nodes().end(); ++node) { + // Init nested functions + std::string func_identifier = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType()); + auto iter = model_local_functions.find(func_identifier); + if (iter == model_local_functions.end()) { + continue; + } + + // This node has a model local function proto. + auto onnx_function_proto = *(iter->second); + InitNestedModelLocalFunction(function_body_graph, node->Index(), onnx_function_proto, model_local_functions, function_container, logger, true); + } + } + + // Graph resolve should be called on the parent functions only. Skip resolve if this is a nested function. + // Nested function bodies will be resolved along with parent function body as we set traverse_function_body to true. + // This is only applicable for model local functions which are schema less. + if (!is_nested_function) { + onnxruntime::Graph::ResolveOptions options; + options.traverse_function_body = true; + auto status = function_body_graph.Resolve(options); + + ORT_ENFORCE(status.IsOK(), "Resolve subgraph failed:", status.ErrorMessage()); + + ORT_ENFORCE(node_in_parent_graph->InputDefs().size() == function_body_graph.GetInputsIncludingInitializers().size(), + "Node " + node_in_parent_graph->Name() + "'s number of inputs is different from function body graph's number of input."); + + ORT_ENFORCE(node_in_parent_graph->OutputDefs().size() == function_body_graph.GetOutputs().size(), + "Node ", node_in_parent_graph->Name(), "'s number of outputs is different from function body graph's number of outputs."); + } + } // namespace onnxruntime FunctionImpl::~FunctionImpl() = default; @@ -491,6 +767,10 @@ const onnxruntime::Graph& FunctionImpl::Body() const { return body_.MainGraph(); } +onnxruntime::Graph& FunctionImpl::MutableBody() { + return body_.MainGraph(); +} + ViewerFunctionImpl::ViewerFunctionImpl(const onnxruntime::Graph& graph, const IndexedSubGraph& nodes_to_fuse, const logging::Logger& /*logger*/) { diff --git a/onnxruntime/core/graph/function_impl.h b/onnxruntime/core/graph/function_impl.h index a913e9029ea75..20843788a9713 100644 --- a/onnxruntime/core/graph/function_impl.h +++ b/onnxruntime/core/graph/function_impl.h @@ -27,10 +27,15 @@ class FunctionImpl final : public Function { // a Function Op. This takes in a FunctionProto and constructs function body // from it. The function body initialization happens during model load in graph resolve // phase. - FunctionImpl(const onnxruntime::Graph& graph, + // model_local_functions contains domain:optype to model_local_functions map. This is + // used to resolve and initialize nested functions. + FunctionImpl(onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, const ONNX_NAMESPACE::FunctionProto& onnx_func, - const logging::Logger& logger); + const std::unordered_map& in_model_function_protos, + std::vector>& function_container, + const logging::Logger& logger, + bool is_nested_function = false); ~FunctionImpl() override; @@ -38,6 +43,8 @@ class FunctionImpl final : public Function { const onnxruntime::Graph& Body() const override; + onnxruntime::Graph& MutableBody() override; + private: const onnxruntime::Graph* const parent_graph_; std::unique_ptr op_schema_; @@ -59,8 +66,21 @@ class ViewerFunctionImpl final : public Function { const onnxruntime::Graph& Body() const override { ORT_THROW("Not supported"); } + onnxruntime::Graph& MutableBody() override { ORT_THROW("Not supported"); } + private: std::unique_ptr op_schema_; }; +namespace function_utils { +/** Get the unique id for function. This is used as a key to find the +* relevant model local function from it's container. +* @param function_domain Domain for the function. +* @param function_name Name of the function. Name should match the OpType of the node which references the function. +*/ +inline std::string GetFunctionIdentifier(const std::string& function_domain, const std::string& function_name) { + return function_domain + ":" + function_name; +} +} + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 15295d42e5d5e..a6e45ce93a550 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -474,7 +474,7 @@ const Path& Node::ModelPath() const noexcept { #if !defined(ORT_MINIMAL_BUILD) -const Function* Node::GetFunctionBody(bool try_init_func_body) { +Function* Node::GetMutableFunctionBody(bool try_init_func_body) { if (nullptr != func_body_) { return func_body_; } @@ -487,7 +487,7 @@ const Function* Node::GetFunctionBody(bool try_init_func_body) { return func_body_; } -void Node::SetFunctionBody(const Function& func) { +void Node::SetFunctionBody(Function& func) { func_body_ = &func; op_ = &func.OpSchema(); since_version_ = op_->since_version(); @@ -1003,12 +1003,14 @@ Graph::Graph(const Model& owning_model, const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + const std::vector& model_functions, const logging::Logger& logger) - : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger) {} + : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, model_functions, logger) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, Graph* parent_graph, const Node* parent_node, + const std::vector& model_functions, const logging::Logger& logger) : owning_model_(owning_model), graph_proto_(graph_proto), @@ -1025,6 +1027,10 @@ Graph::Graph(const Model& owning_model, ArgNameToTypeMap name_to_type_map; const auto& model_path = ModelPath(); + for (auto func : model_functions) { + model_local_functions_[function_utils::GetFunctionIdentifier(func->domain(), func->name())] = func; + } + // Process 'Constant' nodes // Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list for (auto& node : graph_proto_->node()) { @@ -1167,7 +1173,8 @@ Graph::Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::Graph &subgraph_proto, parent_graph.DomainToVersionMap(), parent_graph.IrVersion(), parent_graph.schema_registry_, &parent_graph, - &parent_node, parent_graph.logger_) { + &parent_node, {}, + parent_graph.logger_) { } void Graph::InitializeStateFromModelFileGraphProto() { @@ -2414,7 +2421,9 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { } } - InitFunctionBodyForNode(node); + if (!node.op_ || (node.op_ && (node.op_->HasFunction() || node.op_->HasContextDependentFunction()))) { + InitFunctionBodyForNode(node); + } if (!node.op_) { return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op"); @@ -2425,6 +2434,13 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { if (node.since_version_ == -1) { node.since_version_ = node.op_->since_version(); } + } else { + // This is only applicable for model local functions. + // In case of nested model local functions, graph resolve is called during resolve for parent + // function body graph otherwise type inference for nest function cannot happen. + if (options.traverse_function_body && node.GetFunctionBody() != nullptr) { + node.GetMutableFunctionBody()->MutableBody().Resolve(options); + } } ORT_RETURN_IF_ERROR(node.UpdateInputArgCount()); @@ -2469,9 +2485,18 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { return Status::OK(); } +const std::unordered_map& Graph::GetModelLocalFunctions() const { + if (parent_graph_ == nullptr) { + return model_local_functions_; + } + return parent_graph_->GetModelLocalFunctions(); +} + void Graph::InitFunctionBodyForNode(Node& node) { + ONNX_NAMESPACE::FunctionProto onnx_function_proto; if (node.op_ && (node.op_->HasFunction() || node.op_->HasContextDependentFunction())) { - onnx::FunctionProto onnx_function_proto; + // This node has a schema defined function proto. If it is a context dependent function + // then build it otherwise fetch the FunctionProto from schema. if (node.op_->HasContextDependentFunction()) { NodeProto node_proto; node.ToProto(node_proto); @@ -2484,24 +2509,39 @@ void Graph::InitFunctionBodyForNode(Node& node) { } else input_types.emplace_back(); } - onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); + ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); if (!node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto)) return; } else { onnx_function_proto = *(node.op_->GetFunction()); } - - ORT_TRY { - auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto, - logger_); - function_container_.emplace_back(std::move(func_ptr)); - node.SetFunctionBody(*function_container_.back()); - } - ORT_CATCH(const std::exception&) { - // Return without using this function op's expansion. No need to fail just yet. - // If ORT has a specialized kernel for this op then execution will proceed + } else { + std::string func_identifier = function_utils::GetFunctionIdentifier(node.Domain(), node.OpType()); + const auto& model_local_functions = GetModelLocalFunctions(); + auto iter = model_local_functions.find(func_identifier); + if (iter == model_local_functions.end()) { return; } + + // This node has a model local function proto. + onnx_function_proto = *(iter->second); + } + + ORT_TRY { + // Explicitly pass the model local functions as t + auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto, + GetModelLocalFunctions(), function_container_, logger_); + function_container_.emplace_back(std::move(func_ptr)); + node.SetFunctionBody(*function_container_.back()); + } + ORT_CATCH(const std::exception& e) { + LOGS(logger_, WARNING) << "Function body initialization failed for node '" + << node.Name() << "' optype " << node.OpType() + << ". Error message " << e.what() + << ". Execution will fail if ORT does not have a specialized kernel for this op"; + // Return without using this function op's expansion. No need to fail just yet. + // If ORT has a specialized kernel for this op then execution will proceed + return; } } @@ -3761,24 +3801,37 @@ Status Graph::InlineFunction(Node& node) { for (auto output_edge : output_edges) { RemoveEdge(node.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); } + + // Map of function input outputs to nodes input/outputs std::unordered_map remap_input_output; - if (node.MutableInputDefs().size() != subgraph.GetInputsIncludingInitializers().size()) - return Status(ONNXRUNTIME, FAIL, "Node " + node.Name() + "'s number of inputs is different from function body graph's number of input."); + // Set of node input output names as these names need to be preserved during inlining + std::unordered_set func_input_output_names; + for (size_t i = 0; i < subgraph.GetInputsIncludingInitializers().size(); ++i) { auto* input = subgraph.GetInputsIncludingInitializers()[i]; - if (input->Name() != node.MutableInputDefs()[i]->Name()) + if (input->Name() != node.MutableInputDefs()[i]->Name()) { remap_input_output[input->Name()] = node.MutableInputDefs()[i]; + } + func_input_output_names.insert(input->Name()); } - ORT_ENFORCE(node.MutableOutputDefs().size() == subgraph.GetOutputs().size(), - "Node ", node.Name(), "'s number of outputs is different from function body graph's number of outputs."); for (size_t i = 0; i < subgraph.GetOutputs().size(); ++i) { auto* output = subgraph.GetOutputs()[i]; - if (output->Name() != node.MutableOutputDefs()[i]->Name()) + if (output->Name() != node.MutableOutputDefs()[i]->Name()) { remap_input_output[output->Name()] = node.MutableOutputDefs()[i]; + } + func_input_output_names.insert(output->Name()); } + + // create a uniq_identifier to append to every node name and intermediate input\outputs + // to make sure there are no unintended duplicates + std::stringstream ss; + ss << static_cast(&node); + auto uniq_identifier = ss.str(); + RemoveNode(node.Index()); + const auto& model_path = ModelPath(); for (const auto& subgraph_node : subgraph.Nodes()) { if (subgraph_node.OpType() == kConstant) { @@ -3786,31 +3839,56 @@ Status Graph::InlineFunction(Node& node) { ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; subgraph_node.ToProto(subgraph_node_proto); const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor)); + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0) + uniq_identifier)); name_to_initial_tensor_[tensor->name()] = tensor; } else { std::vector inputs, outputs; for (auto* input : subgraph_node.InputDefs()) { - auto it = remap_input_output.find(input->Name()); - if (it != remap_input_output.end()) - inputs.push_back(it->second); - else - inputs.push_back(const_cast(input)); + if (func_input_output_names.find(input->Name()) != func_input_output_names.end()) { + auto it = remap_input_output.find(input->Name()); + if (it != remap_input_output.end()) { + // This is a function input/output and needs to be remapped to node input for correctness + inputs.push_back(it->second); + } else { + // This is a function input/output so preserve the existing name + auto& n_input = GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + } + } else { + // This is an intermediate input. Add a unique identifier as suffix to make sure + // there is no name collision with names in parent graph + auto& n_input = GetOrCreateNodeArg(input->Name() + uniq_identifier, input->TypeAsProto()); + inputs.push_back(&n_input); + } } for (auto* output : subgraph_node.OutputDefs()) { - auto it = remap_input_output.find(output->Name()); - if (it != remap_input_output.end()) - outputs.push_back(it->second); - else - outputs.push_back(const_cast(output)); + if (func_input_output_names.find(output->Name()) != func_input_output_names.end()) { + auto it = remap_input_output.find(output->Name()); + if (it != remap_input_output.end()) { + outputs.push_back(it->second); + } else { + auto& n_output = GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + } + } else { + auto& n_output = GetOrCreateNodeArg(output->Name() + uniq_identifier, output->TypeAsProto()); + outputs.push_back(&n_output); + } + } + + auto& new_node = AddNode(subgraph_node.Name() + uniq_identifier, subgraph_node.OpType(), subgraph_node.Description(), + inputs, + outputs, + &subgraph_node.GetAttributes(), + subgraph_node.Domain()); + + // If this node has an initialized function body add it to the new node so that reinitialization is not required. + if (subgraph_node.GetFunctionBody() != nullptr) { + new_node.SetFunctionBody(*(const_cast(subgraph_node.GetFunctionBody()))); } - AddNode(subgraph_node.Name(), subgraph_node.OpType(), subgraph_node.Description(), - inputs, - outputs, - &subgraph_node.GetAttributes(), - subgraph_node.Domain()); } } + ORT_RETURN_IF_ERROR(this->Resolve()); return Status::OK(); } @@ -3847,7 +3925,7 @@ void Graph::SetOutputs(const std::vector& outputs) { GraphResolveNeeded(true); } -void Graph::SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto) { +void Graph::SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto) { arg.SetType(type_proto); GraphResolveNeeded(true); } diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index f6bc837690cd9..45a7b9434b6a4 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -45,7 +45,7 @@ Model::Model(const std::string& graph_name, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList& local_registries, const std::unordered_map& domain_to_version, - const std::vector&, + const std::vector& model_local_functions, const logging::Logger& logger) : model_path_(Path::Parse(model_path)) { model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); @@ -82,10 +82,17 @@ Model::Model(const std::string& graph_name, opset_id_proto->set_version(domain.second); } + std::vector model_functions; + for (auto& func : model_local_functions) { + auto func_ptr = model_proto_.add_functions(); + func_ptr->CopyFrom(func); + model_functions.emplace_back(func_ptr); + } + // need to call private ctor so can't use make_shared GSL_SUPPRESS(r .11) graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, - logger)); + model_functions, logger)); } Model::Model(const ModelProto& model_proto, const PathString& model_path, @@ -176,9 +183,14 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, } } + std::vector model_local_functions; + for (auto& func : model_proto_.functions()) { + model_local_functions.emplace_back(&func); + } + // create instance. need to call private ctor so can't use make_unique GSL_SUPPRESS(r .11) - graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger)); + graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, model_local_functions, logger)); } Version Model::IrVersion() const { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 37ca8159910b8..85bdab8c06be4 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -50,7 +50,7 @@ class Model { const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList& local_registries, const std::unordered_map& domain_to_version, - const std::vector& model_specific_functions, + const std::vector& model_local_functions, const logging::Logger& logger); // NOTE: after calling this constructor, <*this> model will diff --git a/onnxruntime/core/graph/model_load_utils.h b/onnxruntime/core/graph/model_load_utils.h index 5b1c793aeeab6..f5aa45d19164f 100644 --- a/onnxruntime/core/graph/model_load_utils.h +++ b/onnxruntime/core/graph/model_load_utils.h @@ -72,5 +72,11 @@ inline void ValidateOpsetForDomain(const std::unordered_map& o } } } + +/** Generates a unique identifier for the given FunctionProto using the function proto domain and name. +*/ +inline std::string GetModelLocalFuncId(const ONNX_NAMESPACE::FunctionProto& function_proto) { + return function_proto.domain() + ":" + function_proto.name(); +} } //namespace model_load_utils } // namespace onnxruntime diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 09e3b74f47e2b..aea38d17c0412 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -11,6 +11,8 @@ #include "test/providers/provider_test_utils.h" //For ASSERT_STATUS_OK #include "test/test_environment.h" #include "gtest/gtest.h" +#include "onnx/defs/function.h" +#include "onnx/defs/parser.h" using namespace onnxruntime; using namespace ONNX_NAMESPACE; @@ -187,5 +189,239 @@ TEST_F(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) { ASSERT_STATUS_OK(model->MainGraph().Resolve()); } +// The following tests verify ORT can successfully load models which reference functions +// present in the ModelProto aka model local functions. This feature was added to ONNX standard starting IRv8 + +void BuildFunction(FunctionProto& function_proto, + const std::string& name, const std::string& domain, + const std::vector& nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& opset_imports) { + for (const auto& node : nodes) { + auto new_node = function_proto.add_node(); + new_node->CopyFrom(node); + } + + function_proto.set_name(name); + function_proto.set_domain(domain); + function_proto.set_doc_string("Test function proto"); + + for (auto& input : inputs) + function_proto.add_input(input); + + for (auto& output : outputs) + function_proto.add_output(output); + + for (auto& opset_import : opset_imports) { + auto* func_opset_import = function_proto.mutable_opset_import()->Add(); + func_opset_import->set_domain(opset_import.first); + func_opset_import->set_version(opset_import.second); + } +} + +void BuildFunctionFoo(FunctionProto& function_proto, const std::string& domain) { + auto func_body_nodes = FunctionBodyHelper::BuildNodes( + {// nodes: {outputs, op, inputs, attributes} + FunctionBodyHelper::Const("Q_Min", 0.f), + FunctionBodyHelper::Const("Q_Max", 255.f), + {{"X_Min"}, "ReduceMin", {"x"}, {MakeAttribute("keepdims", int64_t(0))}}, + {{"X_Max"}, "ReduceMax", {"x"}, {MakeAttribute("keepdims", int64_t(0))}}, + {{"X_Range"}, "Sub", {"X_Max", "X_Min"}}, + {{"s"}, "Div", {"X_Range", "Q_Max"}}, + {{"zp_fp"}, "Sub", {"Q_Min", "s"}}, + {{"zp"}, "Cast", {"zp_fp"}, {MakeAttribute("to", int64_t(2))}}, + {{"y"}, "QuantizeLinear", {"x", "s", "zp"}}}); + + BuildFunction(function_proto, "foo", domain, func_body_nodes, {"x"}, {"y"}, {{"", 13}}); +} + +void RunFunctionTests(ModelProto&& model_proto) { + std::shared_ptr model; + std::shared_ptr registry = std::make_shared(); + std::list> regs = {registry}; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, ®s, + *(DefaultLoggingManager().CreateLogger("GraphTest")))); + + // Test function inline + auto& graph = model->MainGraph(); + bool function_inlined = false; + do { + function_inlined = false; + for (auto& node : graph.Nodes()) { + if (node.GetFunctionBody() != nullptr) { + ASSERT_STATUS_OK(graph.InlineFunction(node)); + function_inlined = true; + break; + } + } + } while (function_inlined); + + ASSERT_STATUS_OK(graph.Resolve()); +} + +// Tests: +// 1. Function initialization and inlining. +// 2. Input\output name handling when intermediate function body node input\outputs have same names as outer graph. +// 3. Input\output name handling when function body input output names don't match node input output names +TEST(FunctionVerification, TestModelLocalFunctions) { + const char* code = R"ONNX( +< + ir_version: 8, + opset_import: [ "" : 13, "custom_domain" : 1], + producer_name: "FunctionProtoTest", + producer_version: "1.0", + model_version: 1, + doc_string: "A test model for model local functions." +> +agraph (float[N] x) => (uint8[N] s) +{ + t = custom_domain.foo(x) + s = Identity(t) +} +)ONNX"; + + ModelProto model_proto; + ONNX_NAMESPACE::OnnxParser parser(code); + auto status = parser.Parse(model_proto); + EXPECT_TRUE(status.IsOK()); + EXPECT_TRUE(parser.EndOfInput()); + + auto* function_proto = model_proto.mutable_functions()->Add(); + BuildFunctionFoo(*function_proto, "custom_domain"); + + RunFunctionTests(std::move(model_proto)); +} + +// Tests Input\Output name handling where function output is consumed by function body node as well. +// This is treated as a special case because we need to test that the node arg name is +// handled properly. Specially when this function output is also remapped to node output. +TEST(FunctionVerification, TestModelLocalFunctionsWithMultipleOutputs) { + const char* code = R"ONNX( +< + ir_version: 8, + opset_import: [ "" : 13, "custom_domain" : 1], + producer_name: "FunctionProtoTest", + producer_version: "1.0", + model_version: 1, + doc_string: "A test model for model local functions." +> +agraph (float[N] x) => (float[N] out) +{ + o1, o2 = custom_domain.bar(x) + out = Add(o1, o2) +} +)ONNX"; + + ModelProto model_proto; + ONNX_NAMESPACE::OnnxParser parser(code); + auto status = parser.Parse(model_proto); + EXPECT_TRUE(status.IsOK()); + EXPECT_TRUE(parser.EndOfInput()); + + auto function_proto = model_proto.mutable_functions()->Add(); + auto func_body_nodes = FunctionBodyHelper::BuildNodes( + {// nodes: {outputs, op, inputs, attributes, domain} + {{"o2"}, "Identity", {"x"}}, + {{"o1"}, "Identity", {"o2"}}}); + BuildFunction(*function_proto, "bar", "custom_domain", + func_body_nodes, {"x"}, {"o1", "o2"}, {{"", 13}}); + + RunFunctionTests(std::move(model_proto)); +} + +// Tests: +// 1. Nested functions initialization and inlining. +// 1. Input\output name handling when intermediate function body node input\outputs have same names as outer graph. +// 2. Input\output name handling when function body input output names don't match node input output names +TEST(FunctionVerification, TestNestedModelLocalFunctions) { + const char* code = R"ONNX( +< + ir_version: 8, + opset_import: [ "" : 13, "custom_domain" : 1], + producer_name: "FunctionProtoTest", + producer_version: "1.0", + model_version: 1, + doc_string: "A test model for model local functions." +> +agraph (float[N] x) => (uint8[N] zp) +{ + c = custom_domain.foo(x) + zp = Identity(c) +} +)ONNX"; + + ModelProto model_proto; + ONNX_NAMESPACE::OnnxParser parser(code); + auto status = parser.Parse(model_proto); + EXPECT_TRUE(status.IsOK()); + EXPECT_TRUE(parser.EndOfInput()); + + auto* function_proto = model_proto.mutable_functions()->Add(); + BuildFunctionFoo(*function_proto, "custom_domainA"); + + // Build second function proto + // intentionally using same function name to test + // that domainA:name and domainB:name are allowed. + function_proto = model_proto.mutable_functions()->Add(); + auto func_body_nodes = FunctionBodyHelper::BuildNodes( + {// nodes: {outputs, op, inputs, attributes, domain} + {{"out"}, "foo", {"x"}, {}, "custom_domainA"}, + {{"s"}, "Identity", {"out"}}}); + BuildFunction(*function_proto, "foo", "custom_domain", + func_body_nodes, {"x"}, {"s"}, {{"", 13}, {"custom_domainA", 1}}); + + RunFunctionTests(std::move(model_proto)); +} + +// Tests: +// 1. Function initialization and inlining when there are multiple references to the same function +// from within a function body and directly from a graph +// 2. Input\output and node names are handled correctly (.i.e unique names are generated where necessary) when inlining the +// same function multiple times in the graph. +// 3. Unique names are generated for intermediate node input\outputs when they match the names of node input\outputs +TEST(FunctionVerification, TestNestedModelLocalFunctionsWithMultipleReferences) { + const char* code = R"ONNX( +< + ir_version: 8, + opset_import: [ "" : 13, "custom_domain" : 1, "custom_domainA" : 1], + producer_name: "FunctionProtoTest", + producer_version: "1.0", + model_version: 1, + doc_string: "A test model for model local functions." +> +agraph (float[N] x, float[N] y) => (float[N] zp) +{ + c = custom_domain.bar(x) + zp1 = Cast(c) + d = custom_domainA.foo(y) + zp2 = Cast(d) + zp = Sub(zp1, zp2) +} +)ONNX"; + + ModelProto model_proto; + ONNX_NAMESPACE::OnnxParser parser(code); + auto status = parser.Parse(model_proto); + EXPECT_TRUE(status.IsOK()); + EXPECT_TRUE(parser.EndOfInput()); + + auto* function_proto = model_proto.mutable_functions()->Add(); + BuildFunctionFoo(*function_proto, "custom_domainA"); + + // Build second function proto + // intentionally using same function name to test + // that domainA:name and domainB:name are allowed. + function_proto = model_proto.mutable_functions()->Add(); + auto func_body_nodes = FunctionBodyHelper::BuildNodes( + {// nodes: {outputs, op, inputs, attributes, domain} + {{"s"}, "foo", {"x"}, {}, "custom_domainA"}, + {{"out"}, "Identity", {"s"}}}); + BuildFunction(*function_proto, "bar", "custom_domain", func_body_nodes, + {"x"}, {"out"}, {{"", 13}, {"custom_domainA", 1}}); + + RunFunctionTests(std::move(model_proto)); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc index cdb92b8742b05..665114df71572 100644 --- a/orttraining/orttraining/test/gradient/function_ops_test.cc +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -141,7 +141,7 @@ void CheckDropoutGradWithoutRatio(bool inline_call) { auto model = testCase.CreateModel(inline_call); if (!inline_call) { auto& node = *model->MainGraph().Nodes().begin(); - auto* fnbody = node.GetFunctionBody(true); + auto* fnbody = node.GetMutableFunctionBody(true); EXPECT_EQ(fnbody, nullptr); } } diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 6d26f52fa4728..110e3471dc112 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=41.4.0 wheel -git+http://github.com/onnx/onnx.git@d75fb0502c9d8fef817d82c15223b4aaae8e8b6e#egg=onnx +git+http://github.com/onnx/onnx.git@1f63dcb7fcc3a8bf5c3c8e326867ecd6f5c43f35#egg=onnx argparse sympy==1.1.1 flatbuffers