From 1fe6bae43e62eb91b307ea7875af162721898e3a Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 24 Oct 2024 13:13:23 -0400 Subject: [PATCH 01/11] Adding new Python package testing pipeline for CUda Alt --- .../py-cuda-alt-package-test-pipeline.yml | 24 +++++++++++++++++++ .../py-cuda-package-test-pipeline.yml | 2 +- .../py-package-test-pipeline.yml | 16 ------------- 3 files changed, 25 insertions(+), 17 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml new file mode 100644 index 0000000000000..0a8fe2f50a29f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml @@ -0,0 +1,24 @@ +resources: + pipelines: + - pipeline: build + source: 'Python CUDA ALT Packaging Pipeline' + trigger: true + branch: main # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + +stages: + # ****The following Stage depend on all previous tags. *** + # GPU resources are very limited, + # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed + - stage: Linux_Test_CUDA_Alt_x86_64_stage + dependsOn: + jobs: + - template: templates/py-packaging-linux-test-cuda.yml + parameters: + arch: 'x86_64' + machine_pool: 'Onnxruntime-Linux-GPU' + python_wheel_suffix: '_gpu' + timeout: 480 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 + trt_version: '10.4.0.26-1.cuda11.8' + cuda_version: '11.8' + diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index e946fedd07a27..5094c56956978 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -9,7 +9,7 @@ stages: # ****The following Stage depend on all previous tags. *** # GPU resources are very limited, # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed - - stage: Linux_Test_GPU_x86_64_stage + - stage: Linux_Test_CUDA_x86_64_stage dependsOn: jobs: - template: templates/py-packaging-linux-test-cuda.yml diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c458f0cf4bfe2..d85b725181181 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -42,22 +42,6 @@ stages: # GPU resources are very limited, # To utilize gpu resource more efficiently, run GPU job only after all cpus jobs succeed -- stage: Linux_Test_GPU_x86_64_stage - dependsOn: - - Linux_Test_CPU_x86_64_stage - - Linux_Test_CPU_aarch64_stage - - Packages_Somking_Test - jobs: - - template: templates/py-packaging-linux-test-cuda.yml - parameters: - arch: 'x86_64' - machine_pool: 'Onnxruntime-Linux-GPU' - python_wheel_suffix: '_gpu' - timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 - trt_version: '10.4.0.26-1.cuda11.8' - cuda_version: '11.8' - # if final job not extecuted, it will not run nightlly build - stage: Final From e97374203bb555872120b1370f7b3c72c4904743 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 24 Oct 2024 13:17:43 -0400 Subject: [PATCH 02/11] Remove Linux_Test_GPU_x86_64_stage from stage final --- .../ci_build/github/azure-pipelines/py-package-test-pipeline.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index d85b725181181..622bf5bc8387a 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -46,7 +46,6 @@ stages: # if final job not extecuted, it will not run nightlly build - stage: Final dependsOn: - - Linux_Test_GPU_x86_64_stage jobs: - job: Final # Run this step only if all previous steps are succeeded and (this build was triggered by a resource trigger or it was triggered by another build). From 2a5be8ca7a02a165efdc9719563ff3cba81d6536 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 24 Oct 2024 13:20:57 -0400 Subject: [PATCH 03/11] Adding dependencies --- .../github/azure-pipelines/py-package-test-pipeline.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 622bf5bc8387a..a0e49692220f9 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -46,6 +46,9 @@ stages: # if final job not extecuted, it will not run nightlly build - stage: Final dependsOn: + - Linux_Test_CPU_x86_64_stage + - Linux_Test_CPU_aarch64_stage + - Packages_Somking_Test jobs: - job: Final # Run this step only if all previous steps are succeeded and (this build was triggered by a resource trigger or it was triggered by another build). From c5b6be045ff58b706390bd7504e7d865e451a689 Mon Sep 17 00:00:00 2001 From: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:24:48 -0700 Subject: [PATCH 04/11] enable serialize prepacked weights into data file (#22256) ### Description part of https://github.com/microsoft/onnxruntime/issues/21448 This change is intend to save CPU memory during model load for inference. Added session option save_prepacked_constant_initializers, with save_prepacked_constant_initializers turn on: 1. optimize model with inference session, prepacked external initializer will be saved into data file. 2. load optimized model and external data file with prepacked initializer, no prepack is needed 3. run inference with optimized model and data file Tested with model Phi-3-mini-instruct-onnx, with ORT 1.12.0: ![image](https://github.com/user-attachments/assets/3c0337be-f340-4bb7-8f9f-30f3552072ef) with this change: ![image](https://github.com/user-attachments/assets/23282990-2e1e-4a1f-92de-afa8ed7e6a43) Peak memory usage dropped from **5.438 GB to 2.726GB**. This change takes advantage of ORT loads external initializer with mmap on CPU. Prepack will use extra memory on heap, omit prepack process can save this part of memory (roughly same size as external initializers). next step: Change all the kernels on CPU with PrePack method implemented and test properly. Will do in next PR. ### Motivation and Context --- .../onnxruntime/core/framework/op_kernel.h | 22 +++ include/onnxruntime/core/graph/graph.h | 29 ++- .../onnxruntime_session_options_config_keys.h | 6 + onnxruntime/contrib_ops/cpu/bert/attention.cc | 2 + .../cpu/quantization/attention_quant.cc | 2 + .../cpu/quantization/dynamic_quantize_lstm.cc | 3 +- .../cpu/quantization/matmul_nbits.cc | 56 ++++++ .../contrib_ops/cpu/skip_layer_norm.cc | 1 + onnxruntime/contrib_ops/cpu/skip_layer_norm.h | 2 +- .../contrib_ops/cuda/diffusion/group_norm.cc | 1 + .../contrib_ops/cuda/diffusion/group_norm.h | 1 + .../qordered_ops/qordered_attention.cc | 1 + .../qordered_ops/qordered_attention.h | 1 + .../qordered_ops/qordered_matmul.cc | 1 + .../qordered_ops/qordered_matmul.h | 1 + onnxruntime/core/framework/session_options.h | 6 + onnxruntime/core/framework/session_state.cc | 85 +++++++-- onnxruntime/core/framework/session_state.h | 33 +++- .../core/framework/session_state_utils.cc | 13 +- .../core/framework/session_state_utils.h | 4 +- .../framework/tensor_external_data_info.cc | 2 + .../framework/tensor_external_data_info.h | 3 + .../core/framework/tensorprotoutils.cc | 29 ++- onnxruntime/core/framework/tensorprotoutils.h | 12 +- onnxruntime/core/framework/utils.cc | 6 + onnxruntime/core/framework/utils.h | 2 + onnxruntime/core/graph/graph.cc | 175 ++++++++++++------ onnxruntime/core/graph/model.cc | 29 ++- onnxruntime/core/graph/model.h | 24 ++- .../core/providers/cpu/fp16/fp16_conv.cc | 2 + onnxruntime/core/providers/cpu/math/gemm.cc | 3 +- onnxruntime/core/providers/cpu/math/gemm.h | 1 + onnxruntime/core/providers/cpu/math/matmul.cc | 1 + onnxruntime/core/providers/cpu/math/matmul.h | 1 + .../core/providers/cpu/nn/conv_transpose.cc | 2 + .../core/providers/cpu/nn/conv_transpose.h | 1 + .../core/providers/cpu/nn/layer_norm_impl.cc | 1 + .../core/providers/cpu/nn/layer_norm_impl.h | 2 +- .../cpu/quantization/matmul_integer_base.h | 1 + .../providers/cpu/quantization/qlinearconv.cc | 2 + .../core/providers/cpu/rnn/deep_cpu_gru.cc | 1 + .../core/providers/cpu/rnn/deep_cpu_gru.h | 3 +- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 4 +- .../core/providers/cpu/rnn/deep_cpu_lstm.h | 1 + onnxruntime/core/providers/cuda/nn/conv.cc | 1 + onnxruntime/core/providers/cuda/nn/conv.h | 1 + .../core/providers/cuda/nn/conv_transpose.cc | 3 +- .../core/providers/cuda/nn/conv_transpose.h | 1 + .../core/providers/js/operators/conv.h | 1 + .../providers/js/operators/conv_transpose.h | 2 + .../core/providers/xnnpack/math/gemm.cc | 1 + .../core/providers/xnnpack/math/gemm.h | 1 + .../core/providers/xnnpack/math/matmul.cc | 1 + .../core/providers/xnnpack/math/matmul.h | 1 + onnxruntime/core/providers/xnnpack/nn/conv.cc | 1 + onnxruntime/core/providers/xnnpack/nn/conv.h | 1 + .../providers/xnnpack/nn/conv_transpose.cc | 1 + .../providers/xnnpack/nn/conv_transpose.h | 1 + onnxruntime/core/session/inference_session.cc | 40 +++- .../test/framework/inference_session_test.cc | 54 ++++++ .../save_model_with_external_initializers.cc | 59 +++++- .../test/framework/session_state_test.cc | 66 ++++++- onnxruntime/test/shared_lib/test_inference.cc | 83 +++++++++ .../model_with_external_initializers.onnx | 9 +- .../model_with_external_initializers.py | 3 +- .../testdata/model_with_orig_ext_data.onnx | 9 +- .../test/testdata/prepack/MatMul.Weight.bin | Bin 0 -> 8 bytes ...xternal_initializers_and_prepack_kernel.py | 88 +++++++++ .../prepack/model_with_matmul_nbits.onnx | Bin 0 -> 333 bytes orttraining/orttraining/models/bert/main.cc | 1 + .../orttraining/models/pipeline_poc/main.cc | 1 + .../models/runner/training_runner.cc | 1 + 72 files changed, 872 insertions(+), 137 deletions(-) create mode 100644 onnxruntime/test/testdata/prepack/MatMul.Weight.bin create mode 100644 onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py create mode 100644 onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 07625c38d8474..a17da2a19bb99 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -79,6 +79,7 @@ class OpKernel { // the allocator tied to the session if the kernel owns the pre-packed buffer or an // allocator shared between sessions if the pre-packed buffer is to be shared across sessions // (i.e.) the kernel does not own the buffer. + // @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file. // @param is_packed: Set it to true if the kernel packed the tensor or to false // The kernel is responsible for keeping the packed data and related metadata if is_packed is true, // and the original initialized constant tensor will be released and not accessible anymore in @@ -88,6 +89,7 @@ class OpKernel { virtual Status PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool, /*save_prepacked_initializers*/ /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; return Status::OK(); @@ -129,6 +131,26 @@ class OpKernel { return Status::OK(); } + // Override this function to get pre-packed tensors from this kernel. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + // @param input_idx : The index of input we prepacked before and intend to get packed tensor back. + // Please refer to matmul_nbits kernel for a complete example. + virtual std::optional GetPrePackTensor(int /*input_idx*/) { + return std::nullopt; + } + + // Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + // Please refer to matmul_nbits kernel for a complete example. + // @param input_idx : The input index of the tensor in this kernel. + // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor + // to restore prepacked weight buffer. + virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) { + return Status::OK(); + } + const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index eb9581e8018d1..69af3c93d7a07 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1148,6 +1148,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); #endif + // Since one constant initializer could be used by different kernels + // and prepacked differently, use an unordered_map to store prepacked + // initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>> + typedef std::unordered_map> PrePackedTensorProtoToSave; + #if !defined(ORT_MINIMAL_BUILD) /** Gets the GraphProto representation of this Graph. */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); @@ -1182,18 +1187,26 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved in the external file. Initializer smaller than this threshold are included in the onnx file. @param align_info offset alignment info. + @param save_prepacked_constant_initializers whether to save prepacked initializer into external data file. + If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file, + we keep constant initializer as it is. + @param pre_packed_initializers struct used to store all the prepacked initializers. @returns GraphProto serialization of the graph. */ ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info) const; + const OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + PrePackedTensorProtoToSave& pre_packed_initializers) const; ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const { OffsetAlignmentInfo default_options; - return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options); + PrePackedTensorProtoToSave pre_packed_initializers; + return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options, + false, pre_packed_initializers); } /** Gets the ISchemaRegistry instances being used with this Graph. */ @@ -1508,6 +1521,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi private: void InitializeStateFromModelFileGraphProto(); + // Private method used to setup external initializer properly during model save, + // this external initializer could be oroginal initializer or prepacked initializer. + static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, + size_t tensor_bytes_size, + int64_t& external_offset, + std::ofstream& external_stream, + gsl::span raw_data, + ONNX_NAMESPACE::TensorProto& output_proto, + const std::filesystem::path& external_file_path, + const ONNX_NAMESPACE::TensorProto& initializer, + bool is_prepacked); + // Add node with specified . Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 6a01602e634f8..086919913cbea 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -246,6 +246,12 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = "session.optimized_model_external_initializers_file_name"; +// Use this config when save prepacked constant initializers to onnx external data file. +// Default is not save prepacked initializers to onnx data file. +// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1") +static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = + "session.save_prepacked_constant_initializers"; + // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index ad14fb8258656..b15e865aa423c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -30,6 +30,7 @@ class Attention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -101,6 +102,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { /* The PrePack() massages the weights to speed up Compute(), there is an option to diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c897f183164f..71a66ea368943 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -24,6 +24,7 @@ class QAttention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& /*out*/ is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -58,6 +59,7 @@ QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC template Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index aa47f365c0005..4148aae4b9a35 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -91,6 +91,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 89e96543c4729..cee3dfc6b3f28 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -98,12 +98,19 @@ class MatMulNBits final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; + void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx); + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) override; + std::optional GetPrePackTensor(int /*input_idx*/) override; + + Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override; + private: const size_t K_; const size_t N_; @@ -119,6 +126,8 @@ class MatMulNBits final : public OpKernel { size_t packed_b_size_{0}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; + std::optional packed_tensor_{std::nullopt}; + MLDataType prepack_tensor_data_type_; bool has_zp_input_{false}; @@ -148,8 +157,22 @@ class MatMulNBits final : public OpKernel { } }; +template +void MatMulNBits::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) { + if (input_idx == InputIndex::B) { + prepack_tensor_data_type_ = tensor.DataType(); + } + + TensorShapeVector weights_dims = {static_cast((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1}; + packed_tensor_ = Tensor(prepack_tensor_data_type_, + TensorShape(weights_dims), + packed_b_.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); +} + template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -185,11 +208,16 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All #endif // MLAS_TARGET_AMD64_IX86 } + if (save_prepacked_initializers) { + ConvertPrepackWeightIntoTensor(tensor, input_idx); + } + return Status::OK(); } template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -239,6 +267,34 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou #endif // MLAS_TARGET_AMD64_IX86 } + if (save_prepacked_initializers) { + ConvertPrepackWeightIntoTensor(tensor, input_idx); + } + + return Status::OK(); +} + +template +std::optional MatMulNBits::GetPrePackTensor(int input_idx) { + // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points. + // During compute process, scales and zeros_points will keep as it is and only use prepacked + // buffer to replace input_B. + // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize + // the latest one. So, we need to always return packed_tensor_ here not only for input_B. + ORT_UNUSED_PARAMETER(input_idx); + return std::move(packed_tensor_); +} + +template +Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) { + if (input_idx == 1) { + // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state, + // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so + // pass empty/default buffer deleter here. + // const_cast here is temporary, will fix in follow up PR. + packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter()); + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 67b4950af73bf..c9ee9e2cb760d 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -278,6 +278,7 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 08e2276c3d9d5..d904c14857437 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,7 +16,7 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index dea5391c7629b..d190ed389f3e9 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -95,6 +95,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { } Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index b408b3c1ee79b..4505c066baedb 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -17,6 +17,7 @@ class GroupNorm final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3e93a527877c5..aa2c8755f6536 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -99,6 +99,7 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a } Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h index 9d4e563c1feab..529fd00307d66 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h @@ -20,6 +20,7 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase { public: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index a64f628f245e6..351e36b884540 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -51,6 +51,7 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) { } Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h index dcb6cc6374be1..d1cef99779e09 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h @@ -18,6 +18,7 @@ class QOrderedMatMul final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8d4db36106f28..18405231750ba 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -83,6 +83,11 @@ struct SessionOptions { // enable profiling for this session. bool enable_profiling = false; + // save pre-packed constant external initializers instead of original initializers to onnxruntime data file. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + bool save_prepacked_constant_initializers = false; + // Non empty filepath enables serialization of the transformed optimized model to the specified filepath. // // Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly @@ -191,6 +196,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " execution_mode:" << session_options.execution_mode << " execution_order:" << session_options.execution_order << " enable_profiling:" << session_options.enable_profiling + << " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers << " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath) << " enable_mem_pattern:" << session_options.enable_mem_pattern << " enable_mem_reuse:" << session_options.enable_mem_reuse diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..943db091b341f 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -14,6 +14,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/session_state_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -397,12 +398,18 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type, } Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map) { - auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( + const std::unordered_map& initializers_to_share_map, + bool save_prepacked_constant_initializers, + PrePackInitializers& pre_packed_initializers) { + auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map, + save_prepacked_constant_initializers, &pre_packed_initializers]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { + std::unordered_map pre_packed_kernel_input_map; for (auto& node : GetGraphViewer().Nodes()) { auto kernel = GetMutableKernel(node.Index()); + auto kernel_name = kernel->Info().node().Name(); int input_idx = 0; + bool is_kernel_prepacked = false; for (auto& input_def : node.InputDefs()) { if (input_def->Exists()) { const std::string& input_name = input_def->Name(); @@ -414,16 +421,27 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapGetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) { std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_; - if (constant_initialized_tensors.count(ort_value_idx)) { + if (constant_initialized_tensors.count(ort_value_idx) && !is_kernel_prepacked) { bool is_packed = false; const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get(); auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); + // found pre-packed constant initializers from data file, no need to do pre-packing again + // apply pre-packed tensor to kernel so kernel can use it directly + if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) != 0) { + is_packed = true; + + // kernel like Matmul_nbits will call prepack multiple times with input_B and possibly scales/zero_points. + // If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked + // weights serialized), only need to set prepacked weights once to kernel. + is_kernel_prepacked = true; + ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor)); + } // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && - node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON + else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON AllocatorPtr allocator_for_caching = prepacked_weights_container_->GetOrCreateAllocator(CPU); ORT_ENFORCE(allocator_for_caching.get() != nullptr); @@ -435,7 +453,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapPrePack(const_initialized_tensor, input_idx, allocator_for_caching, - is_packed, + save_prepacked_constant_initializers, is_packed, &weights_to_be_filled_in)); if (is_packed) { @@ -482,18 +500,50 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetDevice(OrtMemType::OrtMemTypeDefault)); ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, // use allocator tied to this session + save_prepacked_constant_initializers, is_packed, nullptr // no caching required )); } if (is_packed) { + // if intended to save prepacked initializers, get prepacked tensors from kernel and save in hashmap, + // will save to data file later + if (save_prepacked_constant_initializers) { + auto tensor = kernel->GetPrePackTensor(input_idx); + + if (tensor != std::nullopt) { + // save prepacked initializers per initializer and kernel since one initializer could + // be used by multiple kernels + pre_packed_initializers.pre_packed_initializers_to_save[input_name][kernel_name] = std::move(tensor.value()); + + pre_packed_kernel_input_map[kernel_name] = input_name; + } + } + ++number_of_prepacks_counter_; - if (constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { + // if constant_initialized_tensor is already pre-packed, don't need to remove it + if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) == 0 && + constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { // release the constant initialized tensor st->initialized_tensors_.erase(ort_value_idx); constant_initialized_tensors.erase(ort_value_idx); } + } else { + // handle prepack for matmul_nbits, it will prepack several times but set is_packed + // to false for scales and zero_points, we keep scales and zero_points as it is only + // update packed_tensor to input_B. + // TODO: this logic works with matmul_nbits kernel but if other kernels also call prepack + // multiple times and use different initializers to store prepacked weights, this piece of logic + // might introduce bug and need a per kernel strategy to update prepacked weights. + if (save_prepacked_constant_initializers && pre_packed_kernel_input_map.count(kernel_name)) { + auto tensor = kernel->GetPrePackTensor(input_idx); + + if (tensor != std::nullopt) { + auto existing_input_name = pre_packed_kernel_input_map[kernel_name]; + pre_packed_initializers.pre_packed_initializers_to_save[existing_input_name][kernel_name] = std::move(tensor.value()); + } + } } } // stop searching in 2 cases: @@ -1176,6 +1226,7 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging:: Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, + PrePackInitializers& pre_packed_initializers, bool remove_initializers, bool saving_ort_format) { // recursively create the subgraph session state instances and populate the kernel create info in them. @@ -1189,7 +1240,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, sess_options_, - remove_initializers, constant_initializers_use_count); + remove_initializers, constant_initializers_use_count, pre_packed_initializers); } static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map, @@ -1323,6 +1374,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count, + PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map, bool graph_info_already_created) { if (!graph_info_already_created) { @@ -1422,6 +1474,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string> + typedef std::unordered_map> PrePackedTensorsToSave; + PrePackedTensorsToSave pre_packed_initializers_to_save; + + // This set is used during model load with prepacked initializer serialized in external data file. + // ORT reads prepacked initializers and store their name into this set so we could skip PrePack + // process later to save heap memory. Prepacked tensor itself is saved in session state's constant_initialized_tensors_. + typedef std::unordered_set PrePackedTensorNamesReadFromFile; + PrePackedTensorNamesReadFromFile pre_packed_initializer_names_read_from_file; + }; + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, + PrePackInitializers& pre_packed_initializers, bool remove_initializers = true, bool saving_ort_format = false); @@ -321,6 +338,15 @@ class SessionState { return parent_; } + Status FinalizeSessionState(const std::basic_string& graph_loc, + const KernelRegistryManager& kernel_registry_manager, + bool remove_initializers = true, + bool saving_ort_format = false) { + PrePackInitializers pre_packed_initializers; + return FinalizeSessionState(graph_loc, kernel_registry_manager, pre_packed_initializers, + remove_initializers, saving_ort_format); + } + // Clear all removable attributes if they exists. // The function logs the list of removable attributes for every node. void PruneRemovableAttributes(); @@ -380,9 +406,13 @@ class SessionState { /** * Prepack the constant initialized tensors for better performance. * The original constant initialized tensors will be removed to save memory. + * For model with prepacked initializer serialized into ONNX data file, + * PrePack will be skipped to save memory. */ Status PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map); + const std::unordered_map& initializers_to_share_map, + bool save_prepacked_constant_initializers, + PrePackInitializers& pre_packed_initializers); SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); @@ -400,6 +430,7 @@ class SessionState { const SessionOptions& session_options, bool remove_initializers, InlinedHashMap& constant_initializers_use_count, + PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 2c74805c57dce..3424f40e79c01 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -21,7 +21,6 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/sequential_execution_plan.h" -#include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/bfc_arena.h" @@ -72,6 +71,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor, OrtCallback& ext_data_deleter, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); @@ -79,7 +79,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, - buffered_tensor)); + &pre_packed_initializers_name_set, buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -100,6 +100,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -139,7 +140,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, buffered_tensor)); + ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; MLDataType ml_tensor_type = DataTypeImpl::GetType(); @@ -163,7 +164,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st OrtCallback ext_data_deleter; std::optional scoped_ort_callback_invoker; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, buffered_tensor)); + ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. @@ -272,7 +273,8 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors) { + std::unordered_map>& buffered_tensors, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -401,6 +403,7 @@ common::Status SaveInitializedTensors( Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, + pre_packed_initializers_name_set, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af27f5caba0f4..4de501b6f7429 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -12,6 +12,7 @@ #include "core/framework/tensor.h" #include "core/framework/tensor_allocator.h" #include "core/framework/session_options.h" +#include "core/framework/session_state.h" #include "core/framework/sequential_execution_plan.h" #include "core/platform/path_lib.h" @@ -50,7 +51,8 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors); + std::unordered_map>& buffered_tensors, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set); common::Status AllocateTensor( const onnxruntime::MemBuffer* m, diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index 93146e66d9f24..bcd04effe2bd4 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -40,6 +40,8 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed"); } else if (stringmap.key() == "checksum" && !stringmap.value().empty()) { out->checksum_ = stringmap.value(); + } else if (stringmap.key() == "prepacked" && !stringmap.value().empty()) { + out->prepacked_ = stringmap.value() == "1"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error!"); } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index afc8fda6c3037..c2490f5cc5bc2 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -23,6 +23,8 @@ class ExternalDataInfo { const std::string& GetChecksum() const { return checksum_; } + bool GetPrePacked() const noexcept { return prepacked_; } + // If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a // wrong value and return FAIL. static common::Status Create( @@ -36,5 +38,6 @@ class ExternalDataInfo { // 0 means the whole file size_t length_ = 0; std::string checksum_; + bool prepacked_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 2af9f95ad059e..0c69ee11f62bc 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -230,11 +230,12 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { -Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { +static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size, + bool& pre_packed) { ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), "Tensor does not have external data to read from."); @@ -244,6 +245,8 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + pre_packed = external_data_info->GetPrePacked(); + const auto& location = external_data_info->GetRelPath(); external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) @@ -265,6 +268,11 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) { + bool pre_packed = false; + return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); +} + void ConvertRawDataInTensorProto(TensorProto* tensor) { size_t element_size = 1; char* bytes = NULL; @@ -988,7 +996,7 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - Tensor* buffered_tensor) { + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -997,8 +1005,13 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; + bool pre_packed = false; ORT_RETURN_IF_ERROR( - GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, pre_packed)); + + if (pre_packed && pre_packed_initializers_name_set != nullptr) { + (*pre_packed_initializers_name_set).insert(tensor_proto.name()); + } if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data @@ -1108,7 +1121,7 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa OrtCallback& d = deleter_for_file_data.d; if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d, nullptr)); } else if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 262f7adaca1cb..770132f8e95fc 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -17,26 +17,19 @@ #include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" +#include "core/framework/session_state.h" #include "core/framework/tensor_external_data_info.h" #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" namespace onnxruntime { namespace utils { -/** - * This function is used to get the external data info from the given tensor proto. - * @param tensor_proto given initializer tensor - * @param tensor_proto_dir directory of the tensor proto file - * @param external_file_path output external file path - * @param file_offset output tensor offset - * @param tensor_byte_size output tensor byte size - * @returns Status::OK() if the function is executed successfully - */ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size); + /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file @@ -172,6 +165,7 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9eed0249711f9..5402345447706 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1064,5 +1064,11 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index return false; } +std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name) { + const std::string seperator = ":"; + + return initializer_name + seperator + node_name; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index afdb5a2cb27f5..db38ef1675595 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -234,6 +234,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); +std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name); + #ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e8a5855b36496..3f50841f50913 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4084,10 +4084,75 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } +void Graph::SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, + size_t tensor_bytes_size, + int64_t& external_offset, + std::ofstream& external_stream, + gsl::span raw_data, + ONNX_NAMESPACE::TensorProto& output_proto, + const std::filesystem::path& external_file_path, + const ONNX_NAMESPACE::TensorProto& initializer, + bool is_prepacked) { + // update external_offset for alignment + // need to do padding before write actual tensor data as we do offset alignment at the begin of + // large tensors (offset need to be page aligned and alloction granularity aligned) like below: + // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX + // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| + if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { + // Align to the larger of the page size or the allocation granularity + int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); + // Align to the next page or alloc granularity boundary + int64_t new_external_offset = static_cast( + std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * + alignment_factor; + + // padding tensor with zeros for alignment + InlinedVector paddings; + size_t padding_size = SafeInt(new_external_offset - external_offset); + paddings.reserve(padding_size); + for (size_t index = 0; index != padding_size; ++index) { + paddings.push_back(0x0); + } + external_stream.write(reinterpret_cast(paddings.data()), padding_size); + + external_offset = new_external_offset; + } + + external_stream.write(reinterpret_cast(raw_data.data()), tensor_bytes_size); + + output_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ONNX_NAMESPACE::StringStringEntryProto* location = output_proto.add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(external_file_path.native())); + ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto.add_external_data(); + offset->set_key("offset"); + offset->set_value(std::to_string(external_offset)); + ONNX_NAMESPACE::StringStringEntryProto* length = output_proto.add_external_data(); + length->set_key("length"); + length->set_value(std::to_string(tensor_bytes_size)); + + if (is_prepacked) { + ONNX_NAMESPACE::StringStringEntryProto* pre_packed = output_proto.add_external_data(); + pre_packed->set_key("prepacked"); + pre_packed->set_value("1"); + } + + output_proto.set_name(initializer.name()); + output_proto.set_data_type(initializer.data_type()); + for (int i = 0; i != initializer.dims_size(); ++i) { + output_proto.add_dims(initializer.dims(i)); + } + output_proto.set_doc_string(initializer.doc_string()); + + external_offset += tensor_bytes_size; +} + ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info) const { + const OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + PrePackedTensorProtoToSave& pre_packed_initializers) const { GraphProto result; ToGraphProtoInternal(result); ORT_ENFORCE(external_file_path.is_relative()); @@ -4106,6 +4171,34 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std #endif for (const auto& initializer : graph_proto_->initializer()) { + bool use_pre_packed_initializer = false; + InlinedVector pre_packed_initializers_tensor_proto; + // If this initializer has been prepacked, saved prepacked external initializer instead of original one. + // Since one initializer could be used by multiple kernels and been prepacked differently, + // Save each prepacked initializers seperately, chagne the initializer name to [initializer_name]:[node_name] + // to avoid conflict. Change the node input name accordingly. + // IT could potentially make the ONNX data file larger since we store multiple prepacked initializers into disk + // but this could be rare case. + if (save_prepacked_constant_initializers && pre_packed_initializers.count(initializer.name())) { + for (const auto& item : pre_packed_initializers[initializer.name()]) { + auto& node_name = item.first; + std::string prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer.name(), node_name); + pre_packed_initializers_tensor_proto.push_back(item.second); + use_pre_packed_initializer = true; + + for (auto& node : *result.mutable_node()) { + if (node.name() == node_name) { + int input_index = 0; + for (const auto& input : node.input()) { + if (input == initializer.name()) { + node.set_input(input_index, prepacked_initializer_name); + } + input_index += 1; + } + } + } + } + } #if !defined(DISABLE_SPARSE_TENSORS) if (sparse_end != sparse_tensor_names_.find(initializer.name())) { // Sparse tensors are added to the ONNX file. @@ -4114,61 +4207,39 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } else { #endif - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - - std::vector raw_data; - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); - size_t tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = initializer; - continue; - } + if (use_pre_packed_initializer) { + for (const auto& pre_packed_initializer : pre_packed_initializers_tensor_proto) { + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + std::vector raw_data; + size_t tensor_bytes_size = 0; + + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(pre_packed_initializer, model_path, raw_data)); + tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = pre_packed_initializer; + continue; + } - // update external_offset for alignment - // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and alloction granularity aligned) like below: - // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX - // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| - if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); - // Align to the next page or alloc granularity boundary - int64_t new_external_offset = static_cast( - std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * - alignment_factor; - - // padding tensor with zeros for alignment - for (int64_t index = external_offset; index != new_external_offset; ++index) { - external_stream << '0'; + SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, + raw_data, *output_proto, external_file_path, pre_packed_initializer, true); + } + } else { + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + std::vector raw_data; + size_t tensor_bytes_size = 0; + + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); + tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = initializer; + continue; } - external_offset = new_external_offset; - } - - for (size_t index = 0; index != tensor_bytes_size; ++index) { - external_stream << raw_data[index]; - } - - output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); - ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); - location->set_key("location"); - location->set_value(ToUTF8String(external_file_path.native())); - ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); - offset->set_key("offset"); - offset->set_value(std::to_string(external_offset)); - ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data(); - length->set_key("length"); - length->set_value(std::to_string(tensor_bytes_size)); - - output_proto->set_name(initializer.name()); - output_proto->set_data_type(initializer.data_type()); - for (int i = 0; i != initializer.dims_size(); ++i) { - output_proto->add_dims(initializer.dims(i)); + SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, + raw_data, *output_proto, external_file_path, initializer, false); } - output_proto->set_doc_string(initializer.doc_string()); - - external_offset += tensor_bytes_size; #if !defined(DISABLE_SPARSE_TENSORS) } #endif diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 1bae63b510563..ad1ec9c8dedb3 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -384,13 +384,17 @@ ModelProto Model::ToProto() const { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) const { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info); + align_info, + save_prepacked_constant_initializers, + pre_packed_initializers); return result; } @@ -608,7 +612,9 @@ static Status SaveModelWithExternalInitializers(Model& model, const T& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { int fd = 0; Status status = Env::Default().FileOpenWr(file_path, fd); ORT_RETURN_IF_ERROR(status); @@ -616,7 +622,8 @@ static Status SaveModelWithExternalInitializers(Model& model, ORT_TRY { status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -647,9 +654,12 @@ Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); } Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { @@ -766,7 +776,9 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); } @@ -775,7 +787,8 @@ Status Model::SaveWithExternalInitializers(Model& model, auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 9bcec6f78ca08..38d9044ff9d31 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -191,13 +191,17 @@ class Model { ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) const; + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const; ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) const { Graph::OffsetAlignmentInfo default_align_info; - return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status Save(Model& model, const PathString& file_path); @@ -210,14 +214,18 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info); + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers); static common::Status SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status SaveWithExternalInitializers(Model& model, @@ -225,7 +233,9 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info); + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers); static common::Status SaveWithExternalInitializers(Model& model, int fd, @@ -233,7 +243,9 @@ class Model { const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 37db095e92570..0a1a3a5995872 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -51,6 +51,7 @@ class FusedConvFp16 final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -101,6 +102,7 @@ class FusedConvFp16 final : public OpKernel { }; Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..dbc7becdf2397 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -248,6 +248,7 @@ template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE template Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc_for_caching*/, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weight_for_caching*/) { is_packed = false; @@ -256,7 +257,7 @@ Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, Allocat template <> Status Gemm::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 953949732560d..92f05a7921f8b 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -21,6 +21,7 @@ class Gemm : protected GemmBase, public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..8f2c2c53b188b 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -173,6 +173,7 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, #endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b9bbe36583879..0bb0e6c2ef596 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -37,6 +37,7 @@ class MatMul final : public OpKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index f0c1b0b409831..2c7afddf38070 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -38,6 +38,7 @@ ONNX_CPU_OPERATOR_KERNEL( template Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/ ) { @@ -47,6 +48,7 @@ Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, Al template <> Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index c82cd5ad49d7e..d03b5566e334f 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -28,6 +28,7 @@ class ConvTranspose : public OpKernel { ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 24a5dcab225c4..fe2bf1035bb65 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -229,6 +229,7 @@ Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { } Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index f8b528b398cba..abce87d03c14b 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -15,7 +15,7 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index e26eae19b8fd4..8a8ce27990069 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -14,6 +14,7 @@ class MatMulIntegerBase : public OpKernel { MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 7797cbe678bd4..736cde24591ff 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -25,6 +25,7 @@ class QLinearConv : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -360,6 +361,7 @@ REGISTER_QLINEARCONV_INT8_KERNEL(kMSDomain, 1); template Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index b78c5236e6fab..7afd00eacef89 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -284,6 +284,7 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& } Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 5a6dd97c7c3f2..914077b2f2c15 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -62,6 +62,7 @@ class DeepCpuGruOp final : public OpKernel { private: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -197,4 +198,4 @@ class UniDirectionalGru { }; } // namespace detail -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 09bbf6c4c79e6..e4082e5d7634a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -225,7 +225,9 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index 9c4c12954022a..ff8ab9abf0eed 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -19,6 +19,7 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { DeepCpuLstmOp(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 3129f519da2e5..45a1d3bbc0414 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -52,6 +52,7 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index e4047a6af272e..6294566af3cb9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -219,6 +219,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2972ae999adc4..9c9a83460daeb 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,7 +45,8 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template -Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 1a6957164d22f..f23c2b94501f2 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -22,6 +22,7 @@ class ConvTranspose : public CudaKernel { ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index b04df44954295..276b600cf40d2 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -78,6 +78,7 @@ class ConvBase : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { is_packed = false; diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 5ff52e8fda4fa..baa93f825a203 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -126,8 +126,10 @@ class ConvTranspose : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { + ORT_UNUSED_PARAMETER(save_prepacked_initializers); is_packed = false; if (input_idx == 1) { diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 35a06cb7eb89f..68b55030c7363 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -117,6 +117,7 @@ Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*ena } Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights*) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index 954aab0698b9c..d632eef015f9a 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -23,6 +23,7 @@ class Gemm : protected GemmBase, public XnnpackKernel { static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 44a6fb4ee835a..71a11cb05d9af 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -78,6 +78,7 @@ MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ } Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*Not used*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index 188cc73189af5..31a8c36ad418b 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -23,6 +23,7 @@ class MatMul : public XnnpackKernel { // Required for checking XNNpack restrictions on ORT side static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 4e6b308e28ae5..f2e697df475da 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -18,6 +18,7 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.h b/onnxruntime/core/providers/xnnpack/nn/conv.h index 3630aae208d49..762b68c8bd49a 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv.h @@ -19,6 +19,7 @@ class Conv : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index b6930a5fc92d1..5729565b2feb9 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -15,6 +15,7 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h index 866b9b6b98365..0313515d10fa1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h @@ -18,6 +18,7 @@ class ConvTranspose : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f5f12c206ebad..e6aafaa1f2283 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2027,9 +2027,11 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } + SessionState::PrePackInitializers pre_packed_initializers; ORT_RETURN_IF_ERROR_SESSIONID_( session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, // need to keep the initializers if saving the optimized model + pre_packed_initializers, !saving_model, saving_ort_format)); @@ -2065,11 +2067,47 @@ common::Status InferenceSession::Initialize() { kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "1024")); Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; + bool save_prepacked_constant_initializers = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsSavePrePackedConstantInitializers, "0") == "1" ? true : false; + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + if (save_prepacked_constant_initializers) { + LOGS(*session_logger_, WARNING) << "Serialize prepacked initializers option has been turn on." + << "Use this option only when run model inference on PC with CPU." + << "Make sure to save and load model in same device as prepack is device specific." + << "Note: this feature in only work with ONNX model format." + << "Process of use this option is like below:" + << "1. Optimize model with external data file with save_prepacked_constant_initializers on:" + << " sample: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', ' 1 ')" + << " With save_prepacked_constant_initializers option, prepacked initializer will be serialized into data file." + << "2. Load optimized model and external data file in same device, no prepack is need." + << "3. Run inference with optimized model."; + + if (fbs::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)) { + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unable to serialize prepacked external constant initializer for ORT format model." + "Please use ONNX format model with save_prepacked_constant_initializers.")); + } + + // convert pre_packed_initializers to tensorproto format and save to external data file + for (const auto& name_item_pair : pre_packed_initializers.pre_packed_initializers_to_save) { + auto initializer_name = name_item_pair.first; + + for (const auto& kernel_name_initializer_item_pair : name_item_pair.second) { + auto kernel_name = kernel_name_initializer_item_pair.first; + auto prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer_name, kernel_name); + + pre_packed_initializers_tensor_proto[initializer_name][kernel_name] = utils::TensorToTensorProto(kernel_name_initializer_item_pair.second, prepacked_initializer_name); + } + } + } ORT_RETURN_IF_ERROR_SESSIONID_(Model::SaveWithExternalInitializers(*model_, session_options_.optimized_model_filepath, optimized_model_external_initializers_file_name, optimized_model_external_initializers_min_size_in_bytes, - align_info)); + align_info, + save_prepacked_constant_initializers, + pre_packed_initializers_tensor_proto)); } } } diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 61a8f7e23fe87..da5fa2c3a5a24 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -45,6 +45,7 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "dummy_provider.h" @@ -64,6 +65,8 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; using namespace onnxruntime::concurrency; +extern std::unique_ptr ort_env; + namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); @@ -496,6 +499,57 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } +// Test feature serialize prepack weight is only used in PC with CPU on inference, +// disable this test for training, other device and eps +#if !ENABLE_TRAINING && !defined(USE_CUDA) && !defined(__wasm__) && !defined(USE_DNNL) && !defined(USE_QNN) && !defined(__ANDROID__) && !defined(USE_COREML) +// MLAS dispatcher used in matmul_nbits kernels here is 64 bit only +#if defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64) +TEST(InferenceSessionTests, TestPrePackSerialization) { + SessionOptions so; + std::string model_name = "model_with_matmul_nbits"; + + const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; + const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; + + so.session_logid = "InferenceSessionTests.TestPrepackSerialization"; + so.enable_cpu_mem_arena = false; + so.graph_optimization_level = TransformerLevel::Default; + so.optimized_model_filepath = optimized_model; + std::string external_initializer_file_name = model_name + "_opt.onnx.data"; + + // enable serialize prepack initializer to data file + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, + "1")); + // always save external initializer to data file for test + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, + "0")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, + external_initializer_file_name.c_str())); + + // optimize model with serialize prepack constant initializers + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_TRUE(session_object.Load(test_model).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // Verify prepack initializers are serialized into optimized model and data file + // load optimized model and check initializer are prepacked + auto logger = DefaultLoggingManager().CreateLogger("TestPrepackSerialization"); + std::shared_ptr model; + auto load_status = Model::Load(ToWideString(optimized_model), model, nullptr, *logger); + ASSERT_EQ(Status::OK(), load_status); + Graph& graph = model->MainGraph(); + + bool found_prepack_initializer = false; + for (const auto& item : graph.GetAllInitializedTensors()) { + if (item.first.find(':') != std::string::npos) { + found_prepack_initializer = true; + } + } + ASSERT_TRUE(found_prepack_initializer); +} +#endif +#endif + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index d0bc088175755..0f76cb61ace74 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -7,6 +7,7 @@ #include "core/framework/data_types.h" #include "core/graph/model.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/session_state.h" #include "test/test_environment.h" #include "test_utils.h" #include "test/util/include/asserts.h" @@ -19,19 +20,34 @@ using namespace onnxruntime; namespace onnxruntime { namespace test { +std::vector split(const std::string& str, char delimiter) { + std::vector result; + std::stringstream ss(str); + std::string token; + + // Use getline with a delimiter to split the string + while (std::getline(ss, token, delimiter)) { + result.push_back(token); + } + + return result; +} + Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, const std::filesystem::path& input_external_init_file, const std::filesystem::path& output_onnx, const std::filesystem::path& output_external_init_file, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers_tensor_proto, + bool save_prepacked_constant_initializers = false) { auto logger = DefaultLoggingManager().CreateLogger("LoadSaveAndCompareModel"); std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(input_onnx, model, nullptr, *logger)); std::filesystem::remove(output_onnx); std::filesystem::remove(output_external_init_file); ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(*model, output_onnx, output_external_init_file, initializer_size_threshold, - align_info)); + align_info, save_prepacked_constant_initializers, pre_packed_initializers_tensor_proto)); std::shared_ptr model_from_external; ORT_RETURN_IF_ERROR(Model::Load(output_onnx.native(), model_from_external, nullptr, *logger)); @@ -50,10 +66,11 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Compare the initializers of the two versions. std::filesystem::path model_path{}; std::filesystem::path external_data_path{}; - for (const auto& i : initializers) { + for (const auto& i : initializers_from_external) { const std::string kInitName = i.first; - const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; - const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; + const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = i.second; + // prepack initializer will have name as [original name]:[kernel name] in case initializer used by multiple kernels + const ONNX_NAMESPACE::TensorProto* tensor_proto = save_prepacked_constant_initializers ? initializers[split(kInitName, ':')[0]] : initializers[kInitName]; std::vector tensor_proto_data; model_path = input_onnx; @@ -75,8 +92,12 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL, "location mismatch"); } - ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); - ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); + if (!save_prepacked_constant_initializers) { + ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); + ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); + } else { + ORT_RETURN_IF_NOT(from_external_tensor_proto_size >= tensor_proto_size, "prepack initializer's size is at least same as original tensor, might be larger"); + } if (align_info.align_offset) { for (const StringStringEntryProto& entry : from_external_tensor_proto->external_data()) { @@ -89,6 +110,7 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, } } } + // Cleanup. ORT_RETURN_IF_NOT(std::filesystem::remove(output_onnx), "delete file failed"); ORT_RETURN_IF_NOT(std::filesystem::remove(external_data_path), "delete file failed"); @@ -98,13 +120,15 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Original model does not have external initializers TEST(SaveWithExternalInitializers, Mnist) { Graph::OffsetAlignmentInfo align_info; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info, pre_packed_initializers_tensor_proto)); } // Original model has external initializers TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) { Graph::OffsetAlignmentInfo align_info; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); } // Original model has external initializers, align offset @@ -112,7 +136,22 @@ TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffset) { Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; align_info.align_threshold = 0; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); +} + +// Original model has external initializers, align offset and serialize prepacked external initializer to model file +TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffsetAndSavePrepackTensors) { + Graph::OffsetAlignmentInfo align_info; + align_info.align_offset = true; + align_info.align_threshold = 0; + std::shared_ptr alloc = std::make_shared(); + TensorShape shape = {178}; + // prepack both initializers for test purpose + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + pre_packed_initializers_tensor_proto["MatMul.Weight"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "MatMul.Weight:MatMul_0"); + pre_packed_initializers_tensor_proto["scales"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "scales:MatMul_0"); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/prepack/model_with_matmul_nbits.onnx"), ORT_TSTR("model_with_matmul_nbits.onnx.data"), ORT_TSTR("testdata/prepack/model_with_matmul_nbits_opt.onnx"), ORT_TSTR("model_with_matmul_nbits_opt.onnx.data"), 0, align_info, pre_packed_initializers_tensor_proto, true)); } } // namespace test diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b94d24a1b180b..6265eccb7bd9b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -372,10 +372,11 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { ORT_UNUSED_PARAMETER(tensor); ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(save_prepacked_initializers); size_t weight_packed_len = 8; weight_packed_ = IAllocator::MakeUniquePtr(alloc, weight_packed_len, true); @@ -393,9 +394,20 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } + std::optional GetPrePackTensor(int input_idx) override { + ORT_UNUSED_PARAMETER(input_idx); + ++get_prepack_tensors_count; + + TensorShape shape = {2}; + packed_tensor = Tensor(DataTypeImpl::GetType(), shape, std::make_shared()); + return std::move(packed_tensor); + } + int prepack_calls_count = 0; int store_pre_packed_weight_calls_count = 0; + int get_prepack_tensors_count = 0; IAllocatorUniquePtr weight_packed_; + Tensor packed_tensor; }; static void CreateSimpleGraph(Graph& graph) { @@ -530,6 +542,7 @@ static void PlaceAllNodesToCPUEP(Graph& graph) { struct PrepackingTestParam { bool test_subgraph; bool test_prepacking; + bool test_save_prepack_initializer; }; class SessionStatePrepackingTest : public testing::TestWithParam {}; @@ -572,6 +585,8 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { sess_options.enable_mem_reuse = true; sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1"; + sess_options.config_options.configurations[kOrtSessionOptionsSavePrePackedConstantInitializers] = + test_param.test_save_prepack_initializer ? "1" : "0"; SessionState session_state(model.MainGraph(), execution_providers, @@ -597,12 +612,47 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { kernel_registry_manager.RegisterKernelRegistry(kernel_registry); PlaceAllNodesToCPUEP(model.MainGraph()); + SessionState::PrePackInitializers pre_packed_initializers; ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), - kernel_registry_manager)); + kernel_registry_manager, + pre_packed_initializers)); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1)); + + // check get prepack tensor method called when set save_prepacked_constant_initializers + if (!test_param.test_subgraph) { + const auto* kernel = reinterpret_cast(session_state.GetKernel(0)); + ASSERT_EQ(kernel->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + } else { + auto if_index = 1; + if (session_state.GetKernel(0)->Node().OpType() == "If") { + if_index = 0; + } + + const auto& subgraph_session_states = session_state.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + const auto* kernel_if_0 = reinterpret_cast(session_state_1_then_branch_session_state.GetKernel(0)); + const auto* kernel_if_1 = reinterpret_cast(session_state_1_else_branch_session_state.GetKernel(0)); + ASSERT_EQ(kernel_if_0->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + ASSERT_EQ(kernel_if_1->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + } + + // check pre_packed_initializers_to_save will be set properly when set save_prepacked_constant_initializers + if (!test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("node_0_input_1"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["node_0_input_1"].count("node_0"), size_t(1)); + } else if (test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("if_shared"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_1"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_0"), size_t(1)); + } } class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { @@ -1000,10 +1050,14 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, - testing::Values(PrepackingTestParam{false, false}, - PrepackingTestParam{false, true}, - PrepackingTestParam{true, false}, - PrepackingTestParam{true, true})); + testing::Values(PrepackingTestParam{false, false, false}, + PrepackingTestParam{false, true, false}, + PrepackingTestParam{true, false, false}, + PrepackingTestParam{true, true, false}, + PrepackingTestParam{false, false, true}, + PrepackingTestParam{false, true, true}, + PrepackingTestParam{true, false, true}, + PrepackingTestParam{true, true, true})); #endif } // namespace test diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0be1c0b1965ac..e19362e0ec32d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4600,3 +4600,86 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseAliasMap(input_index, output_index); } + +TEST(CApiTest, Serialize_PrePack_Initializers) { + std::string model_name = "model_with_matmul_nbits"; + + const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; + const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; + std::string external_initializer_file_name = model_name + "_opt.onnx.data"; + + // Generate optimized with prepacked weights serialized + Ort::SessionOptions session_options_opt; + session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, external_initializer_file_name.c_str()); + session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "0"); + session_options_opt.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1"); + +#if defined(_WIN32) || defined(_WIN64) + std::wstring test_model_wide = onnxruntime::ToWideString(test_model); + session_options_opt.SetOptimizedModelFilePath(onnxruntime::ToWideString(optimized_model).c_str()); + Ort::Session session_opt_model(*ort_env, test_model_wide.c_str(), session_options_opt); +#else + session_options_opt.SetOptimizedModelFilePath(optimized_model.c_str()); + Ort::Session session_opt_model(*ort_env, test_model.c_str(), session_options_opt); +#endif + + // Do inference with original model and optimized model and check output is identical + // set inputs and session options + Ort::SessionOptions session_options; + const char* input_names[] = {"A"}; + const char* const output_names[] = {"Y"}; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + std::vector input_0_data = {1.3f}; + std::vector input_0_dims = {1, 1}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_0_data.data()), + input_0_data.size(), input_0_dims.data(), input_0_dims.size())); + + // run inference with original model + // Convert std::string to std::wstring +#if defined(_WIN32) || defined(_WIN64) + Ort::Session session(*ort_env, test_model_wide.c_str(), session_options); +#else + Ort::Session session(*ort_env, test_model.c_str(), session_options); +#endif + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // run inference with optimized model which load serialized prepack initializer +#if defined(_WIN32) || defined(_WIN64) + std::wstring optimized_model_wide = onnxruntime::ToWideString(optimized_model); + Ort::Session session_opt(*ort_env, optimized_model_wide.c_str(), session_options); +#else + Ort::Session session_opt(*ort_env, optimized_model.c_str(), session_options); +#endif + auto ort_outputs_opt = session_opt.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // check output of original model and optimized model are equal + ASSERT_EQ(ort_outputs.size(), ort_outputs_opt.size()); + + for (size_t i = 0; i < ort_outputs.size(); ++i) { + const auto& sequences = ort_outputs[i]; + ASSERT_TRUE(sequences.IsTensor()); + + const auto& sequences_opt = ort_outputs_opt[i]; + ASSERT_TRUE(sequences_opt.IsTensor()); + + auto result_ts = sequences.GetTensorTypeAndShapeInfo(); + auto result_ts_opt = sequences_opt.GetTensorTypeAndShapeInfo(); + + ASSERT_EQ(result_ts.GetElementType(), result_ts_opt.GetElementType()); + + ASSERT_EQ(result_ts.GetShape(), result_ts_opt.GetShape()); + + const auto* result_vals = sequences.GetTensorData(); + auto result_span = gsl::make_span(result_vals, ort_outputs.size()); + + const auto* result_vals_opt = sequences_opt.GetTensorData(); + auto result_span_opt = gsl::make_span(result_vals_opt, ort_outputs_opt.size()); + + ASSERT_TRUE(std::equal(result_span_opt.begin(), result_span_opt.end(), result_span.begin(), result_span.end())); + } +} \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.onnx b/onnxruntime/test/testdata/model_with_external_initializers.onnx index f815b4000f98f..3538f01b53c18 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.onnx +++ b/onnxruntime/test/testdata/model_with_external_initializers.onnx @@ -1,7 +1,8 @@ - onnx-example:– -& + + onnx-example:œ +, X -PadsY"Pad* +PadsYpad0"Pad* mode"constant  test-model*"BPadsj locationPads.binpZ @@ -16,4 +17,4 @@ test-model*"BPadsj Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.py b/onnxruntime/test/testdata/model_with_external_initializers.py index 8d2589a9e6564..dc64d4a41424a 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.py +++ b/onnxruntime/test/testdata/model_with_external_initializers.py @@ -35,9 +35,10 @@ def GenerateModel(model_name, external_data_name): # noqa: N802 # Create a node (NodeProto) node_def = helper.make_node( - "Pad", # node name + "Pad", # op type ["X", external_data_name], # inputs ["Y"], # outputs + "pad0", # node name mode="constant", # Attributes ) diff --git a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx index 6f9cce0bc5b4f..47d0c68235099 100644 --- a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx +++ b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx @@ -1,7 +1,8 @@ -  onnx-example:æ -: + + onnx-example:ì +@ X -model_with_orig_ext_dataY"Pad* +model_with_orig_ext_dataYpad0"Pad* mode"constant  test-model*JBmodel_with_orig_ext_dataj( locationmodel_with_orig_ext_data.binpZ @@ -16,4 +17,4 @@ test-model*JBmodel_with_orig_ext_dataj( Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/prepack/MatMul.Weight.bin b/onnxruntime/test/testdata/prepack/MatMul.Weight.bin new file mode 100644 index 0000000000000000000000000000000000000000..0f8a571589c1050d3b3e512801441efcb22cdf3c GIT binary patch literal 8 KcmZ3@00966U;wND literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py new file mode 100644 index 0000000000000..86af461edc2c4 --- /dev/null +++ b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import onnx +from onnx import TensorProto, helper +from onnx.external_data_helper import set_external_data +from onnx.numpy_helper import from_array + +M = 1 +K = 1 +N = 1 +q_cols = 1 +q_rows = 1 +q_scale_size = 1 + + +def create_external_data_tensor(value, tensor_name, data_type): + tensor = from_array(np.array(value)) + tensor.name = tensor_name + tensor_filename = f"{tensor_name}.bin" + set_external_data(tensor, location=tensor_filename) + + with open(os.path.join(tensor_filename), "wb") as data_file: + data_file.write(tensor.raw_data) + tensor.ClearField("raw_data") + tensor.data_location = onnx.TensorProto.EXTERNAL + tensor.data_type = data_type + return tensor + + +def create_internal_data_tensor(value, tensor_name, data_type): + tensor = helper.make_tensor(name=tensor_name, data_type=data_type, dims=value.shape, vals=value.flatten().tolist()) + print(tensor) + tensor.data_location = onnx.TensorProto.DEFAULT + return tensor + + +def GenerateMatmulNBitsModel(model_name, external_data_name): # noqa: N802 + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [M, K]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [M, N]) # noqa: N806 + + # Create a node (NodeProto) + node_def = helper.make_node( + op_type="MatMulNBits", # op type + inputs=["A", external_data_name, "scales"], # inputs + outputs=["Y"], # outputs + name="MatMul_0", # node name + domain="com.microsoft", # Custom domain for this operator + accuracy_level=4, # Attributes + bits=4, # Attributes + block_size=32, # Attributes + K=K, # Attributes + N=N, # Attributes + ) + + # Create the graph (GraphProto) + graph_def = helper.make_graph( + [node_def], + "test-model-matmul4bits", + [A], + [Y], + [ + create_external_data_tensor([[171]], external_data_name, TensorProto.UINT8), + create_internal_data_tensor(np.array([1.5], dtype=np.float32), "scales", TensorProto.FLOAT), + ], + ) + + # Create the model + model_def = helper.make_model( + graph_def, + producer_name="onnx-example", + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], + ) + + print(f"The ir_version in model: {model_def.ir_version}\n") + print(f"The producer_name in model: {model_def.producer_name}\n") + print(f"The graph in model:\n{model_def.graph}") + onnx.checker.check_model(model_def) + print("The model is checked!") + with open(model_name, "wb") as model_file: + model_file.write(model_def.SerializeToString()) + + +if __name__ == "__main__": + GenerateMatmulNBitsModel("model_with_matmul_nbits.onnx", "MatMul.Weight") diff --git a/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx b/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0e06a75a5a7e84e1fe2f090a4c7c6a513ed6344f GIT binary patch literal 333 zcmZ8cO-sZu6znTS!rR@_#t(`h$Q}zV9>l|5#9n3hD(j`BF={kv$|jZ7AK`D%f8tLw zrC#dc!MvF_j~Rk=ZrXNVh&|Jt607eJKLOze7i;F$y(;g7e0p|xU^!F5QrMo7QK>JM zvk`47>1<9AZZr6Ta6p?89b?Qm?{|#9*Gjwzl|{qB45P+d#wA5;l;N+nl^-HI_xftV zjV`t1J7dkGqbE*SS7`GfRH2#Ey}BIi`4s^INmxyzzMLWP|Cp1erRk(a*~qqo{K> o83n=5b@kV)3+@knYZ~L603{d_7^d;$_CHxg7$k9(;xuLgzX5qu+5i9m literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index c4c7a98ba116a..ec7a458237c77 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -42,6 +42,7 @@ static SessionOptions session_options = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 1b7d6b9ea26f6..0e40d04ddac8c 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -89,6 +89,7 @@ int main(int argc, char* argv[]) { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::DEFAULT, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index dae6f613f4329..5a2f1cd13683e 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -37,6 +37,7 @@ static SessionOptions SESSION_OPTION = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse From 10bdf6e7977e19c064ddb914dc38eaa84f5ee4cc Mon Sep 17 00:00:00 2001 From: Kyle <92152685+idiskyle@users.noreply.github.com> Date: Fri, 25 Oct 2024 23:13:02 +0800 Subject: [PATCH 05/11] Fix Maven Sha256 Checksum Issue (#22600) ### Description **Changes applied to maven related signing:** * Windows sha256 file encoded by utf8(no BOM) * powershell script task used latest version, previous 5.1 version only supports utf8 with BOM. * Windows sha256 file content in format 'sha256value *filename.extension'. * Linux sha256 file content in format 'sha256value *filename.extension'. **More information about powershell encoding:** Windows powershell encoding reference: [about_Character_Encoding - PowerShell | Microsoft Learn](https://learn.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_character_encoding?view=powershell-7.4) - for version 5.1, it only has 'UTF8 Uses UTF-8 (with BOM).' - for version v7.1 and higher, it has: utf8: Encodes in UTF-8 format (no BOM). utf8BOM: Encodes in UTF-8 format with Byte Order Mark (BOM) utf8NoBOM: Encodes in UTF-8 format without Byte Order Mark (BOM) --- .../templates/jar-maven-signing-linux.yml | 3 ++- .../templates/jar-maven-signing-win.yml | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml index ca7e3f6148e26..d14952e544e5e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -45,7 +45,8 @@ steps: for file in $(find $jar_file_directory -type f); do echo "Adding checksum of sha256 to file: $file" - sha256sum $file | awk '{print $1}' >$file.sha256 + sha256_value=$(sha256sum $file | awk '{print $1}') + echo $sha256_value" *"$(basename "$file") >$file.sha256 echo "Added checksum of sha256 to file: $file" done diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml index 182a2ebe3b4c9..5681b3568bae1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml @@ -15,6 +15,7 @@ steps: displayName: 'Sign jar files: GnuPG and sha256' inputs: targetType: 'inline' + pwsh: true workingDirectory: '$(Build.SourcesDirectory)' script: | $jar_file_directory = '${{ parameters.JarFileDirectory }}' @@ -53,15 +54,22 @@ steps: Write-Host "GnuPG signed to file: "$file_path } + $PSDefaultParameterValues['Out-File:Encoding'] = 'utf8NoBOM' + $sha256sum_exe_path = "C:\Program Files\Git\usr\bin\sha256sum.exe" $targeting_asc_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name + $original_location = Get-Location + Set-Location $jar_file_directory foreach ($file in $targeting_asc_files) { - $file_path = Join-Path $jar_file_directory -ChildPath $file - Write-Host "Adding checksum of sha256 to file: "$file_path - $file_path_sha256 = $file_path + ".sha256" - CertUtil -hashfile $file_path SHA256 - CertUtil -hashfile $file_path SHA256 | find /v `"hash`" | Out-File -FilePath $file_path_sha256 - Write-Host "Added checksum of sha256 to file: "$file_path + Write-Host "Adding checksum of sha256 to file: "$file + $file_path_sha256 = $file + ".sha256" + & $sha256sum_exe_path $file 1>$file_path_sha256 + if ($lastExitCode -ne 0) { + Write-Host -Object "sha256sum command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "Added checksum of sha256 to file: "$file } + Set-Location $original_location Write-Host "GnuPG and sha256 signing to files completed." Write-Host "Deleting GnuPG key files." From 6ea9065b833179a8dcf14cddb39dc4fb3dda81c6 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 25 Oct 2024 09:18:30 -0700 Subject: [PATCH 06/11] Add an 1ES PT baseline file (#22587) This branch is auto-generated by microsoft-github-policy-service[bot] --- .../1espt/PipelineAutobaseliningConfig.yml | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .config/1espt/PipelineAutobaseliningConfig.yml diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml new file mode 100644 index 0000000000000..daa9b73d5971a --- /dev/null +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -0,0 +1,34 @@ +## DO NOT MODIFY THIS FILE MANUALLY. This is part of auto-baselining from 1ES Pipeline Templates. Go to [https://aka.ms/1espt-autobaselining] for more details. + +pipelines: + 1624: + retail: + source: + credscan: + lastModifiedDate: 2024-10-24 + policheck: + lastModifiedDate: 2024-10-24 + eslint: + lastModifiedDate: 2024-10-24 + psscriptanalyzer: + lastModifiedDate: 2024-10-24 + armory: + lastModifiedDate: 2024-10-24 + 1299: + retail: + source: + credscan: + lastModifiedDate: 2024-10-25 + eslint: + lastModifiedDate: 2024-10-25 + psscriptanalyzer: + lastModifiedDate: 2024-10-25 + armory: + lastModifiedDate: 2024-10-25 + binary: + credscan: + lastModifiedDate: 2024-10-25 + binskim: + lastModifiedDate: 2024-10-25 + spotbugs: + lastModifiedDate: 2024-10-25 From 5b4e2a636b77978c4742e73057182540254f25e3 Mon Sep 17 00:00:00 2001 From: dtang317 Date: Fri, 25 Oct 2024 09:21:19 -0700 Subject: [PATCH 07/11] DML EP Register Opset 21 (#22547) ### Description This PR registers the following opset 21 operators: - Size-21 - CastLike-21 - ConstantOfShape-21 - Flatten-21 - Pad-21 - Transpose-21 ### Motivation and Context --- docs/OperatorKernels.md | 18 ++++++++++++------ .../src/Operators/DmlOperatorCast.cpp | 1 + .../src/Operators/DmlOperatorPadding.cpp | 1 + .../src/Operators/OperatorRegistration.cpp | 8 ++++++++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 3 +++ .../OperatorAuthorHelper/OperatorVersions.h | 6 ++++++ 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ddf37cfded77d..bd886abc98a89 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -969,7 +969,8 @@ Do not modify directly.* |||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| @@ -983,7 +984,8 @@ Do not modify directly.* |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| @@ -1021,7 +1023,8 @@ Do not modify directly.* |Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Flatten|*in* input:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1141,7 +1144,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||7+|**T** = tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1253,7 +1257,8 @@ Do not modify directly.* |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -1293,7 +1298,8 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||10+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 45ff25c4fdd90..02fb72b5a073a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -50,5 +50,6 @@ class DmlOperatorCast : public DmlOperator DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); +DML_OP_DEFINE_CREATION_FUNCTION(CastLike21, DmlOperatorCast); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index 9b7ad9aa9e088..f8710fd266c07 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -123,5 +123,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad21, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 2375131cb34ea..ceed388bb0a6f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -365,6 +365,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad11); DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(Pad18); DML_OP_EXTERN_CREATION_FUNCTION(Pad19); +DML_OP_EXTERN_CREATION_FUNCTION(Pad21); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -445,6 +446,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeMatMul); DML_OP_EXTERN_CREATION_FUNCTION(Cast); DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); +DML_OP_EXTERN_CREATION_FUNCTION(CastLike21); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); DML_OP_EXTERN_CREATION_FUNCTION(TopK7); @@ -792,6 +794,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO( 21, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. {REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. @@ -804,6 +807,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, + {REG_INFO_VER( 21, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, #if DML_TARGET_VERSION >= 0x6400 {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, @@ -819,6 +823,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))}, + {REG_INFO( 21, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))}, {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, {REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, @@ -853,6 +858,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(21, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, @@ -1087,6 +1093,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 21, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, @@ -1102,6 +1109,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 21, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 323fcc779d98d..c1ea69ab35374 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1673,6 +1673,7 @@ using ShapeInferenceHelper_Flatten7 = FlattenHelper; using ShapeInferenceHelper_Flatten9 = FlattenHelper; using ShapeInferenceHelper_Flatten11 = FlattenHelper; using ShapeInferenceHelper_Flatten13 = FlattenHelper; +using ShapeInferenceHelper_Flatten21 = FlattenHelper; using ShapeInferenceHelper_Split7 = VersionedOpsetHelper; using ShapeInferenceHelper_Split11 = VersionedOpsetHelper; using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; @@ -1689,6 +1690,7 @@ using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad21 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; @@ -1865,6 +1867,7 @@ using ShapeInferenceHelper_Range = RangeHelper; using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_CastLike21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DmlFusedConv = ConvHelper; using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 26529c0d59dd6..c2a6d57fca0a9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -446,6 +446,12 @@ namespace OperatorHelper static const int sc_sinceVer_Reshape = 21; static const int sc_sinceVer_Cast = 21; static const int sc_sinceVer_Shape = 21; + static const int sc_sinceVer_Size = 21; + static const int sc_sinceVer_CastLike = 21; + static const int sc_sinceVer_ConstantOfShape = 21; + static const int sc_sinceVer_Flatten = 21; + static const int sc_sinceVer_Pad = 21; + static const int sc_sinceVer_Transpose = 21; } namespace MsftOperatorSet1 From 7acbd51912b3f24f473df0047284ddf216968832 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:03:43 -0700 Subject: [PATCH 08/11] Bump onnx from 1.16.1 to 1.17.0 in /tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts (#22593) Bumps [onnx](https://github.com/onnx/onnx) from 1.16.1 to 1.17.0.
Release notes

Sourced from onnx's releases.

v1.17.0

ONNX v1.17.0 is now available with exciting new features! We would like to thank everyone who contributed to this release! Please visit onnx.ai to learn more about ONNX and associated projects.

Key Updates

ai.onnx Opset 22

Python Changes

  • Support for numpy >= 2.0

Bug fixes and infrastructure improvements

  • Fix Check URLs errors 5972
  • Use CMAKE_PREFIX_PATH in finding libprotobuf 5975
  • Bump main VERSION_NUMBER to 1.17.0 5968
  • Fix source and pip tar.gz builds on s390x systems 5984
  • Fix unique_name 5992
  • Fix SegFault bug in shape inference 5990
  • Fix onnx.compose when connecting subgraphs 5991
  • Fix conversion from split 11 to split 18 6020
  • Update error messages for NegativeLogLikelihoodLoss inference function 6021
  • Generalize input/output number check in shape inference 6005
  • Replace rank inference with shape inference for Einsum op 6010
  • build from source instruction with latest cmake change 6038
  • Handle OneHot's depth value during shape inference 5963
  • Not to install cmake in pyproject.toml on Windows 6045
  • fix a skipped shape infer code 6049
  • Include the ".onnxtext" extension in supported serialization format 6051
  • Allow ReferenceEvaluator to return intermediate results 6066
  • Fix 1 typo in numpy_helper.py 6041
  • Remove benchmarking code 6076
  • Prevent crash on import after GCC 8 builds 6048
  • Check graph outputs are defined 6083
  • Enable additional ruff rules 6032
  • Add missing shape inference check for DequantizeLinear 6080
  • Add bfloat16 to all relevant ops 6099
  • fix(ci): install python dependencies with --only-binary :all: in manylinux 6120
  • fix: install google-re2 with --only-binary option 6129
  • Specify axis parameter for DequantizeLinear when input rank is 1 6095
  • Pin onnxruntime to 1.17.3 for release CIs 6143
  • Fix INT4 TensorProto byte size is 5x larger than expected with negative values 6161
  • Mitigate tarball directory traversal risks 6164
  • Fix reference implementation for ScatterND with 4D tensors 6174
  • Addition of group > 1 in test and in backend for ConvTranspose 6175
  • Support for bfloat16 for binary, unary operators in reference implementation 6166
  • Refactor windows workflow to work on standard windows 6190
  • Fix a few crashes while running shape inference 6195
  • Update onnx to work with numpy>=2.0 6196
  • Use sets to improve performance of dfs search 6213

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=onnx&package-manager=pip&previous-version=1.16.1&new-version=1.17.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../inference/aarch64/python/cpu/scripts/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt index 07a9f3f481aa8..a0c9a4326aec3 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -onnx==1.16.1 +onnx==1.17.0 protobuf==4.21.12 sympy==1.12 flatbuffers From 28efacfd5a9f3421ee2aae8759f0b14eb96bd338 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 25 Oct 2024 13:19:59 -0500 Subject: [PATCH 09/11] [MigraphX] Fix potential synchronization problem when ORT_ENABLE_STREAM is true (#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined. --- onnxruntime/core/providers/migraphx/gpu_data_transfer.cc | 2 +- .../core/providers/migraphx/migraphx_execution_provider.cc | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 94480c308b99f..51625b83b8f61 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -57,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e41cd577b0b21..dca38480434fe 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1445,7 +1445,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector ort_shape{res_lens.begin(), res_lens.end()}; auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + gpu_res.data(), + res_shape.bytes(), + hipMemcpyDeviceToDevice, + static_cast(rocm_stream))); } } }; From b4afc6266f7ff20e7b79eaea7fa62f3e30b7474f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 25 Oct 2024 11:47:16 -0700 Subject: [PATCH 10/11] [ROCm] Python 3.10 in ROCm CI, and ROCm 6.2.3 in MigraphX CI (#22527) ### Description Upgrade python from 3.9 to 3.10 in ROCm and MigraphX docker files and CI pipelines. Upgrade ROCm version to 6.2.3 in most places except ROCm CI, see comment below. Some improvements/upgrades on ROCm/Migraphx docker or pipeline: * rocm 6.0/6.1.3 => 6.2.3 * python 3.9 => 3.10 * Ubuntu 20.04 => 22.04 * Also upgrade ml_dtypes, numpy and scipy packages. * Fix message "ROCm version from ..." with correct file path in CMakeList.txt * Exclude some NHWC tests since ROCm EP lacks support for NHWC convolution. #### ROCm CI Pipeline: ROCm 6.1.3 is kept in the pipeline for now. - Failed after upgrading to ROCm 6.2.3: `HIPBLAS_STATUS_INVALID_VALUE ; GPU=0 ; hostname=76123b390aed ; file=/onnxruntime_src/onnxruntime/core/providers/rocm/rocm_execution_provider.cc ; line=170 ; expr=hipblasSetStream(hipblas_handle_, stream);` . It need further investigation. - cupy issues: (1) It currently supports numpy < 1.27, might not work with numpy 2.x. So we locked numpy==1.26.4 for now. (2) cupy support of ROCm 6.2 is still in progress: https://github.com/cupy/cupy/issues/8606. Note that miniconda issues: its libstdc++.so.6 and libgcc_s.so.1 might have conflict with the system ones. So we created links to use the system ones. #### MigraphX CI pipeline MigraphX CI does not use cupy, and we are able to use ROCm 6.2.3 and numpy 2.x in the pipeline. #### Other attempts Other things that I've tried which might help in the future: Attempt to use a single docker file for both ROCm and Migraphx: https://github.com/microsoft/onnxruntime/pull/22478 Upgrade to ubuntu 24.04 and python 3.12, and use venv like [this](https://github.com/microsoft/onnxruntime/blob/27903e7ff1dd7256cd2b277c03766b4f2ad9e2f1/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile). ### Motivation and Context In 1.20 release, ROCm nuget packaging pipeline will use 6.2: https://github.com/microsoft/onnxruntime/pull/22461. This upgrades rocm to 6.2.3 in CI pipelines to be consistent. --- cmake/CMakeLists.txt | 69 +++++++++++-------- dockerfiles/Dockerfile.migraphx | 2 +- dockerfiles/Dockerfile.rocm | 2 +- dockerfiles/README.md | 4 +- .../internal_testing_tests.cc | 6 +- .../linux-migraphx-ci-pipeline.yml | 8 +-- .../linux-rocm-ci-pipeline.yml | 14 ++-- .../docker/Dockerfile.manylinux2_28_rocm | 2 +- .../migraphx-ci-pipeline-env.Dockerfile | 6 +- .../docker/rocm-ci-pipeline-env.Dockerfile | 16 +++-- .../docker/scripts/setup_rocm_yum_repo.sh | 2 +- 11 files changed, 70 insertions(+), 61 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 9d1b39143016b..1070627d5e7da 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -291,12 +291,50 @@ if (onnxruntime_USE_ROCM) message(FATAL_ERROR "ROCM does not support build with CUDA!") endif() + # replicate strategy used by pytorch to get ROCM_VERSION + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") + file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm_version.h ****\n") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h ****\n") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) + set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) + set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) + set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + + message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") + message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") + message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") + message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") + message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") + endif() + + if (NOT CMAKE_HIP_COMPILER) set(CMAKE_HIP_COMPILER "${onnxruntime_ROCM_HOME}/llvm/bin/clang++") endif() if (NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201") + if (ROCM_VERSION_DEV VERSION_LESS "6.2") + message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2") + else() + set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201") + endif() endif() file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) @@ -328,35 +366,6 @@ if (onnxruntime_USE_ROCM) set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl) endif() - # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake - # with modification - if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version") - file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") - file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) - string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) - endif() - - if (ROCM_VERSION_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") - math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") - else() - message(FATAL_ERROR "Cannot determine ROCm version string") - endif() - message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n") - message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") - message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") - message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") - message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") - message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}") message("\n***** HIP LANGUAGE CONFIG INFO ****\n") message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") message("CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}") diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index c3541a8bd3425..c5d998d503899 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index c242933f677f0..bef8d7a5f47d2 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 7825940571769..9f83fc390eee7 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -292,7 +292,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropriate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm6.0, MIGraphX** +**Ubuntu 22.04, ROCm6.2.3, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -306,7 +306,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm6.0** +**Ubuntu 22.04, ROCm6.2.3** 1. Build the docker image from the Dockerfile in this repository. ``` diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 67fb35d26e6dc..559b521f18782 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -159,7 +159,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { // the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC // version of the ONNX operator when matching a static kernel, those are required. -#if !defined(DISABLE_CONTRIB_OPS) +#if !defined(DISABLE_CONTRIB_OPS) && !defined(USE_ROCM) TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "transform/fusion/conv_relu_opset12.onnx"; @@ -256,10 +256,6 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { run_test(ort_model_path); } -// This test can be deprecated now as the code logic has been changed so the model is not applicable -// TEST(InternalTestingEP, TestRegisterAllocatorHandlesUsageInMultipleSessions) { -//} - // make sure allocators returned by SessionState::GetAllocator are valid when IExecutionProvider::ReplaceAllocator // is used. if something is off InferenceSession::Initialize will fail. TEST(InternalTestingEP, TestReplaceAllocatorDoesntBreakDueToLocalAllocatorStorage) { diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index 1cf60b47b4ded..9e2d8e49a2292 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -37,9 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.1 - - name: RocmVersionPatchSuffix - value: ".3" + value: 6.2.3 jobs: - job: Linux_Build @@ -66,7 +64,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - task: Cache@2 @@ -165,7 +163,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimetrainingmigraphx-cibuild-rocm$(RocmVersion) - task: CmdLine@2 diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index 50f3862761320..c730cc2548038 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -37,9 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.1 - - name: RocmVersionPatchSuffix - value: ".3" + value: 6.1.3 jobs: - job: Linux_Build @@ -66,7 +64,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion) - task: Cache@2 @@ -166,7 +164,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix)" + DockerBuildArgs: "--build-arg ROCM_VERSION=$(RocmVersion)" Repository: onnxruntimerocm-cibuild-rocm$(RocmVersion) - task: CmdLine@2 @@ -231,7 +229,11 @@ jobs: -e KERNEL_EXPLORER_TEST_USE_CUPY=1 \ -e CUPY_CACHE_DIR=/build/Release \ onnxruntimerocm-cibuild-rocm$(RocmVersion) \ - pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100 + /bin/bash -c " + set -ex; \ + python --version; \ + ls /opt/miniconda/envs/rocm-ci/lib/; \ + pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 4 --reruns 1 --durations=100" workingDirectory: $(Build.SourcesDirectory) displayName: 'Run kernel explorer tests' condition: succeededOrFailed() diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index f63f508852fc2..e4c3af05053ba 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -6,7 +6,7 @@ ARG LD_LIBRARY_PATH_ARG=${DEVTOOLSET_ROOTPATH}/usr/lib64:${DEVTOOLSET_ROOTPATH}/ ARG PREPEND_PATH=${DEVTOOLSET_ROOTPATH}/usr/bin: FROM $BASEIMAGE AS base_image -ARG ROCM_VERSION=5.5 +ARG ROCM_VERSION=6.2.3 #Add our own dependencies ADD scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 98ea5e119c319..51591e11ea2e9 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.0 +ARG ROCM_VERSION=6.2.3 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' @@ -68,7 +68,7 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 # Create migraphx-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/migraphx-ci ENV CONDA_DEFAULT_ENV migraphx-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 +RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.10 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Enable migraphx-ci environment @@ -80,4 +80,4 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi # Install migraphx RUN apt update && apt install -y migraphx -RUN pip install numpy packaging ml_dtypes==0.3.0 +RUN pip install numpy packaging ml_dtypes==0.5.0 diff --git a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile index 749e222aff499..f74c5c7b0295e 100644 --- a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.0 +ARG ROCM_VERSION=6.1.3 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' @@ -67,26 +67,30 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 # Create rocm-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/rocm-ci ENV CONDA_DEFAULT_ENV rocm-ci -RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 +RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.10 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Enable rocm-ci environment SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] -# ln -sf is needed to make sure that version `GLIBCXX_3.4.30' is found +# Some DLLs in the conda environment have conflict with the one installed in Ubuntu system. +# For example, the GCC version in the conda environment is 12.x, while the one in the Ubuntu 22.04 is 11.x. +# ln -sf to make sure we always use libstdc++.so.6 and libgcc_s.so.1 in the system. RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libstdc++.so.6 +RUN ln -sf /usr/lib/x86_64-linux-gnu/libgcc_s.so.1 ${CONDA_ENVIRONMENT_PATH}/bin/../lib/libgcc_s.so.1 RUN pip install packaging \ - ml_dtypes==0.3.0 \ + ml_dtypes==0.5.0 \ pytest==7.4.4 \ pytest-xdist \ pytest-rerunfailures \ - scipy==1.10.0 \ - numpy==1.24.1 + scipy==1.14.1 \ + numpy==1.26.4 RUN apt install -y git # Install Cupy to decrease CPU utilization +# Note that the version of Cupy requires numpy < 1.27 RUN git clone https://github.com/ROCm/cupy && cd cupy && \ git checkout 432a8683351d681e00903640489cb2f4055d2e09 && \ export CUPY_INSTALL_USE_HIP=1 && \ diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh index 269337bbba042..0be64d96f3a34 100755 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh @@ -2,7 +2,7 @@ set -e -x # version -ROCM_VERSION=6.0 +ROCM_VERSION=6.2.3 while getopts "r:" parameter_Option do case "${parameter_Option}" From 5ba0c12da3b7588039d7a23acc997080d94ef5ce Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Fri, 25 Oct 2024 21:01:16 -0400 Subject: [PATCH 11/11] Split DML test out of cuda --- .../test/python/onnx_backend_test_series.py | 32 ++++++++++++------- .../stages/py-win-gpu-stage.yml | 20 +++++++++--- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 9b1e87f6ec02e..64fbf5ad15e07 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -105,7 +105,7 @@ def load_jsonc(basename: str): return json.loads("\n".join(lines)) -def create_backend_test(test_name=None): +def create_backend_test(devices:list[str] = None, test_name=None): """Creates an OrtBackendTest and adds its TestCase's to global scope so unittest will find them.""" overrides = load_jsonc("onnx_backend_test_series_overrides.jsonc") @@ -126,36 +126,35 @@ def create_backend_test(test_name=None): else: filters = load_jsonc("onnx_backend_test_series_filters.jsonc") current_failing_tests = apply_filters(filters, "current_failing_tests") - if platform.architecture()[0] == "32bit": current_failing_tests += apply_filters(filters, "current_failing_tests_x86") - if backend.supports_device("DNNL"): + if backend.supports_device("DNNL") or "DNNL" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_DNNL") - if backend.supports_device("NNAPI"): + if backend.supports_device("NNAPI") or "NNAPI" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_NNAPI") - if backend.supports_device("OPENVINO_GPU"): + if backend.supports_device("OPENVINO_GPU") or "OPENVINO_GPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_GPU") - if backend.supports_device("OPENVINO_CPU"): + if backend.supports_device("OPENVINO_CPU") or "OPENVINO_CPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP32") current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") - if backend.supports_device("OPENVINO_NPU"): + if backend.supports_device("OPENVINO_NPU") or "OPENVINO_NPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU") - if backend.supports_device("OPENVINO"): + if backend.supports_device("OPENVINO") or "OPENVINO" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") - if backend.supports_device("MIGRAPHX"): + if backend.supports_device("MIGRAPHX") or "MIGRAPHX" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_MIGRAPHX") # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA # EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precedence than DML # and the nodes are assigned to only the CUDA EP (which supports these tests) - if backend.supports_device("DML") and not backend.supports_device("GPU"): + if (backend.supports_device("DML") and not backend.supports_device("GPU")) or "DML" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_pure_DML") filters = ( @@ -196,6 +195,15 @@ def parse_args(): help="Only run tests that match this value. Matching is regex based, and '.*' is automatically appended", ) + parser.add_argument( + "--devices", + type=str, + choices=["CPU", "CUDA", "MIGRAPHX", "DNNL", "DML", "OPENVINO_GPU", "OPENVINO_CPU", "OPENVINO_NPU","OPENVINO"], + nargs="+", # allows multiple values + default=["CPU"], # default to ["CPU"] if no input is given + help="Select one or more devices CPU, CUDA, MIGRAPHX, DNNL, DML, OPENVINO_GPU, OPENVINO_CPU, OPENVINO_NPU, OPENVINO" + ) + # parse just our args. python unittest has its own args and arg parsing, and that runs inside unittest.main() parsed, unknown = parser.parse_known_args() sys.argv = sys.argv[:1] + unknown @@ -206,5 +214,5 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - create_backend_test(args.test_name) - unittest.main() + create_backend_test(args.devices,args.test_name) + unittest.main() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 88937cc2e154d..18de988af9cbc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -65,9 +65,19 @@ stages: targetPath: $(Build.ArtifactStagingDirectory) artifactName: win_${{ parameters.EP_NAME }}_wheel_${{ parameters.PYTHON_VERSION }} variables: - GRADLE_OPTS: '-Dorg.gradle.daemon=false' - VSGenerator: 'Visual Studio 17 2022' - CUDA_MODULE_LOADING: 'LAZY' + - name: GRADLE_OPTS + value: '-Dorg.gradle.daemon=false' + - name: VSGenerator + value: 'Visual Studio 17 2022' + - name: CUDA_MODULE_LOADING + value: 'LAZY' + - name: ep_name_alt + value: + ${{ if eq(parameters.EP_NAME, directml) }}: + 'DML' + ${{ else }}: + 'CUDA' + steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -190,7 +200,7 @@ stages: TargetPath: '$(Build.ArtifactStagingDirectory)' - task: PowerShell@2 - displayName: 'Install ONNX' + displayName: 'Install Third Party Dependencies' inputs: filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' workingDirectory: '$(Build.BinariesDirectory)' @@ -203,6 +213,6 @@ stages: Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data Copy-Item -Recurse -Path $(Build.sourcesDirectory)/onnxruntime/test/testdata -Destination $(Agent.TempDirectory)\ort_test_data cd $(Agent.TempDirectory)\ort_test_data - python onnx_backend_test_series.py + python onnx_backend_test_series.py --devices $(ep_name_alt) workingDirectory: '$(Build.sourcesDirectory)' displayName: 'Run Python Tests'