Skip to content

Commit

Permalink
add model local function support (#8540)
Browse files Browse the repository at this point in the history
* updates for picking pnnx commit

* add tests filter to c# tests

* plus test fixes

* fix versioning for contrib ops

* fix tests

* test filter for optional ops

* more versioning related updates

* fix test

* fix layernorm spec

* more updates

* update docs

* add more test filters

* more filters

* update binary size threshold

* update docs

* draft - enable model local function

* enable model local functions in ORT

* update to latest rel onnx commit

* plus tests

* plus more updates

* plus updates

* test updates

* Fix for nested functions + shape inference

* plus bug fix and updates per review

* plus fixes per review

* plus test updates

* plus updates per review

* plus fixes

* fix a test
  • Loading branch information
askhade authored Sep 8, 2021
1 parent b7b42e0 commit ec63d10
Show file tree
Hide file tree
Showing 13 changed files with 753 additions and 91 deletions.
3 changes: 3 additions & 0 deletions include/onnxruntime/core/graph/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Function {

/** Gets the Graph instance for the Function body subgraph. */
virtual const onnxruntime::Graph& Body() const = 0;

/** Gets the Mutable Graph instance for the Function body subgraph. */
virtual onnxruntime::Graph& MutableBody() = 0;
};

/**
Expand Down
36 changes: 26 additions & 10 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,19 @@ class Node {

/**
Gets the function body if applicable otherwise nullptr
@param try_init_func_body If not already intialized, initialize the function body
(only applicable to operators which are defined as function in ONNX spec).
Function body can be initialized in 2 cases :
@param try_init_func_body If not already initialized, initialize the function body
(This is not applicable for primitive operators.)
Function body can be initialized in 3 cases :
1. For nodes of type "Fused"
2. For nodes which are defined as functions in ONNX spec (example: DynamicQuantizeLinear)
2. For nodes which are defined as functions in the spec (example: DynamicQuantizeLinear)
3. For nodes which reference a model local function. These functions are defined in the model itself and
do not have any schema associated with them.
For all other cases this will always return nullptr.
Nodes of type "Fused" are created during partitioning and the function body
initialization for such nodes also happens during node creation. Therefore,
initialization of function body will happen via this method only in case 2 mentioned above.
initialization of function body will happen via this method only in cases 2 and 3 mentioned above.
*/
const Function* GetFunctionBody(bool try_init_func_body = true);
Function* GetMutableFunctionBody(bool try_init_func_body = true);

/** Gets the function body if applicable otherwise nullptr. */
const Function* GetFunctionBody() const noexcept { return func_body_; }
Expand Down Expand Up @@ -395,6 +397,11 @@ class Node {
execution_provider_type_ = execution_provider_type;
}

/** Sets initialized function body for node. This is called right after function body initialization for a node.
* or during function inlining when a nested function is encountered.
*/
void SetFunctionBody(Function& func);

/** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
@param include_missing_optional_defs Include NodeArgs that are optional and were not provided
Expand Down Expand Up @@ -527,8 +534,6 @@ class Node {
// validate and update the input arg count
common::Status UpdateInputArgCount();

void SetFunctionBody(const Function& func);

const Definitions& GetDefinitions() const noexcept { return definitions_; }
const Relationships& GetRelationships() const noexcept { return relationships_; }

Expand Down Expand Up @@ -558,7 +563,7 @@ class Node {
Node::Type node_type_ = Node::Type::Primitive;

// The function body is owned by graph_
const Function* func_body_ = nullptr;
Function* func_body_ = nullptr;

// Node doc string.
std::string description_;
Expand Down Expand Up @@ -1022,6 +1027,9 @@ class Graph {
/** Initialize function body for the given node */
void InitFunctionBodyForNode(Node& node);

/** Gets Model local functions from the root/parent graph.*/
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& GetModelLocalFunctions() const;

/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
be used as a GraphProto attribute in another Node..
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
Expand All @@ -1045,7 +1053,7 @@ class Graph {
void SetOutputs(const std::vector<const NodeArg*>& outputs);

/** Sets the type of a NodeArg, replacing existing type/shape if any */
void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto);
void SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto);

const Node* GetProducerNode(const std::string& node_arg_name) const {
return GetProducerNodeImpl(*this, node_arg_name);
Expand Down Expand Up @@ -1107,6 +1115,9 @@ class Graph {
// Whether to set that no proto sync is required after resolving.
// Useful for resolving right after loading from a GraphProto.
bool no_proto_sync_required = false;
// When set to true, graph resolve will be called for initialized function bodies as well. This is used
// in case of nested model local functions.
bool traverse_function_body = false;
};

/**
Expand Down Expand Up @@ -1203,6 +1214,7 @@ class Graph {
const std::unordered_map<std::string, int>& domain_to_version,
Version ir_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
const std::vector<const ONNX_NAMESPACE::FunctionProto*>& model_functions,
const logging::Logger& logger);

// internal use by the Graph class only
Expand All @@ -1213,6 +1225,7 @@ class Graph {
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
Graph* parent_graph,
const Node* parent_node,
const std::vector<const ONNX_NAMESPACE::FunctionProto*>& model_functions,
const logging::Logger& logger);

void InitializeStateFromModelFileGraphProto();
Expand Down Expand Up @@ -1391,7 +1404,10 @@ class Graph {
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;

// Container to hold initialized function bodies
std::vector<std::unique_ptr<onnxruntime::Function>> function_container_;

std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> model_local_functions_;
#endif

// Graph nodes.
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std:

common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
const Path& model_path,
ONNX_NAMESPACE::TensorProto& tensor) {
ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name) {
const AttributeProto& constant_attribute = node.attribute(0);

switch (constant_attribute.type()) {
Expand Down Expand Up @@ -836,11 +836,17 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n
}

// set name last in case attribute type was tensor (would copy over name)
*(tensor.mutable_name()) = node.output(0);
*(tensor.mutable_name()) = tensor_name;

return Status::OK();
}

common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
const Path& model_path,
ONNX_NAMESPACE::TensorProto& tensor) {
return ConstantNodeProtoToTensorProto(node, model_path, tensor, node.output(0));
}

#if !defined(DISABLE_SPARSE_TENSORS)
static Status CopySparseData(size_t n_sparse_elements,
const ONNX_NAMESPACE::TensorProto& indices,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto&
// However if AttributeProto contains SparseTensorProto then it converts the data into dense tensor proto
// (including loading external data when applicable).
// model_path is used for contructing full path for external_data
// tensor_name specifies the name for the new TensorProto TensorProto
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
const Path& model_path,
ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name);

common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
const Path& model_path,
ONNX_NAMESPACE::TensorProto& tensor);
Expand Down
Loading

0 comments on commit ec63d10

Please sign in to comment.