diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml
new file mode 100644
index 0000000000000..daa9b73d5971a
--- /dev/null
+++ b/.config/1espt/PipelineAutobaseliningConfig.yml
@@ -0,0 +1,34 @@
+## DO NOT MODIFY THIS FILE MANUALLY. This is part of auto-baselining from 1ES Pipeline Templates. Go to [https://aka.ms/1espt-autobaselining] for more details.
+
+pipelines:
+ 1624:
+ retail:
+ source:
+ credscan:
+ lastModifiedDate: 2024-10-24
+ policheck:
+ lastModifiedDate: 2024-10-24
+ eslint:
+ lastModifiedDate: 2024-10-24
+ psscriptanalyzer:
+ lastModifiedDate: 2024-10-24
+ armory:
+ lastModifiedDate: 2024-10-24
+ 1299:
+ retail:
+ source:
+ credscan:
+ lastModifiedDate: 2024-10-25
+ eslint:
+ lastModifiedDate: 2024-10-25
+ psscriptanalyzer:
+ lastModifiedDate: 2024-10-25
+ armory:
+ lastModifiedDate: 2024-10-25
+ binary:
+ credscan:
+ lastModifiedDate: 2024-10-25
+ binskim:
+ lastModifiedDate: 2024-10-25
+ spotbugs:
+ lastModifiedDate: 2024-10-25
diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml
new file mode 100644
index 0000000000000..6a76f7bcdbcb0
--- /dev/null
+++ b/.github/codeql/codeql-config.yml
@@ -0,0 +1,7 @@
+name: "CodeQL config"
+queries:
+ - uses: security-extended
+ - uses: security-and-quality
+paths-ignore:
+ - tests
+ - build
\ No newline at end of file
diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml
index 7144363717749..0cbaf24059390 100644
--- a/.github/workflows/cffconvert.yml
+++ b/.github/workflows/cffconvert.yml
@@ -8,7 +8,7 @@ on:
jobs:
validate:
name: "validate"
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- name: Check out a copy of the repository
uses: actions/checkout@v4
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
index e4d1b91bab736..d3b51c0681a20 100644
--- a/.github/workflows/codeql.yml
+++ b/.github/workflows/codeql.yml
@@ -18,7 +18,7 @@ on:
jobs:
analyze:
name: Analyze
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
permissions:
actions: read
contents: read
@@ -55,6 +55,11 @@ jobs:
java-version: '11'
distribution: 'microsoft'
+ - if: ${{ matrix.language == 'javascript' }}
+ uses: actions/setup-node@v4
+ with:
+ node-version: 20
+
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- if: ${{ matrix.language != 'cpp' }}
diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml
index 32aed81092774..cf3bc598d02bb 100644
--- a/.github/workflows/gradle-wrapper-validation.yml
+++ b/.github/workflows/gradle-wrapper-validation.yml
@@ -8,7 +8,7 @@ on: [push, pull_request]
jobs:
validation:
name: "Validation"
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- uses: gradle/actions/wrapper-validation@v4
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index a196226a4b836..00960c848b107 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -8,7 +8,7 @@ permissions:
jobs:
triage:
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: github/issue-labeler@v3.4
with:
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 74fd5d7dfdb28..ec834b07b2c78 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -36,7 +36,7 @@ jobs:
lint-python-format:
# Required workflow
name: Python format
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- name: Setup Python
@@ -114,9 +114,12 @@ jobs:
lint-js:
name: Lint JavaScript
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
+ - uses: actions/setup-node@v4
+ with:
+ node-version: 20
- uses: reviewdog/action-eslint@v1
with:
reporter: github-pr-check
diff --git a/.github/workflows/linux_training.yml b/.github/workflows/linux_training.yml
new file mode 100644
index 0000000000000..d382cdf476283
--- /dev/null
+++ b/.github/workflows/linux_training.yml
@@ -0,0 +1,55 @@
+name: orttraining-linux-ci-pipeline
+on:
+ push:
+ branches:
+ - main
+ - rel-*
+ pull_request:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ orttraining-linux-ci-pipeline:
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
+ permissions:
+ actions: read
+ contents: read
+ security-events: write
+ steps:
+ - uses: actions/checkout@v4
+ - run: |
+ python3 -m pip install --user -r tools/ci_build/github/linux/python/requirements.txt
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v3
+ with:
+ config-file: ./.github/codeql/codeql-config.yml
+ languages: 'cpp'
+ - run: |
+ set -e -x
+ rm -rf build
+ python3 tools/ci_build/build.py --build_dir build --config Release --enable_training --skip_submodule_sync --parallel --update --build
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v3
+ with:
+ category: "/language:cpp"
+ output: sarif-results
+ upload: failure-only
+
+ - name: filter-sarif
+ uses: advanced-security/filter-sarif@v1
+ with:
+ patterns: |
+ +**/*.cc
+ +**/*.h
+ -tests/**/*.*
+ -build/**/*.*
+ input: sarif-results/cpp.sarif
+ output: sarif-results/cpp.sarif
+
+ - name: Upload SARIF
+ uses: github/codeql-action/upload-sarif@v3
+ with:
+ sarif_file: sarif-results/cpp.sarif
\ No newline at end of file
diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml
index a711b753c492d..af3f00c4e35ab 100644
--- a/.github/workflows/pr_checks.yml
+++ b/.github/workflows/pr_checks.yml
@@ -19,7 +19,7 @@ concurrency:
jobs:
auto-apply-fixes:
name: Suggest fixes
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
permissions:
contents: read
pull-requests: write
@@ -34,14 +34,17 @@ jobs:
with:
toolchain: stable
components: rustfmt
- - name: Install dependencies
+
+ - name: Update PATH
run: |
- python -m pip install -r requirements-dev.txt
- python -m pip install lintrunner lintrunner-adapters
- lintrunner init
- - name: Run lintrunner on all files
+ echo "$HOME/.local/bin" >> "$GITHUB_PATH"
+
+ - name: Install dependencies and run lintrunner on all files
run: |
- set +e
+ set -e
+ python -m pip install --user -r requirements-dev.txt
+ python -m pip install --user lintrunner lintrunner-adapters
+ lintrunner init
lintrunner f --all-files -v
exit 0
- uses: parkerbxyz/suggest-changes@v1
diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml
index 72e69f6117ce9..6d3e593d8694e 100644
--- a/.github/workflows/publish-c-apidocs.yml
+++ b/.github/workflows/publish-c-apidocs.yml
@@ -22,7 +22,7 @@ permissions:
jobs:
build:
name: Generate C/C++ API docs
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- name: Install doxygen and dependencies
diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml
index 81ba703e8d5c1..c704adb263db4 100644
--- a/.github/workflows/publish-csharp-apidocs.yml
+++ b/.github/workflows/publish-csharp-apidocs.yml
@@ -20,7 +20,7 @@ permissions:
jobs:
build:
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
env:
DOCFXVERSION: 2.62.2
steps:
diff --git a/.github/workflows/publish-gh-pages.yml b/.github/workflows/publish-gh-pages.yml
index 1818261b4b766..11745ce24f9e5 100644
--- a/.github/workflows/publish-gh-pages.yml
+++ b/.github/workflows/publish-gh-pages.yml
@@ -8,7 +8,7 @@ on:
jobs:
placeholder:
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- name: Placeholder step to have workflow included in the GitHub web UI
run: |
diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml
index bed96b1be7027..d04669a13aab7 100644
--- a/.github/workflows/publish-java-apidocs.yml
+++ b/.github/workflows/publish-java-apidocs.yml
@@ -21,7 +21,7 @@ permissions:
jobs:
build:
name: Generate Java docs
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- name: Set up JDK 11
diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml
index 7af635f3eb50a..a6749b42adc35 100644
--- a/.github/workflows/publish-js-apidocs.yml
+++ b/.github/workflows/publish-js-apidocs.yml
@@ -21,7 +21,7 @@ permissions:
jobs:
build:
name: Generate JS API docs
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- name: Setup Node.js
diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml
index 352fd3e948b4b..2be9ad957c5cb 100644
--- a/.github/workflows/publish-python-apidocs.yml
+++ b/.github/workflows/publish-python-apidocs.yml
@@ -22,7 +22,7 @@ permissions:
jobs:
build:
name: Generate Python API docs
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: actions/checkout@v4
- name: Install tools
diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml
index 0867d4c343e91..51166293f06ac 100644
--- a/.github/workflows/sca.yml
+++ b/.github/workflows/sca.yml
@@ -30,7 +30,7 @@ jobs:
- uses: actions/setup-node@v4
with:
- node-version: 18
+ node-version: 20
- name: Download cuda
run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk
@@ -57,6 +57,45 @@ jobs:
sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
category: VS_SCA
+ # With WebGPU, Without python
+ Onnxruntime-SCA-win32-WebGPU-x64:
+ permissions:
+ security-events: write
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: false
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11.x'
+ architecture: 'x64'
+
+ - uses: actions/setup-node@v4
+ with:
+ node-version: 20
+
+ - name: Delete build folder
+ run: |
+ if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
+
+
+ - name: Build code
+ env:
+ CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
+ run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_webgpu
+
+ - name: Generate sarif
+ working-directory: D:\b
+ run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
+
+ - name: Upload SARIF to GitHub
+ uses: github/codeql-action/upload-sarif@v3
+ continue-on-error: true
+ with:
+ sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
+ category: VS_SCA_WIN32_WEBGPU_X64
+
# No python
Onnxruntime-SCA-win32-WINML-x64:
permissions:
@@ -73,7 +112,7 @@ jobs:
- uses: actions/setup-node@v4
with:
- node-version: 18
+ node-version: 20
- name: Delete build folder
run: |
@@ -113,7 +152,7 @@ jobs:
- uses: actions/setup-node@v4
with:
- node-version: 18
+ node-version: 20
- name: Delete build folder
run: |
diff --git a/.github/workflows/skip-doc-change.yml.j2 b/.github/workflows/skip-doc-change.yml.j2
index 58f048122a87e..04f77e5d28713 100644
--- a/.github/workflows/skip-doc-change.yml.j2
+++ b/.github/workflows/skip-doc-change.yml.j2
@@ -14,7 +14,7 @@ jobs:
{%- for name in job_names %}
job{{ loop.index }}:
name: {{ name }}
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- run: 'echo "No build required, only documentation changed"'
{% endfor %}
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 181f3fb17d332..14cf0825873a0 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -8,7 +8,7 @@ on:
jobs:
close-stale-issues:
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
permissions:
issues: write
pull-requests: write
diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml
index e0af2dd06b1b7..7ee9f3917a901 100644
--- a/.github/workflows/title-only-labeler.yml
+++ b/.github/workflows/title-only-labeler.yml
@@ -8,7 +8,7 @@ permissions:
jobs:
triage:
- runs-on: ubuntu-latest
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
steps:
- uses: github/issue-labeler@v3.4
with:
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 15b5e42b1f2e2..9d1b39143016b 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -200,6 +200,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t
option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF)
+option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF)
# Enable bitcode for iOS
option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index c04d67ea4ce3f..dbbf685346532 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -60,6 +60,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
endif()
+ if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
+ string(APPEND CMAKE_C_FLAGS " -DORT_WASM64")
+ string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64")
+ endif()
+
# Build WebAssembly with multi-threads support.
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth")
diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake
index 54a65b57301cc..66268cefac9ef 100644
--- a/cmake/onnxruntime_webassembly.cmake
+++ b/cmake/onnxruntime_webassembly.cmake
@@ -97,7 +97,6 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable)
if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
bundle_static_library(onnxruntime_webassembly
-
${PROTOBUF_LIB}
onnx
onnx_proto
@@ -175,7 +174,6 @@ else()
endif()
target_link_libraries(onnxruntime_webassembly PRIVATE
-
${PROTOBUF_LIB}
onnx
onnx_proto
@@ -194,9 +192,7 @@ else()
onnxruntime_util
re2::re2
)
-
- set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")
-
+ set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'")
if (onnxruntime_USE_XNNPACK)
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'")
@@ -217,10 +213,114 @@ else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()
+ if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
+ set(MAXIMUM_MEMORY "17179869184")
+ target_link_options(onnxruntime_webassembly PRIVATE
+ "SHELL:-s MEMORY64=1"
+ )
+ string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental")
+ string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental")
+ set(SMEMORY_FLAG "-sMEMORY64")
+
+ target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ if (onnxruntime_USE_EXTENSIONS)
+ target_compile_options(ortcustomops PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(ocos_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
+ endif()
+ target_link_options(onnxruntime_webassembly PRIVATE
+ --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js"
+ )
+ else ()
+ set(MAXIMUM_MEMORY "4294967296")
+ target_link_options(onnxruntime_webassembly PRIVATE
+ --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js"
+ )
+ endif ()
+
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
- "SHELL:-s MAXIMUM_MEMORY=4294967296"
+ "SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}"
"SHELL:-s EXIT_RUNTIME=0"
"SHELL:-s ALLOW_MEMORY_GROWTH=1"
"SHELL:-s MODULARIZE=1"
@@ -233,6 +333,41 @@ else()
--no-entry
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\""
)
+ if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
+ set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\
+OrtRunWithBinding:_ppppp,\
+OrtGetTensorData:_ppppp,\
+OrtCreateTensor:p_pppp_,\
+OrtCreateSession:pppp,\
+OrtReleaseSession:_p,\
+OrtGetInputOutputCount:_ppp,\
+OrtCreateSessionOptions:pp__p_ppppp,\
+OrtReleaseSessionOptions:_p,\
+OrtAppendExecutionProvider:_pp,\
+OrtAddSessionConfigEntry:_ppp,\
+OrtGetInputName:ppp,\
+OrtGetOutputName:ppp,\
+OrtCreateRunOptions:ppp_p,\
+OrtReleaseRunOptions:_p,\
+OrtReleaseTensor:_p,\
+OrtFree:_p,\
+OrtCreateBinding:_p,\
+OrtBindInput:_ppp,\
+OrtBindOutput:_ppp_,\
+OrtClearBoundOutputs:_p,\
+OrtReleaseBinding:_p,\
+OrtGetLastError:_pp,\
+JsepOutput:pp_p,\
+JsepGetNodeName:pp,\
+JsepOutput:pp_p,\
+jsepCopy:_pp_,\
+jsepCopyAsync:_pp_,\
+jsepDownload:_pp_")
+ target_link_options(onnxruntime_webassembly PRIVATE
+ "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
+ "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'"
+ )
+ endif ()
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js)
if (onnxruntime_USE_JSEP)
@@ -245,6 +380,8 @@ else()
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
+ "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
+ "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
@@ -281,7 +418,9 @@ else()
endif()
# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
- target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
+ if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
+ target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
+ endif()
if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING)
target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ddf37cfded77d..bd886abc98a89 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -969,7 +969,8 @@ Do not modify directly.*
|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
@@ -983,7 +984,8 @@ Do not modify directly.*
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)|
|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)|
@@ -1021,7 +1023,8 @@ Do not modify directly.*
|Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Flatten|*in* input:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1141,7 +1144,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||7+|**T** = tensor(float), tensor(float16)|
-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1253,7 +1257,8 @@ Do not modify directly.*
|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
-|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
@@ -1293,7 +1298,8 @@ Do not modify directly.*
|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||10+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**
or
*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h
index 07625c38d8474..a17da2a19bb99 100644
--- a/include/onnxruntime/core/framework/op_kernel.h
+++ b/include/onnxruntime/core/framework/op_kernel.h
@@ -79,6 +79,7 @@ class OpKernel {
// the allocator tied to the session if the kernel owns the pre-packed buffer or an
// allocator shared between sessions if the pre-packed buffer is to be shared across sessions
// (i.e.) the kernel does not own the buffer.
+ // @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file.
// @param is_packed: Set it to true if the kernel packed the tensor or to false
// The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
// and the original initialized constant tensor will be released and not accessible anymore in
@@ -88,6 +89,7 @@ class OpKernel {
virtual Status
PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
+ bool, /*save_prepacked_initializers*/
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
return Status::OK();
@@ -129,6 +131,26 @@ class OpKernel {
return Status::OK();
}
+ // Override this function to get pre-packed tensors from this kernel.
+ // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
+ // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
+ // @param input_idx : The index of input we prepacked before and intend to get packed tensor back.
+ // Please refer to matmul_nbits kernel for a complete example.
+ virtual std::optional GetPrePackTensor(int /*input_idx*/) {
+ return std::nullopt;
+ }
+
+ // Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer.
+ // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
+ // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
+ // Please refer to matmul_nbits kernel for a complete example.
+ // @param input_idx : The input index of the tensor in this kernel.
+ // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor
+ // to restore prepacked weight buffer.
+ virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) {
+ return Status::OK();
+ }
+
const OrtDevice GetDevice(OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index eb9581e8018d1..69af3c93d7a07 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -1148,6 +1148,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
#endif
+ // Since one constant initializer could be used by different kernels
+ // and prepacked differently, use an unordered_map to store prepacked
+ // initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>>
+ typedef std::unordered_map> PrePackedTensorProtoToSave;
+
#if !defined(ORT_MINIMAL_BUILD)
/** Gets the GraphProto representation of this Graph. */
const ONNX_NAMESPACE::GraphProto& ToGraphProto();
@@ -1182,18 +1187,26 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
@param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
in the external file. Initializer smaller than this threshold are included in the onnx file.
@param align_info offset alignment info.
+ @param save_prepacked_constant_initializers whether to save prepacked initializer into external data file.
+ If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file,
+ we keep constant initializer as it is.
+ @param pre_packed_initializers struct used to store all the prepacked initializers.
@returns GraphProto serialization of the graph.
*/
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
const std::filesystem::path& model_file_path,
size_t initializer_size_threshold,
- const OffsetAlignmentInfo& align_info) const;
+ const OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ PrePackedTensorProtoToSave& pre_packed_initializers) const;
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
const std::filesystem::path& model_file_path,
size_t initializer_size_threshold) const {
OffsetAlignmentInfo default_options;
- return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options);
+ PrePackedTensorProtoToSave pre_packed_initializers;
+ return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options,
+ false, pre_packed_initializers);
}
/** Gets the ISchemaRegistry instances being used with this Graph. */
@@ -1508,6 +1521,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
private:
void InitializeStateFromModelFileGraphProto();
+ // Private method used to setup external initializer properly during model save,
+ // this external initializer could be oroginal initializer or prepacked initializer.
+ static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info,
+ size_t tensor_bytes_size,
+ int64_t& external_offset,
+ std::ofstream& external_stream,
+ gsl::span raw_data,
+ ONNX_NAMESPACE::TensorProto& output_proto,
+ const std::filesystem::path& external_file_path,
+ const ONNX_NAMESPACE::TensorProto& initializer,
+ bool is_prepacked);
+
// Add node with specified .
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 6a01602e634f8..086919913cbea 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -246,6 +246,12 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
"session.optimized_model_external_initializers_file_name";
+// Use this config when save prepacked constant initializers to onnx external data file.
+// Default is not save prepacked initializers to onnx data file.
+// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1")
+static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
+ "session.save_prepacked_constant_initializers";
+
// Use this config to control the minimum size of the initializer when externalizing it during serialization
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
diff --git a/js/node/README.md b/js/node/README.md
index 3f4da7ddd4135..abb91bf05ddf1 100644
--- a/js/node/README.md
+++ b/js/node/README.md
@@ -14,7 +14,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun
## Requirements
-ONNXRuntime works on Node.js v16.x+ (recommend v18.x+) or Electron v15.x+ (recommend v28.x+).
+ONNXRuntime works on Node.js v16.x+ (recommend v20.x+) or Electron v15.x+ (recommend v28.x+).
The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
diff --git a/js/package-lock.json b/js/package-lock.json
index 58a13a9112116..594d0584ad80e 100644
--- a/js/package-lock.json
+++ b/js/package-lock.json
@@ -7,6 +7,7 @@
"license": "MIT",
"devDependencies": {
"@types/fs-extra": "^11.0.2",
+ "@types/global-agent": "^2.1.3",
"@types/mocha": "^10.0.2",
"@types/node": "^18.14.6",
"@types/npmlog": "^4.1.4",
@@ -23,6 +24,7 @@
"eslint-plugin-prefer-arrow": "^1.2.3",
"eslint-plugin-unicorn": "^48.0.1",
"fs-extra": "^11.1.1",
+ "global-agent": "^3.0",
"jszip": "^3.10.1",
"mocha": "^10.2.0",
"npmlog": "^7.0.1",
@@ -710,6 +712,13 @@
"@types/node": "*"
}
},
+ "node_modules/@types/global-agent": {
+ "version": "2.1.3",
+ "resolved": "https://registry.npmjs.org/@types/global-agent/-/global-agent-2.1.3.tgz",
+ "integrity": "sha512-rGtZZcgZcKWuKNTkGBGsqyOQ7Nn2MjXh4+xeZbf+5b5KMUx8H1rTqLRackxos7pUlreszbYjQcop5JvqCnZlLw==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/@types/json-schema": {
"version": "7.0.15",
"resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz",
@@ -1289,6 +1298,14 @@
"node": ">=8"
}
},
+ "node_modules/boolean": {
+ "version": "3.2.0",
+ "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz",
+ "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==",
+ "deprecated": "Package no longer supported. Contact Support at https://www.npmjs.com/support for more info.",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/brace-expansion": {
"version": "1.1.11",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz",
@@ -1640,6 +1657,13 @@
"integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==",
"dev": true
},
+ "node_modules/detect-node": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz",
+ "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/diff": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz",
@@ -1791,6 +1815,13 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/es6-error": {
+ "version": "4.1.1",
+ "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz",
+ "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/esbuild": {
"version": "0.19.3",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz",
@@ -2504,6 +2535,24 @@
"node": ">=10.13.0"
}
},
+ "node_modules/global-agent": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz",
+ "integrity": "sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q==",
+ "dev": true,
+ "license": "BSD-3-Clause",
+ "dependencies": {
+ "boolean": "^3.0.1",
+ "es6-error": "^4.1.1",
+ "matcher": "^3.0.0",
+ "roarr": "^2.15.3",
+ "semver": "^7.3.2",
+ "serialize-error": "^7.0.1"
+ },
+ "engines": {
+ "node": ">=10.0"
+ }
+ },
"node_modules/globals": {
"version": "13.24.0",
"resolved": "https://registry.npmjs.org/globals/-/globals-13.24.0.tgz",
@@ -3153,6 +3202,13 @@
"integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==",
"dev": true
},
+ "node_modules/json-stringify-safe": {
+ "version": "5.0.1",
+ "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz",
+ "integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==",
+ "dev": true,
+ "license": "ISC"
+ },
"node_modules/json5": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz",
@@ -3272,6 +3328,19 @@
"node": ">=10"
}
},
+ "node_modules/matcher": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz",
+ "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==",
+ "dev": true,
+ "license": "MIT",
+ "dependencies": {
+ "escape-string-regexp": "^4.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ }
+ },
"node_modules/merge2": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz",
@@ -4075,6 +4144,24 @@
"url": "https://github.com/sponsors/isaacs"
}
},
+ "node_modules/roarr": {
+ "version": "2.15.4",
+ "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz",
+ "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==",
+ "dev": true,
+ "license": "BSD-3-Clause",
+ "dependencies": {
+ "boolean": "^3.0.1",
+ "detect-node": "^2.0.4",
+ "globalthis": "^1.0.1",
+ "json-stringify-safe": "^5.0.1",
+ "semver-compare": "^1.0.0",
+ "sprintf-js": "^1.1.2"
+ },
+ "engines": {
+ "node": ">=8.0"
+ }
+ },
"node_modules/run-parallel": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",
@@ -4157,6 +4244,42 @@
"node": ">=10"
}
},
+ "node_modules/semver-compare": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz",
+ "integrity": "sha512-YM3/ITh2MJ5MtzaM429anh+x2jiLVjqILF4m4oyQB18W7Ggea7BfqdH/wGMK7dDiMghv/6WG7znWMwUDzJiXow==",
+ "dev": true,
+ "license": "MIT"
+ },
+ "node_modules/serialize-error": {
+ "version": "7.0.1",
+ "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz",
+ "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==",
+ "dev": true,
+ "license": "MIT",
+ "dependencies": {
+ "type-fest": "^0.13.1"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/serialize-error/node_modules/type-fest": {
+ "version": "0.13.1",
+ "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz",
+ "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==",
+ "dev": true,
+ "license": "(MIT OR CC0-1.0)",
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
"node_modules/set-blocking": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz",
@@ -4284,6 +4407,13 @@
"integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==",
"dev": true
},
+ "node_modules/sprintf-js": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz",
+ "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==",
+ "dev": true,
+ "license": "BSD-3-Clause"
+ },
"node_modules/string_decoder": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz",
@@ -5198,6 +5328,12 @@
"@types/node": "*"
}
},
+ "@types/global-agent": {
+ "version": "2.1.3",
+ "resolved": "https://registry.npmjs.org/@types/global-agent/-/global-agent-2.1.3.tgz",
+ "integrity": "sha512-rGtZZcgZcKWuKNTkGBGsqyOQ7Nn2MjXh4+xeZbf+5b5KMUx8H1rTqLRackxos7pUlreszbYjQcop5JvqCnZlLw==",
+ "dev": true
+ },
"@types/json-schema": {
"version": "7.0.15",
"resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz",
@@ -5588,6 +5724,12 @@
"integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==",
"dev": true
},
+ "boolean": {
+ "version": "3.2.0",
+ "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz",
+ "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==",
+ "dev": true
+ },
"brace-expansion": {
"version": "1.1.11",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz",
@@ -5838,6 +5980,12 @@
"integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==",
"dev": true
},
+ "detect-node": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz",
+ "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==",
+ "dev": true
+ },
"diff": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/diff/-/diff-5.0.0.tgz",
@@ -5965,6 +6113,12 @@
"is-symbol": "^1.0.2"
}
},
+ "es6-error": {
+ "version": "4.1.1",
+ "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz",
+ "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==",
+ "dev": true
+ },
"esbuild": {
"version": "0.19.3",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.3.tgz",
@@ -6511,6 +6665,20 @@
"is-glob": "^4.0.3"
}
},
+ "global-agent": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz",
+ "integrity": "sha512-PT6XReJ+D07JvGoxQMkT6qji/jVNfX/h364XHZOWeRzy64sSFr+xJ5OX7LI3b4MPQzdL4H8Y8M0xzPpsVMwA8Q==",
+ "dev": true,
+ "requires": {
+ "boolean": "^3.0.1",
+ "es6-error": "^4.1.1",
+ "matcher": "^3.0.0",
+ "roarr": "^2.15.3",
+ "semver": "^7.3.2",
+ "serialize-error": "^7.0.1"
+ }
+ },
"globals": {
"version": "13.24.0",
"resolved": "https://registry.npmjs.org/globals/-/globals-13.24.0.tgz",
@@ -6956,6 +7124,12 @@
"integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==",
"dev": true
},
+ "json-stringify-safe": {
+ "version": "5.0.1",
+ "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz",
+ "integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==",
+ "dev": true
+ },
"json5": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz",
@@ -7052,6 +7226,15 @@
"yallist": "^4.0.0"
}
},
+ "matcher": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz",
+ "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==",
+ "dev": true,
+ "requires": {
+ "escape-string-regexp": "^4.0.0"
+ }
+ },
"merge2": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz",
@@ -7636,6 +7819,20 @@
"glob": "^7.1.3"
}
},
+ "roarr": {
+ "version": "2.15.4",
+ "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz",
+ "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==",
+ "dev": true,
+ "requires": {
+ "boolean": "^3.0.1",
+ "detect-node": "^2.0.4",
+ "globalthis": "^1.0.1",
+ "json-stringify-safe": "^5.0.1",
+ "semver-compare": "^1.0.0",
+ "sprintf-js": "^1.1.2"
+ }
+ },
"run-parallel": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",
@@ -7691,6 +7888,29 @@
"lru-cache": "^6.0.0"
}
},
+ "semver-compare": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz",
+ "integrity": "sha512-YM3/ITh2MJ5MtzaM429anh+x2jiLVjqILF4m4oyQB18W7Ggea7BfqdH/wGMK7dDiMghv/6WG7znWMwUDzJiXow==",
+ "dev": true
+ },
+ "serialize-error": {
+ "version": "7.0.1",
+ "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz",
+ "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==",
+ "dev": true,
+ "requires": {
+ "type-fest": "^0.13.1"
+ },
+ "dependencies": {
+ "type-fest": {
+ "version": "0.13.1",
+ "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz",
+ "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==",
+ "dev": true
+ }
+ }
+ },
"set-blocking": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz",
@@ -7800,6 +8020,12 @@
"integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==",
"dev": true
},
+ "sprintf-js": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz",
+ "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==",
+ "dev": true
+ },
"string_decoder": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz",
diff --git a/js/package.json b/js/package.json
index a3bd18adce98e..7385ed31eb075 100644
--- a/js/package.json
+++ b/js/package.json
@@ -1,6 +1,7 @@
{
"devDependencies": {
"@types/fs-extra": "^11.0.2",
+ "@types/global-agent": "^2.1.3",
"@types/mocha": "^10.0.2",
"@types/node": "^18.14.6",
"@types/npmlog": "^4.1.4",
@@ -17,6 +18,7 @@
"eslint-plugin-prefer-arrow": "^1.2.3",
"eslint-plugin-unicorn": "^48.0.1",
"fs-extra": "^11.1.1",
+ "global-agent": "^3.0",
"jszip": "^3.10.1",
"mocha": "^10.2.0",
"npmlog": "^7.0.1",
diff --git a/js/scripts/utils.ts b/js/scripts/utils.ts
index e22eeb1bd9217..5d032dc01957c 100644
--- a/js/scripts/utils.ts
+++ b/js/scripts/utils.ts
@@ -2,9 +2,15 @@
// Licensed under the MIT License.
import { WriteStream } from 'fs';
+import { bootstrap as globalAgentBootstrap } from 'global-agent';
import * as https from 'https';
import { JSZipObject } from 'jszip';
+// Bootstrap global-agent to honor the proxy settings in
+// environment variables, e.g. GLOBAL_AGENT_HTTPS_PROXY.
+// See https://github.com/gajus/global-agent/blob/v3.0.0/README.md#environment-variables for details.
+globalAgentBootstrap();
+
export const downloadZip = async (url: string): Promise =>
new Promise((resolve, reject) => {
https.get(url, (res) => {
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index bfb74355b0d70..50d83f5af26e0 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -902,6 +902,10 @@ export class WebGpuBackend {
this.sessionStatus = 'default';
}
+ onCreateSession(): void {
+ this.gpuDataManager.onCreateSession();
+ }
+
onReleaseSession(sessionId: number): void {
this.unregisterBuffers(sessionId);
if (this.capturedCommandList.has(sessionId)) {
diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts
index d13136d252d2a..37eb0e0edc67c 100644
--- a/js/web/lib/wasm/jsep/backend-webnn.ts
+++ b/js/web/lib/wasm/jsep/backend-webnn.ts
@@ -163,6 +163,69 @@ export class WebNNBackend {
return id;
}
+ // Register WebNN Constant operands from external data.
+ public registerMLConstant(
+ externalFilePath: string,
+ dataOffset: number,
+ dataLength: number,
+ builder: MLGraphBuilder,
+ desc: MLOperandDescriptor,
+ mountedFiles: Map | undefined,
+ ): MLOperand {
+ // If available, "Module.MountedFiles" is a Map for all preloaded files.
+ if (!mountedFiles) {
+ throw new Error('External mounted files are not available.');
+ }
+
+ let filePath = externalFilePath;
+ if (externalFilePath.startsWith('./')) {
+ filePath = externalFilePath.substring(2);
+ }
+ const fileData = mountedFiles.get(filePath);
+ if (!fileData) {
+ throw new Error(`File with name ${filePath} not found in preloaded files.`);
+ }
+
+ if (dataOffset + dataLength > fileData.byteLength) {
+ throw new Error('Out of bounds: data offset and length exceed the external file data size.');
+ }
+
+ const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
+ let bufferView: ArrayBufferView;
+ switch (desc.dataType) {
+ case 'float32':
+ bufferView = new Float32Array(buffer);
+ break;
+ case 'float16':
+ bufferView = new Uint16Array(buffer);
+ break;
+ case 'int32':
+ bufferView = new Int32Array(buffer);
+ break;
+ case 'uint32':
+ bufferView = new Uint32Array(buffer);
+ break;
+ case 'int64':
+ bufferView = new BigInt64Array(buffer);
+ break;
+ case 'uint64':
+ bufferView = new BigUint64Array(buffer);
+ break;
+ case 'int8':
+ bufferView = new Int8Array(buffer);
+ break;
+ case 'uint8':
+ bufferView = new Uint8Array(buffer);
+ break;
+ default:
+ throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
+ }
+
+ LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`);
+
+ return builder.constant(desc, bufferView);
+ }
+
public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 7bce5ff9390e8..fddc061cd775a 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -87,24 +87,25 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
- const heapU32 = module.HEAPU32;
// extract context data
- let dataIndex = contextDataOffset >>> 2;
- this.opKernelContext = heapU32[dataIndex++];
- const inputCount = heapU32[dataIndex++];
- this.outputCount = heapU32[dataIndex++];
- this.customDataOffset = heapU32[dataIndex++];
- this.customDataSize = heapU32[dataIndex++];
+ const ptrSize = module.PTR_SIZE;
+ let dataIndex = contextDataOffset / module.PTR_SIZE;
+ const type = ptrSize === 4 ? 'i32' : 'i64';
+ this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type));
+ const inputCount = Number(module.getValue(ptrSize * dataIndex++, type));
+ this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type));
+ this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*'));
+ this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type));
const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
- const dataType = heapU32[dataIndex++];
- const data = heapU32[dataIndex++];
- const dim = heapU32[dataIndex++];
+ const dataType = Number(module.getValue(ptrSize * dataIndex++, type));
+ const data = Number(module.getValue(ptrSize * dataIndex++, '*'));
+ const dim = Number(module.getValue(ptrSize * dataIndex++, type));
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
- dims.push(heapU32[dataIndex++]);
+ dims.push(Number(module.getValue(ptrSize * dataIndex++, type)));
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
@@ -152,11 +153,12 @@ class ComputeContextImpl implements ComputeContext {
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
- const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
- let offset = data >> 2;
- this.module.HEAPU32[offset++] = dims.length;
+ const ptrSize = this.module.PTR_SIZE;
+ const type = ptrSize === 4 ? 'i32' : 'i64';
+ const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */);
+ this.module.setValue(data, dims.length, type);
for (let i = 0; i < dims.length; i++) {
- this.module.HEAPU32[offset++] = dims[i];
+ this.module.setValue(data + ptrSize * (i + 1), dims[i], type);
}
return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
@@ -215,7 +217,7 @@ export const init = async (
backend,
// jsepAlloc()
- (size: number) => backend.alloc(size),
+ (size: number) => backend.alloc(Number(size)),
// jsepFree()
(ptr: number) => backend.free(ptr),
@@ -223,12 +225,19 @@ export const init = async (
// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
- LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
- backend.memcpy(src, dst);
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`,
+ );
+ backend.memcpy(Number(src), Number(dst));
} else {
- LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
- const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
- backend.upload(dst, data);
+ LOG_DEBUG(
+ 'verbose',
+ () =>
+ `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`,
+ );
+ const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size));
+ backend.upload(Number(dst), data);
}
},
@@ -239,12 +248,19 @@ export const init = async (
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
);
- await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
+ await backend.download(Number(gpuDataId), () =>
+ module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0),
+ );
},
// jsepCreateKernel
(kernelType: string, kernelId: number, attribute: unknown) =>
- backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
+ backend.createKernel(
+ kernelType,
+ Number(kernelId),
+ attribute,
+ module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))),
+ ),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
@@ -256,8 +272,8 @@ export const init = async (
() =>
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
);
- const context = new ComputeContextImpl(module, backend, contextDataOffset);
- return backend.computeKernel(kernel, context, errors);
+ const context = new ComputeContextImpl(module, backend, Number(contextDataOffset));
+ return backend.computeKernel(Number(kernel), context, errors);
},
// jsepCaptureBegin
() => backend.captureBegin(),
diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts
index 5ae16d5625dc8..85aca96057df2 100644
--- a/js/web/lib/wasm/jsep/util.ts
+++ b/js/web/lib/wasm/jsep/util.ts
@@ -167,7 +167,7 @@ export class ShapeUtil {
'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.',
);
}
- size *= dims[i];
+ size *= Number(dims[i]);
}
return size;
}
diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
index 33e8c95c141ee..1860870a1130b 100644
--- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
@@ -64,6 +64,11 @@ export interface GpuDataManager {
*/
dispose(): void;
+ /**
+ * create session related data.
+ */
+ onCreateSession(): void;
+
/**
* release session related data.
* @param sessionId - specify the session ID.
@@ -112,7 +117,7 @@ const bucketArr: number[] = [];
/**
* normalize the buffer size so that it fits the 128-bits (16 bytes) alignment.
*/
-const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16;
+const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16;
/**
* calculate the buffer size so that it fits into buckets.
@@ -200,6 +205,9 @@ class GpuDataManagerImpl implements GpuDataManager {
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map;
+ // The session count.
+ private sessionCount: number;
+
constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
@@ -213,6 +221,8 @@ class GpuDataManagerImpl implements GpuDataManager {
this.freeBuffers.set(key, []);
this.freeUniformBuffers.set(key, []);
}
+
+ this.sessionCount = 0;
}
upload(id: GpuDataId, data: Uint8Array): void {
@@ -226,7 +236,7 @@ class GpuDataManagerImpl implements GpuDataManager {
if (!gpuDataCache) {
throw new Error('gpu data for uploading does not exist');
}
- if (gpuDataCache.originalSize !== srcLength) {
+ if (Number(gpuDataCache.originalSize) !== srcLength) {
throw new Error(`inconsistent data size. gpu data size=${gpuDataCache.originalSize}, data size=${srcLength}`);
}
@@ -288,9 +298,7 @@ class GpuDataManagerImpl implements GpuDataManager {
LOG_DEBUG(
'verbose',
() =>
- `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
- id
- }, buffer is the same, skip.`,
+ `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, buffer is the same, skip.`,
);
return id;
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
@@ -347,7 +355,7 @@ class GpuDataManagerImpl implements GpuDataManager {
}
const gpuData = { id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer };
- this.storageCache.set(gpuData.id, { gpuData, originalSize: size });
+ this.storageCache.set(gpuData.id, { gpuData, originalSize: Number(size) });
LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`);
return gpuData;
@@ -357,10 +365,16 @@ class GpuDataManagerImpl implements GpuDataManager {
return this.storageCache.get(id)?.gpuData;
}
- release(id: GpuDataId): number {
+ release(idInput: GpuDataId): number {
+ const id = typeof idInput === 'bigint' ? Number(idInput) : idInput;
const cachedData = this.storageCache.get(id);
if (!cachedData) {
- throw new Error('releasing data does not exist');
+ if (this.storageCache.size === 0) {
+ // cache was previously cleared, no need to release anything.
+ return 0;
+ } else {
+ throw new Error('releasing data does not exist');
+ }
}
LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`);
@@ -373,7 +387,7 @@ class GpuDataManagerImpl implements GpuDataManager {
}
async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise {
- const cachedData = this.storageCache.get(id);
+ const cachedData = this.storageCache.get(Number(id));
if (!cachedData) {
throw new Error('data does not exist');
}
@@ -460,6 +474,10 @@ class GpuDataManagerImpl implements GpuDataManager {
this.capturedPendingBuffers = new Map();
}
+ onCreateSession() {
+ this.sessionCount += 1;
+ }
+
onReleaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
@@ -469,6 +487,16 @@ class GpuDataManagerImpl implements GpuDataManager {
});
this.capturedPendingBuffers.delete(sessionId);
}
+
+ // release the storage cache if no active sessions.
+ this.sessionCount -= 1;
+ if (this.sessionCount === 0) {
+ LOG_DEBUG('warning', () => '[WebGPU] Clearing webgpu buffer cache');
+ this.storageCache.forEach((storage) => {
+ storage.gpuData.buffer.destroy();
+ });
+ this.storageCache = new Map();
+ }
}
}
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index fe824a5c4558a..09c786daa3fcd 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
-import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
+import { groupQueryAttention } from './ops/group-query-attention';
import { instanceNorm } from './ops/instance-norm';
import { layerNorm } from './ops/layer-norm';
import { matMul } from './ops/matmul';
@@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
- ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
+ ['GroupQueryAttention', [groupQueryAttention]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
index 832f6e132901e..6a78c8ae3b190 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
@@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU
import {
getMaxComponents,
+ IndicesHelper,
inputVariable,
outputVariable,
ShaderHelper,
@@ -65,14 +66,17 @@ export interface AttentionParameters {
broadcastResPosBias: boolean;
passPastInKv: boolean;
qkvFormat: AttentionQkvFormat;
- isPastkvBSNH?: boolean;
+ softcap?: number;
+ doRotary?: number;
+ rotaryInterLeaved?: number;
+ sommoothSoftmax?: number;
+ localWindowsSize?: number;
}
export interface AttentionAttrs {
numHeads: number;
- kvNumHeads?: number;
- isUnidirectional?: number;
- maskFilterValue?: number;
+ isUnidirectional: number;
+ maskFilterValue: number;
scale: number;
doRotary: number;
qkvHiddenSizes: number[];
@@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
};
};
-const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
- const components = getMaxComponents(d);
+const initVarStub = (
+ seqLensInput: IndicesHelper | undefined,
+ totalSequenceLengthInput: IndicesHelper | undefined,
+ initPastSequenceLength: boolean,
+) => {
+ // In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input
+ if (totalSequenceLengthInput && seqLensInput) {
+ return `
+ let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')});
+ let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);
+ let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;
+ let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;
+ total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1;
+ var past_sequence_length: u32 = 0;
+ if (is_first_prompt == false) {
+ past_sequence_length = total_sequence_length - sequence_length;
+ }
+ `;
+ } else {
+ return `
+ ${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''};
+ let present_sequence_length = total_sequence_length;
+ `;
+ }
+};
+
+const createInPlaceSoftmaxProgramInfo = (
+ input: TensorView,
+ batchSize: number,
+ numHeads: number,
+ pastSequenceLength: number,
+ sequenceLength: number,
+ totalSequenceLength: number,
+ seqLens: TensorView | undefined,
+ totalSequenceLengthInput: TensorView | undefined,
+) => {
+ // Set components to 1 if seqLens is specified, i.e. GroupQueryAttention.
+ const components = getMaxComponents(seqLens ? 1 : totalSequenceLength);
let WG = 64;
- const dComp = d / components;
- if (dComp < WG) {
+ const totalSequenceLengthComp = totalSequenceLength / components;
+ if (totalSequenceLengthComp < WG) {
WG = 32;
}
- const elementsPerThread = Math.ceil(d / components / WG);
+ const elementsPerThread = Math.ceil(totalSequenceLength / components / WG);
const programUniforms: ProgramUniform[] = [
- { type: DataType.float, data: 1 / d },
- { type: DataType.uint32, data: dComp },
+ { type: DataType.uint32, data: batchSize },
+ { type: DataType.uint32, data: numHeads },
+ { type: DataType.uint32, data: pastSequenceLength },
+ { type: DataType.uint32, data: sequenceLength },
+ { type: DataType.uint32, data: totalSequenceLengthComp },
{ type: DataType.uint32, data: elementsPerThread },
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
+ if (seqLens) {
+ inputDependencies.push('type');
+ }
+ if (totalSequenceLengthInput) {
+ inputDependencies.push('type');
+ }
const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
+ const inputHelpers = [inputHelper];
+ const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
+ if (seqLensInputHelper) {
+ inputHelpers.push(seqLensInputHelper);
+ }
+
+ const totalSequenceLengthInputHelper = totalSequenceLengthInput
+ ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
+ : undefined;
+ if (totalSequenceLengthInputHelper) {
+ inputHelpers.push(totalSequenceLengthInputHelper);
+ }
const elemValueType = tensorTypeToWsglValueType(input.dataType);
const uniforms: UniformsArrayType = [
- { name: 'd_inv', type: 'f32' },
- { name: 'd_comp', type: 'u32' },
+ { name: 'batch_size', type: 'u32' },
+ { name: 'num_heads', type: 'u32' },
+ { name: 'past_sequence_length', type: 'u32' },
+ { name: 'sequence_length', type: 'u32' },
+ { name: 'total_sequence_length', type: 'u32' },
{ name: 'elements_per_thread', type: 'u32' },
];
return `
var thread_max: array;
var thread_sum: array;
- ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
+ ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)}
${shaderHelper.mainStart([WG, 1, 1])}
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let sequence_length = uniforms.sequence_length;
+ var total_sequence_length = uniforms.total_sequence_length;
+ ${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)}
let local_offset = local_idx * uniforms.elements_per_thread;
- let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
-
+ let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset;
+ let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'};
var thread_max_vector = ${f32Type}(-3.402823e+38f);
- for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
}
thread_max[local_idx] = ${(() => {
@@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
}
var sum_vector = ${f32Type}(0);
- for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
}
thread_sum[local_idx] = ${(() => {
@@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
}
if (sum == 0) {
- for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
- x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
+ x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length));
}
} else {
- for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
+ for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
var f32input = ${f32Type}(x[offset + i]);
x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
}
}
+ ${
+ seqLens
+ ? `
+ for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {
+ x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0));
+ }`
+ : ''
+ };
}`;
};
@@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
name: 'AttentionProbsSoftmax',
shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
getShaderSource,
- getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }),
+ getRunData: () => ({
+ outputs: [],
+ dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads },
+ programUniforms,
+ }),
};
};
@@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = (
pastKey: TensorView | undefined,
attentionBias: TensorView | undefined,
parameters: AttentionParameters,
- attributes: AttentionAttrs,
pastSequenceLength: number,
+ seqLens: TensorView | undefined,
+ totalSequenceLengthInput: TensorView | undefined,
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
- const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
+ const presentKey = outputCount > 1 && pastKey;
+ const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads;
const presentKeyShape = presentKey
- ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
+ ? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize]
: undefined;
-
+ const nReps = parameters.nReps ? parameters.nReps : 1;
// TODO: handle mask
- const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
+ const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale;
const components = getMaxComponents(parameters.headSize);
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
@@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: vectorizedHeadSize },
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: parameters.numHeads },
+ { type: DataType.uint32, data: parameters.headSize },
{ type: DataType.float, data: alpha },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
+ { type: DataType.uint32, data: nReps },
];
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
@@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = (
if (attentionBias) {
inputDependencies.push('type');
}
+ if (seqLens) {
+ inputDependencies.push('type');
+ }
+ if (totalSequenceLengthInput) {
+ inputDependencies.push('type');
+ }
const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
if (presentKey) {
outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default });
@@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = (
if (attentionBias) {
inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
}
+ const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
+ if (seqLensInputVariable) {
+ inputVars.push(seqLensInputVariable);
+ }
+ const totalSequenceLengthInputVariable = totalSequenceLengthInput
+ ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
+ : undefined;
+ if (totalSequenceLengthInputVariable) {
+ inputVars.push(totalSequenceLengthInputVariable);
+ }
const output = outputVariable('output', q.dataType, probsShape);
const outputVars = [output];
if (presentKey) {
@@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
+ { name: 'head_size', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
+ { name: 'n_reps', type: 'u32' },
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
@@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = (
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
// x holds the N and y holds the M
- let headIdx = workgroup_id.z;
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
+ let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
- let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
- ${(() => {
- if (feedPastKey && presentKey) {
- return `
- let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
- let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
- } else {
- return `
- let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
- }
- })()}
- ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
+ let sequence_length = uniforms.M;
+ var total_sequence_length = uniforms.N;
+ ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
+ let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
+ let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
+ ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
+ let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
+ ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
@@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = (
${(() => {
if (feedPastKey && presentKey) {
return `
- if (n + local_id.y < uniforms.past_sequence_length) {
+ if (n + local_id.y < past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
- } else {
- tileK[idx] =
- key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
+ } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
+ tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
- return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
+ return `
+ if (n + local_id.y < uniforms.kv_sequence_length) {
+ tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
+ }`;
}
})()}
${
- presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : ''
+ presentKey
+ ? `if (n + local_id.y < present_sequence_length) {
+ present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
+ }`
+ : ''
}
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
- value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
+ value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
}
workgroupBarrier();
}
- let headOffset = headIdx * uniforms.M * uniforms.N;
- if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
+ if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {
+ let headOffset = workgroup_id.z * uniforms.M * uniforms.N;
let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
var sum: f32 = ${(() => {
switch (components) {
@@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = (
pastValue: TensorView | undefined,
params: AttentionParameters,
pastSequenceLength: number,
+ seqLens: TensorView | undefined = undefined,
+ totalSequenceLengthInput: TensorView | undefined = undefined,
) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
- const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
+ const presentValue = outputCount > 1 && pastValue;
+ const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
const presentValueShape = presentValue
- ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
+ ? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize]
: undefined;
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
const TILE_SIZE = 12;
@@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: params.vHeadSize },
{ type: DataType.uint32, data: params.numHeads },
+ { type: DataType.uint32, data: params.headSize },
{ type: DataType.uint32, data: repeatedVHiddenSize },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
+ { type: DataType.uint32, data: nReps },
];
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
@@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = (
if (feedPastValue) {
inputDependencies.push('type');
}
+ if (seqLens) {
+ inputDependencies.push('type');
+ }
+ if (totalSequenceLengthInput) {
+ inputDependencies.push('type');
+ }
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
if (presentValue) {
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
@@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = (
if (feedPastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
+ const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
+ if (seqLens) {
+ inputVars.push(seqLensInputVariable!);
+ }
+ const totalSequenceLengthInputVariable = totalSequenceLengthInput
+ ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
+ : undefined;
+ if (totalSequenceLengthInput) {
+ inputVars.push(totalSequenceLengthInputVariable!);
+ }
const output = outputVariable('output', probs.dataType, outputShape);
const outputVars = [output];
if (presentValue) {
@@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
+ { name: 'head_size', type: 'u32' },
{ name: 'v_hidden_size', type: 'u32' },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
+ { name: 'n_reps', type: 'u32' },
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
- var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
+ var tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
- let headIdx = workgroup_id.z;
+ let headIdx = workgroup_id.z % uniforms.num_heads;
+ let batchIdx = workgroup_id.z / uniforms.num_heads;
+ let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
+ let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
let m = global_id.y;
let n = global_id.x;
-
- let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
- ${(() => {
- if (feedPastValue && presentValue) {
- return `
- let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
- let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
- `;
- } else {
- return `
- let offsetB = headIdx * uniforms.N * uniforms.K + n;
- `;
- }
- })()}
- ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
+ let sequence_length = uniforms.M;
+ var total_sequence_length = uniforms.K;
+ ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
+ let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
+ let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
+ ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
+ let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
+ ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (m < uniforms.M && w + local_id.x < uniforms.K) {
@@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = (
${(() => {
if (feedPastValue && presentValue) {
return `
- if (w + local_id.y < uniforms.past_sequence_length) {
- tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
- } else {
- tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
+ if (w + local_id.y < past_sequence_length) {
+ tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
+ } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
+ tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];
}
`;
} else {
return `
- tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
- `;
+ if (w + local_id.y < uniforms.kv_sequence_length) {
+ tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];
+ }`;
}
})()}
- ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
+ ${
+ presentValue
+ ? `
+ if (w + local_id.y < present_sequence_length) {
+ present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];
+ }`
+ : ''
+ }
}
workgroupBarrier();
- for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
- value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
+ for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {
+ value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x];
}
workgroupBarrier();
}
// we need to transpose output from BNSH_v to BSND_v
- let batchIdx = workgroup_id.z / uniforms.num_heads;
- let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
if (m < uniforms.M && n < uniforms.N) {
let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
- + currentBatchHeadNumber * uniforms.N + n;
+ + headIdx * uniforms.N + n;
output[outputIdx] = value;
}
}`;
@@ -671,23 +804,29 @@ export const applyAttention = (
pastValue: TensorView | undefined,
attentionBiasInput: TensorView | undefined,
parameters: AttentionParameters,
- attributes: AttentionAttrs,
+ seqLens: TensorView | undefined = undefined,
+ totalSequenceLengthInput: TensorView | undefined = undefined,
) => {
- // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
+ // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
- const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
+ const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
const inputsK = [q, k];
- if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
+ if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
if (attentionBias) {
inputsK.push(attentionBias);
}
-
+ if (seqLens) {
+ inputsK.push(seqLens);
+ }
+ if (totalSequenceLengthInput) {
+ inputsK.push(totalSequenceLengthInput);
+ }
// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
@@ -697,31 +836,55 @@ export const applyAttention = (
pastKey,
attentionBias,
parameters,
- attributes,
pastSequenceLength,
+ seqLens,
+ totalSequenceLengthInput,
),
- { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
+ { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
)[0];
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
probs,
- parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
+ parameters.batchSize,
+ parameters.numHeads,
+ pastSequenceLength,
+ parameters.sequenceLength,
totalSequenceLength,
+ seqLens,
+ totalSequenceLengthInput,
),
- { inputs: [probs], outputs: [] },
+ { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] },
);
- // Run AttrionScore
+ // Run AttentionScore
const inputsV = [probs, v];
- if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
+ if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
- context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
- inputs: inputsV,
- outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
- });
+ if (seqLens) {
+ inputsV.push(seqLens);
+ }
+ if (totalSequenceLengthInput) {
+ inputsV.push(totalSequenceLengthInput);
+ }
+ context.compute(
+ createVxAttentionScoreProgramInfo(
+ outputCount,
+ probs,
+ v,
+ pastValue,
+ parameters,
+ pastSequenceLength,
+ seqLens,
+ totalSequenceLengthInput,
+ ),
+ {
+ inputs: inputsV,
+ outputs: outputCount > 1 ? [0, 2] : [0],
+ },
+ );
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
@@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs):
undefined,
context.inputs[5],
params,
- attributes,
);
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
index 53c2ca2fa47d6..c695a71568c97 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
@@ -143,9 +143,11 @@ const createBinaryOpProgramInfo = (
additionalImplementation?: string,
outputDataType: number = a.dataType,
): ProgramInfo => {
- const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims);
- let outputShape = a.dims;
- let outputSize = ShapeUtil.size(a.dims);
+ const aDims = a.dims.map((x) => Number(x) ?? 1);
+ const bDims = b.dims.map((x) => Number(x) ?? 1);
+ const isBroadcast = !ShapeUtil.areEqual(aDims, bDims);
+ let outputShape = aDims;
+ let outputSize = ShapeUtil.size(aDims);
let vectorize = false;
let sharedDimensionDivisibleBy4 = false;
@@ -153,16 +155,16 @@ const createBinaryOpProgramInfo = (
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
const cacheKeyAux = [isBroadcast];
if (isBroadcast) {
- const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false);
+ const calculatedShape = BroadcastUtil.calcShape(aDims, bDims, false);
if (!calculatedShape) {
throw new Error("Can't perform binary op on the given tensors");
}
- outputShape = calculatedShape;
+ outputShape = calculatedShape.slice();
outputSize = ShapeUtil.size(outputShape);
- const isAOneElement = ShapeUtil.size(a.dims) === 1;
- const isBOneElement = ShapeUtil.size(b.dims) === 1;
- const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0;
- const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0;
+ const isAOneElement = ShapeUtil.size(aDims) === 1;
+ const isBOneElement = ShapeUtil.size(bDims) === 1;
+ const aLastDimDivisibleBy4 = aDims.length > 0 && aDims[aDims.length - 1] % 4 === 0;
+ const bLastDimDivisibleBy4 = bDims.length > 0 && bDims[bDims.length - 1] % 4 === 0;
cacheKeyAux.push(isAOneElement);
cacheKeyAux.push(isBOneElement);
cacheKeyAux.push(aLastDimDivisibleBy4);
@@ -170,8 +172,8 @@ const createBinaryOpProgramInfo = (
// check whether vectorize can be enabled
let sharedDimension = 1;
for (let i = 1; i < outputShape.length; i++) {
- const dimA = a.dims[a.dims.length - i] ?? 1;
- const dimB = b.dims[b.dims.length - i] ?? 1;
+ const dimA = aDims[aDims.length - i];
+ const dimB = bDims[bDims.length - i];
if (dimA === dimB) {
sharedDimension *= dimA;
} else {
@@ -199,8 +201,8 @@ const createBinaryOpProgramInfo = (
getShaderSource: (shaderHelper) =>
createBinaryOpProgramShader(
shaderHelper,
- a.dims,
- b.dims,
+ aDims,
+ bDims,
outputShape,
vectorize,
isBroadcast,
@@ -216,7 +218,7 @@ const createBinaryOpProgramInfo = (
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) },
programUniforms: [
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) },
- ...createTensorShapeVariables(a.dims, b.dims, outputShape),
+ ...createTensorShapeVariables(aDims, bDims, outputShape),
],
}),
};
@@ -280,9 +282,7 @@ export const pow = (context: ComputeContext): void => {
} else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
return ${type}(pow(f32(a), f32(b))); // NaN
}
- return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${
- roundStr
- }(pow(f32(abs(a)), f32(b))));
+ return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${roundStr}(pow(f32(abs(a)), f32(b))));
}
fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
// TODO: implement vectorized pow
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index f2057df533ca7..793f26fe901e3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -219,7 +219,7 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s
}
// return type is [ storage type, runtime type ] or a single string for both
- switch (type) {
+ switch (Number(type)) {
case DataType.float16:
return components > 1 ? `vec${components}` : 'f16';
case DataType.float:
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
index 56291c037b7da..bbe25460d6fd3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
@@ -1,31 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
-import { ShapeUtil } from '../../util';
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
-import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
+import { ComputeContext } from '../types';
-import {
- applyAttention,
- AttentionAttrs,
- AttentionMaskType,
- AttentionParameters,
- AttentionQkvFormat,
-} from './attention';
-import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
+import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
-import { createTileProgramInfo } from './tile';
+import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
-
-export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
+export interface GroupQueryAttentionAttributes {
+ numHeads: number;
+ kvNumHeads: number;
+ scale: number;
+ softcap: number;
+ doRotary: number;
+ rotaryInterleaved: number;
+ smoothSoftmax: boolean;
+ localWindowSize: number;
+}
+
+export const validateInputs = (
+ inputs: readonly TensorView[],
+ attributes: GroupQueryAttentionAttributes,
+): AttentionParameters => {
+ if (attributes.doRotary && inputs.length <= 7) {
+ throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
+ }
const query = inputs[0];
const key = inputs[1];
const value = inputs[2];
const pastKey = inputs[3];
const pastValue = inputs[4];
-
+ if (attributes.localWindowSize !== -1) {
+ throw new Error('Local attention is not supported');
+ }
+ if (attributes.softcap !== 0) {
+ throw new Error('Softcap is not supported');
+ }
+ if (attributes.rotaryInterleaved !== 0) {
+ throw new Error('Rotary interleaved is not supported');
+ }
+ if (attributes.smoothSoftmax) {
+ throw new Error('Smooth softmax is not supported');
+ }
// Abbreviation and Meanings:
// B: batch_size
// S: sequence_length (input sequence length of query)
@@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
const dmmhaPacking = false;
const batchSize = query.dims[0];
const sequenceLength = query.dims[1];
- const hiddenSize =
+ let hiddenSize =
query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
let kvSequenceLength = sequenceLength;
let pastSequenceLength = 0;
- let maxSequenceLength = 0;
- const headSize = Math.floor(hiddenSize / attributes.numHeads);
+ const packedQKV = !key || key.dims.length === 0;
+ const headSize = !packedQKV
+ ? Math.floor(hiddenSize / attributes.numHeads)
+ : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads));
+ if (packedQKV) {
+ hiddenSize = headSize * attributes.numHeads;
+ }
const hasPastKey = pastKey && pastKey.dims.length !== 0;
const hasPastValue = pastValue && pastValue.dims.length !== 0;
- // TODO : this should be from attributes.
- const isPastkvBSNH = true;
+ // Currenly the onnxruntime GQA specification only support key/value BNSH format.
+ const isPastkvBSNH =
+ hasPastKey &&
+ pastKey.dims.length === 4 &&
+ pastKey.dims[0] === batchSize &&
+ pastKey.dims[1] !== attributes.kvNumHeads &&
+ pastKey.dims[2] === attributes.kvNumHeads &&
+ pastKey.dims[3] === headSize;
+
+ if (isPastkvBSNH) {
+ throw new Error('BSNH pastKey/pastValue is not supported');
+ }
if (hasPastKey && hasPastValue) {
if (pastKey.dims.length !== 4) {
throw new Error('Input "past_key" is expected to have 4 dimensions');
@@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (pastValue.dims.length !== 4) {
throw new Error('Input "past_value" is expected to have 4 dimensions');
}
- if (isPastkvBSNH) {
- // For BSNH
- pastSequenceLength = pastKey.dims[1];
- maxSequenceLength = pastKey.dims[1];
- } else {
- // For BNSH
- pastSequenceLength = pastKey.dims[2];
- maxSequenceLength = pastKey.dims[2];
- }
+ pastSequenceLength = pastKey.dims[2];
} else if (hasPastKey || hasPastValue) {
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
}
- let qkvFormat: AttentionQkvFormat;
- if (key) {
+ let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH;
+ if (key && key.dims.length > 0) {
if (query.dims.length !== 3) {
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
}
@@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (query.dims[2] % key.dims[2] !== 0) {
throw new Error('Dimension 2 of "query" should be a multiple of "key"');
}
- qkvFormat = AttentionQkvFormat.qkvBSNH;
kvSequenceLength = key.dims[1];
} else if (key.dims.length === 5) {
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
@@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (value) {
throw new Error('Expect "value" be none when "key" has packed kv format.');
}
- qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
kvSequenceLength = key.dims[1];
} else {
// key_dims.size() == 4 (cross-attention with past_key)
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
}
-
- qkvFormat = AttentionQkvFormat.unknown;
kvSequenceLength = key.dims[2];
}
} else {
@@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
const maskType: AttentionMaskType = AttentionMaskType.none;
let passPastInKv = false;
- let vHiddenSize = hiddenSize;
- if (value) {
+ let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize;
+ if (value && value.dims.length > 0) {
if (value.dims.length !== 3 && value.dims.length !== 4) {
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
}
@@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
passPastInKv = true;
}
}
- const totalSequenceLength = pastSequenceLength + kvSequenceLength;
+ const seqlLens = inputs.length > 4 ? inputs[5] : undefined;
+ if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) {
+ throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size');
+ }
+ const totalSequenceLength = -1;
+ const maxSequenceLength = -1;
const broadcastResPosBias = false;
return {
@@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
hiddenSize,
vHiddenSize,
headSize,
- vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
+ vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads),
numHeads: attributes.numHeads,
kvNumHeads: attributes.kvNumHeads,
- nReps: attributes.numHeads / attributes.kvNumHeads!,
+ nReps: attributes.numHeads / attributes.kvNumHeads,
pastPresentShareBuffer: false,
maskType,
scale: attributes.scale,
broadcastResPosBias,
passPastInKv,
qkvFormat,
- isPastkvBSNH,
};
};
-const createConcatProgramInfo = (
- a: TensorView,
- b: TensorView | undefined,
- dataType: DataType,
- params: AttentionParameters,
-): ProgramInfo => {
- const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
- const component = 4;
- const outputSize = ShapeUtil.size(outputShape) / component;
- const presentSequenceLength = params.totalSequenceLength;
- const output = outputVariable('present_kv', dataType, outputShape.length, component);
- const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
- const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
-
- const H = Math.ceil(params.headSize / component);
- const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 };
-
- const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
-
- const programUniforms: ProgramUniform[] = [
- { type: DataType.uint32, data: outputSize },
- { type: DataType.uint32, data: params.pastSequenceLength },
- { type: DataType.uint32, data: params.kvSequenceLength },
- { type: DataType.uint32, data: params.totalSequenceLength },
- ];
-
- const inputs = [inputA];
- if (inputB) {
- programUniforms.push(
- ...createTensorShapeVariables(a.dims),
- ...createTensorShapeVariables(b!.dims),
- ...createTensorShapeVariables(outputShape),
- );
- inputs.push(inputB);
- } else {
- programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
- }
- const uniforms: UniformsArrayType = [
- { name: 'output_size', type: 'u32' },
- { name: 'past_seqlen', type: 'u32' },
- { name: 'new_seqlen', type: 'u32' },
- { name: 'present_seqlen', type: 'u32' },
- ];
-
- const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H;
- var past_head_stride = uniforms.past_seqlen * H;
- if (is_bsnh) {
- past_head_stride = H;
- }
- let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
- present_kv[out_offset] = past_kv[in_offset];`;
- const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H;
- let new_row_stride = num_heads * H;
- let new_head_stride = H;
- let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
- present_kv[out_offset] = new_kv[in_offset];`;
- const concatStr = b
- ? `if (s < past_seqlen) {
- ${pastStr}
- } else if (s < past_seqlen + uniforms.new_seqlen) {
- ${newStr}
- }`
- : `if (s < past_seqlen + uniforms.new_seqlen) {
- ${newStr}
- }`;
-
- // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit.
- const getShaderSource = (shaderHelper: ShaderHelper) => `
-
- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
- ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
- var indices = ${output.offsetToIndices('global_idx')};
- let h = local_id.x;
- let n = local_id.y;
- let s = workgroup_id.x;
- let b = workgroup_id.y;
- let num_heads = ${params.kvNumHeads!}u;
- let H = ${H}u;
-
- let present_seqlen = uniforms.present_seqlen;
- let present_batch_stride = present_seqlen * num_heads * H;
- var row_stride = H;
- let is_bsnh = ${params.isPastkvBSNH};
-
- if (is_bsnh) {
- row_stride = num_heads * H;
- }
- var present_head_stride = present_seqlen * H;
- if (is_bsnh) {
- present_head_stride = H;
- }
-
- let past_seqlen = uniforms.past_seqlen;
-
- let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
- ${concatStr}
- }`;
-
- return {
- name: 'ConcatPastNew',
- shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies },
- getRunData: () => ({
- outputs: [{ dims: outputShape, dataType }],
- dispatchGroup: dispatch,
- programUniforms,
- }),
- getShaderSource,
- };
-};
-
-export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
- createAttributeWithCacheKey({ ...attributes });
-
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
-const maybeExpandAndTransposeToBNSH = (
- context: ComputeContext,
- input: TensorView,
- pastKV: TensorView | undefined,
- params: AttentionParameters,
- outputIndex: number,
-) => {
+const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => {
let reshapedInput = input;
const numHeads = params.kvNumHeads!;
- const nReps = params.nReps!;
if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
- }
-
- if (pastKV) {
- reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), {
- inputs: [reshapedInput, pastKV],
- outputs: [params.isPastkvBSNH ? outputIndex : -1],
- })[0];
- } else {
- reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), {
- inputs: [reshapedInput],
- outputs: [params.isPastkvBSNH ? outputIndex : -1],
- })[0];
- }
- if (nReps !== 1) {
- reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {
+ reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
})[0];
- reshapedInput = reshapedInput.reshape([
- params.batchSize,
- params.totalSequenceLength,
- numHeads * nReps,
- params.headSize,
- ]);
}
- return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
- inputs: [reshapedInput],
- outputs: [-1],
- })[0];
+ return reshapedInput;
};
-export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
+export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
@@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti
throw new Error('Packed KV is not implemented');
}
+ const q = context.inputs[0];
+ const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
+ const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
+ const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
+ const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
+ const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
+ const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
+ const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
+
+ // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
+
+ const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
+ axis: 2,
+ numOutputs: 3,
+ splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
+ });
+ const [query, key, value] =
+ !k && !v
+ ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
+ : [q, k!, v!];
+
const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.sequenceLength,
params.headSize,
- context.inputs[0],
+ query,
undefined,
0,
);
- const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
- const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
- const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
- const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
- applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
+ applyAttention(
+ context,
+ Q,
+ maybeTransposeToBNSH(context, key, params),
+ maybeTransposeToBNSH(context, value, params),
+ undefined,
+ undefined,
+ pastKey,
+ pastValue,
+ undefined,
+ params,
+ seqLens,
+ totalSequenceLengthInput,
+ );
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
index 1a31253905694..db7a4b8e68b79 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
@@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
);
if (kvBNSH) {
- return applyAttention(
- context,
- Q,
- key,
- value,
- keyPaddingMask,
- undefined,
- pastKey,
- pastValue,
- attentionBias,
- params,
- attributes,
- );
+ return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
}
if (!key || !value) {
throw new Error('key and value must be provided');
@@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
2 * params.hiddenSize,
);
- applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes);
+ applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index 1dc3a206cf94b..8c39505734e41 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
}`;
};
-const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
+export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;
diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts
index b2594267a595a..17e564247863d 100644
--- a/js/web/lib/wasm/session-options.ts
+++ b/js/web/lib/wasm/session-options.ts
@@ -200,7 +200,9 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n
return [sessionOptionsHandle, allocs];
} catch (e) {
if (sessionOptionsHandle !== 0) {
- wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
+ if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) {
+ checkLastError("Can't release session options.");
+ }
}
allocs.forEach((alloc) => wasm._free(alloc));
throw e;
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 5f219f63aaf61..eb74aa44b3a72 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -207,12 +207,14 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] =>
const wasm = getInstance();
const stack = wasm.stackSave();
try {
- const dataOffset = wasm.stackAlloc(8);
- const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);
+ const ptrSize = wasm.PTR_SIZE;
+ const dataOffset = wasm.stackAlloc(2 * ptrSize);
+ const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize);
if (errorCode !== 0) {
checkLastError("Can't get session input/output count.");
}
- return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
+ const type = ptrSize === 4 ? 'i32' : 'i64';
+ return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))];
} finally {
wasm.stackRestore(stack);
}
@@ -317,6 +319,8 @@ export const createSession = async (
checkLastError("Can't create a session.");
}
+ wasm.jsepOnCreateSession?.();
+
// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
@@ -398,17 +402,23 @@ export const createSession = async (
outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
if (ioBindingHandle !== 0) {
- wasm._OrtReleaseBinding(ioBindingHandle);
+ if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) {
+ checkLastError("Can't release IO binding.");
+ }
}
if (sessionHandle !== 0) {
- wasm._OrtReleaseSession(sessionHandle);
+ if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
+ checkLastError("Can't release session.");
+ }
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
- wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
+ if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) {
+ checkLastError("Can't release session options.");
+ }
}
allocs.forEach((alloc) => wasm._free(alloc));
@@ -427,16 +437,22 @@ export const releaseSession = (sessionId: number): void => {
if (ioBindingState) {
if (enableGraphCapture) {
- wasm._OrtClearBoundOutputs(ioBindingState.handle);
+ if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
+ checkLastError("Can't clear bound outputs.");
+ }
+ }
+ if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) {
+ checkLastError("Can't release IO binding.");
}
- wasm._OrtReleaseBinding(ioBindingState.handle);
}
wasm.jsepOnReleaseSession?.(sessionId);
inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
- wasm._OrtReleaseSession(sessionHandle);
+ if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
+ checkLastError("Can't release session.");
+ }
activeSessions.delete(sessionId);
};
@@ -454,6 +470,7 @@ export const prepareInputOutputTensor = (
}
const wasm = getInstance();
+ const ptrSize = wasm.PTR_SIZE;
const dataType = tensor[0];
const dims = tensor[1];
@@ -495,15 +512,14 @@ export const prepareInputOutputTensor = (
if (Array.isArray(data)) {
// string tensor
- dataByteLength = 4 * data.length;
+ dataByteLength = ptrSize * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
- let dataIndex = rawData / 4;
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
- wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
+ wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*');
}
} else {
dataByteLength = data.byteLength;
@@ -516,8 +532,7 @@ export const prepareInputOutputTensor = (
const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
- let dimIndex = dimsOffset / 4;
- dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d));
+ dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64'));
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType),
rawData,
@@ -547,6 +562,7 @@ export const run = async (
options: InferenceSession.RunOptions,
): Promise => {
const wasm = getInstance();
+ const ptrSize = wasm.PTR_SIZE;
const session = activeSessions.get(sessionId);
if (!session) {
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
@@ -569,10 +585,10 @@ export const run = async (
const inputOutputAllocs: number[] = [];
const beforeRunStack = wasm.stackSave();
- const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
- const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
- const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
- const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
+ const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize);
+ const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize);
+ const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize);
+ const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize);
try {
// WebNN backend needs the active session to check MLTensors with the current context.
@@ -604,17 +620,13 @@ export const run = async (
);
}
- let inputValuesIndex = inputValuesOffset / 4;
- let inputNamesIndex = inputNamesOffset / 4;
- let outputValuesIndex = outputValuesOffset / 4;
- let outputNamesIndex = outputNamesOffset / 4;
for (let i = 0; i < inputCount; i++) {
- wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];
- wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];
+ wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*');
+ wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*');
}
for (let i = 0; i < outputCount; i++) {
- wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];
- wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
+ wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*');
+ wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*');
}
if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) {
@@ -698,7 +710,7 @@ export const run = async (
const output: TensorMetadata[] = [];
for (let i = 0; i < outputCount; i++) {
- const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
+ const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*'));
if (tensor === outputTensorHandles[i]) {
// output tensor is pre-allocated. no need to copy data.
output.push(outputTensors[i]!);
@@ -707,7 +719,7 @@ export const run = async (
const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
- const tensorDataOffset = wasm.stackAlloc(4 * 4);
+ const tensorDataOffset = wasm.stackAlloc(4 * ptrSize);
let keepOutputTensor = false;
let type: Tensor.Type | undefined,
@@ -716,24 +728,26 @@ export const run = async (
const errorCode = wasm._OrtGetTensorData(
tensor,
tensorDataOffset,
- tensorDataOffset + 4,
- tensorDataOffset + 8,
- tensorDataOffset + 12,
+ tensorDataOffset + ptrSize,
+ tensorDataOffset + 2 * ptrSize,
+
+ tensorDataOffset + 3 * ptrSize,
);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
- let tensorDataIndex = tensorDataOffset / 4;
- const dataType = wasm.HEAPU32[tensorDataIndex++];
- dataOffset = wasm.HEAPU32[tensorDataIndex++];
- const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
- const dimsLength = wasm.HEAPU32[tensorDataIndex++];
+ const valueType = ptrSize === 4 ? 'i32' : 'i64';
+ const dataType = Number(wasm.getValue(tensorDataOffset, valueType));
+ dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*');
+ const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*');
+ const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType));
const dims = [];
for (let i = 0; i < dimsLength; i++) {
- dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
+ dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType)));
+ }
+ if (wasm._OrtFree(dimsOffset) !== 0) {
+ checkLastError("Can't free memory for tensor dims.");
}
- wasm._OrtFree(dimsOffset);
-
const size = dims.reduce((a, b) => a * b, 1);
type = tensorDataTypeEnumToString(dataType);
@@ -744,10 +758,10 @@ export const run = async (
throw new Error('String tensor is not supported on GPU.');
}
const stringData: string[] = [];
- let dataIndex = dataOffset / 4;
for (let i = 0; i < size; i++) {
- const offset = wasm.HEAPU32[dataIndex++];
- const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
+ const offset = wasm.getValue(dataOffset + i * ptrSize, '*');
+ const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*');
+ const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset;
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
output.push([type, dims, stringData, 'cpu']);
@@ -775,7 +789,9 @@ export const run = async (
gpuBuffer,
download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
dispose: () => {
- wasm._OrtReleaseTensor(tensor);
+ if (wasm._OrtReleaseTensor(tensor) !== 0) {
+ checkLastError("Can't release tensor.");
+ }
},
},
'gpu-buffer',
@@ -832,7 +848,9 @@ export const run = async (
}
if (ioBindingState && !enableGraphCapture) {
- wasm._OrtClearBoundOutputs(ioBindingState.handle);
+ if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
+ checkLastError("Can't clear bound outputs.");
+ }
activeSessions.set(sessionId, [
sessionHandle,
inputNamesUTF8Encoded,
diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts
index 3e08fe97f559d..dff3ca74de5a4 100644
--- a/js/web/lib/wasm/wasm-types.ts
+++ b/js/web/lib/wasm/wasm-types.ts
@@ -141,6 +141,12 @@ export declare namespace JSEP {
* @param sessionId - specify the session ID.
*/
jsepOnRunStart: (sessionId: number) => void;
+ /**
+ * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is
+ * called.
+ * @returns
+ */
+ jsepOnCreateSession: () => void;
/**
* [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is
* called.
@@ -225,15 +231,15 @@ export declare namespace JSEP {
export interface OrtInferenceAPIs {
_OrtInit(numThreads: number, loggingLevel: number): number;
- _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
+ _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise;
- _OrtReleaseSession(sessionHandle: number): void;
+ _OrtReleaseSession(sessionHandle: number): number;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
_OrtGetOutputName(sessionHandle: number, index: number): number;
- _OrtFree(stringHandle: number): void;
+ _OrtFree(stringHandle: number): number;
_OrtCreateTensor(
dataType: number,
@@ -250,12 +256,12 @@ export interface OrtInferenceAPIs {
dimsOffset: number,
dimsLength: number,
): number;
- _OrtReleaseTensor(tensorHandle: number): void;
+ _OrtReleaseTensor(tensorHandle: number): number;
_OrtCreateBinding(sessionHandle: number): number;
_OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise;
_OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number;
- _OrtClearBoundOutputs(ioBindingHandle: number): void;
- _OrtReleaseBinding(ioBindingHandle: number): void;
+ _OrtClearBoundOutputs(ioBindingHandle: number): number;
+ _OrtReleaseBinding(ioBindingHandle: number): number;
_OrtRunWithBinding(
sessionHandle: number,
ioBindingHandle: number,
@@ -289,11 +295,11 @@ export interface OrtInferenceAPIs {
_OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number;
_OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number;
_OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number;
- _OrtReleaseSessionOptions(sessionOptionsHandle: number): void;
+ _OrtReleaseSessionOptions(sessionOptionsHandle: number): number;
_OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number;
_OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number;
- _OrtReleaseRunOptions(runOptionsHandle: number): void;
+ _OrtReleaseRunOptions(runOptionsHandle: number): number;
_OrtEndProfiling(sessionHandle: number): number;
}
@@ -302,10 +308,13 @@ export interface OrtInferenceAPIs {
* The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten.
*/
export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial {
+ PTR_SIZE: number;
// #region emscripten functions
stackSave(): number;
stackRestore(stack: number): void;
stackAlloc(size: number): number;
+ getValue(ptr: number, type: string): number;
+ setValue(ptr: number, value: number, type: string): void;
UTF8ToString(offset: number, maxBytesToRead?: number): string;
lengthBytesUTF8(str: string): number;
diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts
index a820fd216ee03..9ce39c366dc77 100644
--- a/js/web/lib/wasm/wasm-utils.ts
+++ b/js/web/lib/wasm/wasm-utils.ts
@@ -55,10 +55,11 @@ export const checkLastError = (message: string): void => {
const stack = wasm.stackSave();
try {
- const paramsOffset = wasm.stackAlloc(8);
- wasm._OrtGetLastError(paramsOffset, paramsOffset + 4);
- const errorCode = wasm.HEAP32[paramsOffset / 4];
- const errorMessagePointer = wasm.HEAPU32[paramsOffset / 4 + 1];
+ const ptrSize = wasm.PTR_SIZE;
+ const paramsOffset = wasm.stackAlloc(2 * ptrSize);
+ wasm._OrtGetLastError(paramsOffset, paramsOffset + ptrSize);
+ const errorCode = Number(wasm.getValue(paramsOffset, ptrSize === 4 ? 'i32' : 'i64'));
+ const errorMessagePointer = wasm.getValue(paramsOffset + ptrSize, '*');
const errorMessage = errorMessagePointer ? wasm.UTF8ToString(errorMessagePointer) : '';
throw new Error(`${message} ERROR_CODE: ${errorCode}, ERROR_MESSAGE: ${errorMessage}`);
} finally {
diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts
index 5b8b0d27c88db..a07849a154e01 100644
--- a/js/web/script/pull-prebuilt-wasm-artifacts.ts
+++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts
@@ -14,6 +14,7 @@
//
import fs from 'fs';
+import { bootstrap as globalAgentBootstrap } from 'global-agent';
import https from 'https';
import jszip from 'jszip';
import path from 'path';
@@ -111,6 +112,11 @@ console.log(
} ===`,
);
+// Bootstrap global-agent to honor the proxy settings in
+// environment variables, e.g. GLOBAL_AGENT_HTTPS_PROXY.
+// See https://github.com/gajus/global-agent/blob/v3.0.0/README.md#environment-variables for details.
+globalAgentBootstrap();
+
const filter = buildId
? `&buildIds=${buildId}`
: '&definitions=161' +
diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc
index 2a4b265078456..036069f43eb54 100644
--- a/js/web/test/data/ops/group-query-attention.jsonc
+++ b/js/web/test/data/ops/group-query-attention.jsonc
@@ -1,6 +1,316 @@
[
{
- "name": "GroupQueryAttention Basic",
+ "name": "GroupQueryAttention 0",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [32, 33, 34, 35, 36, 37, 38, 39],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 1",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // past key, BS*
+ {
+ "data": [40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ },
+ // past value, BS*
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ },
+ // seqlens_k, unimplemented
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length, unimplemented
+ {
+ "data": [2],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [40, 41, 42, 43, 44, 45, 46, 47, 16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 2, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 32, 33, 34, 35, 36, 37, 38, 39],
+ "dims": [1, 1, 2, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 2",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 2, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47
+ ],
+ "dims": [1, 3, 16],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 72, 73, 74, 75, 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81,
+ 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 3, 16],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 3",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 4",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
@@ -12,44 +322,293 @@
"name": "T[0]",
"inputs": [
{
- "data": [
- 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4,
- 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
- ],
- "dims": [1, 3, 16],
+ "data": [
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+ 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 3, 32],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47
+ ],
+ "dims": [1, 3, 16],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [
+ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
+ 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 3, 16],
+ "type": "float32"
+ },
+ // past key, BNSH
+ {
+ "data": [],
+ "dims": [1, 2, 0, 8],
+ "type": "float32"
+ },
+ // past value, BNSH
+ {
+ "data": [],
+ "dims": [1, 2, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57,
+ 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
+ 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, 82, 83, 84, 85,
+ 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 3, 32],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [
+ 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 8, 9, 10, 11, 12,
+ 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47
+ ],
+ "dims": [1, 2, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [
+ 48, 49, 50, 51, 52, 53, 54, 55, 64, 65, 66, 67, 68, 69, 70, 71, 80, 81, 82, 83, 84, 85, 86, 87, 56, 57,
+ 58, 59, 60, 61, 62, 63, 72, 73, 74, 75, 76, 77, 78, 79, 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 2, 3, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 5",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 2, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 1, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 6",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [3],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
+ "dims": [1, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
+ "dims": [1, 1, 3, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 7",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 3, 8],
"type": "float32"
},
// key, BS*
{
- "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
"dims": [1, 3, 8],
"type": "float32"
},
// value, BS*
{
- "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
+ "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
"dims": [1, 3, 8],
"type": "float32"
},
// past key, BS*
{
- "data": null,
+ "data": [96, 97, 98, 99, 100, 101, 102, 103],
+ "dims": [1, 1, 1, 8],
"type": "float32"
},
// past value, BS*
{
- "data": null,
+ "data": [104, 105, 106, 107, 108, 109, 110, 111],
+ "dims": [1, 1, 1, 8],
"type": "float32"
},
// seqlens_k, unimplemented
{
- "data": [1],
+ "data": [3],
"dims": [1],
"type": "int32"
},
// total_sequence_length, unimplemented
{
- "data": [1],
+ "data": [4],
"dims": [1],
"type": "int32"
}
@@ -57,22 +616,28 @@
"outputs": [
{
"data": [
- 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2,
- 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21
+ 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108,
+ 109, 110, 111
],
- "dims": [1, 3, 16],
+ "dims": [1, 3, 8],
"type": "float32"
},
{
- // present key, BS*
- "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
- "dims": [1, 3, 2, 4],
+ // present key, BNSH
+ "data": [
+ 96, 97, 98, 99, 100, 101, 102, 103, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
+ 65, 66, 67, 68, 69, 70, 71
+ ],
+ "dims": [1, 1, 4, 8],
"type": "float32"
},
{
- // present value, BS*
- "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
- "dims": [1, 3, 2, 4],
+ // present value, BNSH
+ "data": [
+ 104, 105, 106, 107, 108, 109, 110, 111, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ 88, 89, 90, 91, 92, 93, 94, 95
+ ],
+ "dims": [1, 1, 4, 8],
"type": "float32"
}
]
@@ -80,13 +645,12 @@
]
},
{
- "name": "GroupQueryAttention Scale",
+ "name": " GroupQueryAttention 8",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "num_heads", "data": 4, "type": "int" },
- { "name": "kv_num_heads", "data": 2, "type": "int" },
- { "name": "scale", "data": 2.0, "type": "float" }
+ { "name": "kv_num_heads", "data": 2, "type": "int" }
],
"cases": [
{
@@ -94,38 +658,43 @@
"inputs": [
{
"data": [
- 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31
],
- "dims": [1, 4, 8],
+ "dims": [1, 1, 32],
"type": "float32"
},
+ // key, BS*
{
- "data": [1, 9, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 4],
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 16],
"type": "float32"
},
+ // value, BS*
{
- "data": [1, 1, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 4],
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
+ "dims": [1, 1, 16],
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
"data": [1],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
"data": [1],
"dims": [1],
@@ -135,35 +704,34 @@
"outputs": [
{
"data": [
- 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1,
- 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731
+ 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57,
+ 58, 59, 60, 61, 62, 63
],
- "dims": [1, 4, 8],
+ "dims": [1, 1, 32],
"type": "float32"
},
{
- // present key, BS*
- "data": [1, 9, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 2, 2],
+ // present key, BNSH
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 2, 1, 8],
"type": "float32"
},
{
- // present value, BS*
- "data": [1, 1, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 2, 2],
+ // present value, BNSH
+ "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
+ "dims": [1, 2, 1, 8],
"type": "float32"
}
]
}
]
},
-
{
- "name": "GroupQueryAttention, different sequence length",
+ "name": "GroupQueryAttention 9",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
- { "name": "num_heads", "data": 4, "type": "int" },
+ { "name": "num_heads", "data": 2, "type": "int" },
{ "name": "kv_num_heads", "data": 2, "type": "int" }
],
"cases": [
@@ -171,39 +739,41 @@
"name": "T[0]",
"inputs": [
{
- "data": [
- 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
- ],
- "dims": [1, 4, 8],
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ "dims": [1, 1, 16],
"type": "float32"
},
+ // key, BS*
{
- "data": [1, 9, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 4],
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 16],
"type": "float32"
},
+ // value, BS*
{
- "data": [1, 1, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 4],
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 16],
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
"data": [1],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
"data": [1],
"dims": [1],
@@ -212,23 +782,20 @@
],
"outputs": [
{
- "data": [
- 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823,
- 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2,
- 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1,
- 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815
- ],
- "dims": [1, 4, 8],
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 16],
"type": "float32"
},
{
- "data": [1, 9, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 2, 2],
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 2, 1, 8],
"type": "float32"
},
{
- "data": [1, 1, 1, 1, 2, 2, 2, 2],
- "dims": [1, 2, 2, 2],
+ // present value, BNSH
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 2, 1, 8],
"type": "float32"
}
]
@@ -236,12 +803,164 @@
]
},
{
- "name": "GroupQueryAttention Basic, q k v same head number",
+ "name": "GroupQueryAttention 10",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
- { "name": "num_heads", "data": 4, "type": "int" },
- { "name": "kv_num_heads", "data": 4, "type": "int" }
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 16],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 16],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 1, 16],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 1, 16],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 11",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ "dims": [1, 2, 8],
+ "type": "float32"
+ },
+ // key, BS*
+ {
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 2, 8],
+ "type": "float32"
+ },
+ // value, BS*
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 2, 8],
+ "type": "float32"
+ },
+ // pask key, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // pask value, BNSH
+ {
+ "data": [],
+ "dims": [1, 1, 0, 8],
+ "type": "float32"
+ },
+ // seqlens_k
+ {
+ "data": [2],
+ "dims": [1],
+ "type": "int32"
+ },
+ // total_sequence_length
+ {
+ "data": [2],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 2, 8],
+ "type": "float32"
+ },
+ {
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 2, 8],
+ "type": "float32"
+ },
+ {
+ // present value, BNSH
+ "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ "dims": [1, 1, 2, 8],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GroupQueryAttention 12",
+ "operator": "GroupQueryAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
],
"cases": [
{
@@ -249,45 +968,49 @@
"inputs": [
{
"data": [
- 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4,
- 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31
],
- "dims": [1, 3, 16],
+ "dims": [1, 1, 32],
"type": "float32"
},
+ // key, BS*
{
"data": [
- 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2,
- 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
+ 58, 59, 60, 61, 62, 63
],
- "dims": [1, 3, 16],
+ "dims": [1, 1, 32],
"type": "float32"
},
+ // value, BS*
{
"data": [
- 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1,
- 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 3, 16],
+ "dims": [1, 1, 32],
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 1, 0, 32],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 1, 0, 32],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
"data": [1],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
"data": [1],
"dims": [1],
@@ -297,26 +1020,28 @@
"outputs": [
{
"data": [
- 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2,
- 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 3, 16],
+ "dims": [1, 1, 32],
"type": "float32"
},
{
+ // present key, BNSH
"data": [
- 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2,
- 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
+ 58, 59, 60, 61, 62, 63
],
- "dims": [1, 3, 4, 4],
+ "dims": [1, 1, 1, 32],
"type": "float32"
},
{
+ // present value, BNSH
"data": [
- 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1,
- 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 3, 4, 4],
+ "dims": [1, 1, 1, 32],
"type": "float32"
}
]
@@ -324,12 +1049,12 @@
]
},
{
- "name": "GroupQueryAttention, no past kv, used as reference",
+ "name": "GroupQueryAttention 13",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
- { "name": "num_heads", "data": 4, "type": "int" },
- { "name": "kv_num_heads", "data": 2, "type": "int" }
+ { "name": "num_heads", "data": 1, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
],
"cases": [
{
@@ -337,50 +1062,51 @@
"inputs": [
{
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
- 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
- 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
- 107, 108, 109, 110, 111, 112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31
],
- "dims": [1, 7, 16],
+ "dims": [1, 4, 8],
"type": "float32"
},
+ // key, BS*
{
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
+ 58, 59, 60, 61, 62, 63
],
- "dims": [1, 7, 8],
+ "dims": [1, 4, 8],
"type": "float32"
},
+ // value, BS*
{
"data": [
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
- 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 7, 8],
+ "dims": [1, 4, 8],
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 1, 0, 8],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": null,
+ "data": [],
+ "dims": [1, 1, 0, 8],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
- "data": [1],
+ "data": [4],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
- "data": [1],
+ "data": [4],
"dims": [1],
"type": "int32"
}
@@ -388,29 +1114,28 @@
"outputs": [
{
"data": [
- 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
- 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
- 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
- 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
- 52, 53, 54, 55, 52, 53, 54, 55
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 7, 16],
+ "dims": [1, 4, 8],
"type": "float32"
},
{
+ // present key, BNSH
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
+ 58, 59, 60, 61, 62, 63
],
- "dims": [1, 7, 2, 4],
+ "dims": [1, 1, 4, 8],
"type": "float32"
},
{
+ // present value, BNSH
"data": [
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
- 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+ 90, 91, 92, 93, 94, 95
],
- "dims": [1, 7, 2, 4],
+ "dims": [1, 1, 4, 8],
"type": "float32"
}
]
@@ -418,12 +1143,12 @@
]
},
{
- "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1",
+ "name": "GroupQueryAttention PackedQKV 14",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
- { "name": "num_heads", "data": 4, "type": "int" },
- { "name": "kv_num_heads", "data": 2, "type": "int" }
+ { "name": "num_heads", "data": 2, "type": "int" },
+ { "name": "kv_num_heads", "data": 1, "type": "int" }
],
"cases": [
{
@@ -431,52 +1156,41 @@
"inputs": [
{
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
- 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
- 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
- 107, 108, 109, 110, 111, 112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31
],
- "dims": [1, 7, 16],
+ "dims": [1, 1, 32],
"type": "float32"
},
- // new key, BS*
+ // key, BS*
{
- "data": [
- 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
- 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
- ],
- "dims": [1, 6, 8],
+ "data": null,
"type": "float32"
},
- // new value, BS*
+ // value, BS*
{
- "data": [
- 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
- 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
- ],
- "dims": [1, 6, 8],
+ "data": null,
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": [1, 2, 3, 4, 5, 6, 7, 8],
- "dims": [1, 1, 2, 4],
+ "data": [],
+ "dims": [1, 1, 0, 8],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": [0, 1, 2, 3, 4, 5, 6, 7],
- "dims": [1, 1, 2, 4],
+ "data": [],
+ "dims": [1, 1, 0, 8],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
"data": [1],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
"data": [1],
"dims": [1],
@@ -485,30 +1199,20 @@
],
"outputs": [
{
- "data": [
- 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
- 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
- 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
- 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
- 52, 53, 54, 55, 52, 53, 54, 55
- ],
- "dims": [1, 7, 16],
+ "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 16],
"type": "float32"
},
{
- "data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
- ],
- "dims": [1, 7, 2, 4],
+ // present key, BNSH
+ "data": [16, 17, 18, 19, 20, 21, 22, 23],
+ "dims": [1, 1, 1, 8],
"type": "float32"
},
{
- "data": [
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
- 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
- ],
- "dims": [1, 7, 2, 4],
+ // present value, BNSH
+ "data": [24, 25, 26, 27, 28, 29, 30, 31],
+ "dims": [1, 1, 1, 8],
"type": "float32"
}
]
@@ -516,7 +1220,7 @@
]
},
{
- "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2",
+ "name": "GroupQueryAttention PackedQKV 15",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
@@ -529,54 +1233,48 @@
"inputs": [
{
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
- 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
- 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
- 107, 108, 109, 110, 111, 112
+ 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4,
+ 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2,
+ 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131,
+ 22, 21, 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1,
+ 1, 3, 4, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22,
+ 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2,
+ 2, 131, 22, 21
],
- "dims": [1, 7, 16],
+ "dims": [1, 3, 64],
"type": "float32"
},
- // new key, BS*
+ // key
{
- "data": [
- 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
- 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
- ],
- "dims": [1, 5, 8],
+ "data": null,
"type": "float32"
},
- // new value, BS*
+ // value
{
- "data": [
- 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
- 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
- ],
- "dims": [1, 5, 8],
+ "data": null,
"type": "float32"
},
- // past key, BS*
+ // pask key, BNSH
{
- "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
- "dims": [1, 2, 2, 4],
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // past value, BS*
+ // pask value, BNSH
{
- "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
- "dims": [1, 2, 2, 4],
+ "data": [],
+ "dims": [1, 2, 0, 8],
"type": "float32"
},
- // seqlens_k, unimplemented
+ // seqlens_k
{
- "data": [1],
+ "data": [3],
"dims": [1],
"type": "int32"
},
- // total_sequence_length, unimplemented
+ // total_sequence_length
{
- "data": [1],
+ "data": [3],
"dims": [1],
"type": "int32"
}
@@ -584,29 +1282,29 @@
"outputs": [
{
"data": [
- 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
- 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
- 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
- 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
- 52, 53, 54, 55, 52, 53, 54, 55
+ 1, 9, 1, 1, 2, 2, 2, 2, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 1, 12, 21, 131, 22, 21, 2,
+ 2, 8, 12, 233, 4, 5, 6, 7, 8, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4,
+ 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4
],
- "dims": [1, 7, 16],
+ "dims": [1, 3, 32],
"type": "float32"
},
{
+ // present key, BNSH
"data": [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
+ 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 131, 22, 21, 2, 2, 131, 22, 21, 5, 6, 7, 8, 1, 1, 3, 4,
+ 8, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 2, 2, 2
],
- "dims": [1, 7, 2, 4],
+ "dims": [1, 2, 3, 8],
"type": "float32"
},
{
+ // present value, BNSH
"data": [
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
- 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
+ 1, 9, 1, 1, 2, 2, 2, 2, 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2,
+ 5, 6, 7, 8, 1, 1, 3, 4, 131, 22, 21, 2, 2, 131, 22, 21
],
- "dims": [1, 7, 2, 4],
+ "dims": [1, 2, 3, 8],
"type": "float32"
}
]
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index ad14fb8258656..b15e865aa423c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -30,6 +30,7 @@ class Attention : public OpKernel, public AttentionCPUBase {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
@@ -101,6 +102,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index,
template
Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
/* The PrePack() massages the weights to speed up Compute(), there is an option to
diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
index 2c897f183164f..71a66ea368943 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
@@ -24,6 +24,7 @@ class QAttention : public OpKernel, public AttentionCPUBase {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
bool& /*out*/ is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
@@ -58,6 +59,7 @@ QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC
template
Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
if (1 != input_idx) {
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc
index aa47f365c0005..4148aae4b9a35 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc
@@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}
Status PrePack(const Tensor& tensor, int input_idx,
- AllocatorPtr alloc, /*out*/ bool& is_packed,
+ AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
@@ -91,6 +91,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke
}
Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 89e96543c4729..cee3dfc6b3f28 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -98,12 +98,19 @@ class MatMulNBits final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
+ void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx);
+
Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) override;
+ std::optional GetPrePackTensor(int /*input_idx*/) override;
+
+ Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override;
+
private:
const size_t K_;
const size_t N_;
@@ -119,6 +126,8 @@ class MatMulNBits final : public OpKernel {
size_t packed_b_size_{0};
IAllocatorUniquePtr scales_fp32_{};
IAllocatorUniquePtr bias_fp32_{};
+ std::optional packed_tensor_{std::nullopt};
+ MLDataType prepack_tensor_data_type_;
bool has_zp_input_{false};
@@ -148,8 +157,22 @@ class MatMulNBits final : public OpKernel {
}
};
+template
+void MatMulNBits::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) {
+ if (input_idx == InputIndex::B) {
+ prepack_tensor_data_type_ = tensor.DataType();
+ }
+
+ TensorShapeVector weights_dims = {static_cast((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1};
+ packed_tensor_ = Tensor(prepack_tensor_data_type_,
+ TensorShape(weights_dims),
+ packed_b_.get(),
+ OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator));
+}
+
template
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
@@ -185,11 +208,16 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
#endif // MLAS_TARGET_AMD64_IX86
}
+ if (save_prepacked_initializers) {
+ ConvertPrepackWeightIntoTensor(tensor, input_idx);
+ }
+
return Status::OK();
}
template <>
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
@@ -239,6 +267,34 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou
#endif // MLAS_TARGET_AMD64_IX86
}
+ if (save_prepacked_initializers) {
+ ConvertPrepackWeightIntoTensor(tensor, input_idx);
+ }
+
+ return Status::OK();
+}
+
+template
+std::optional MatMulNBits::GetPrePackTensor(int input_idx) {
+ // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points.
+ // During compute process, scales and zeros_points will keep as it is and only use prepacked
+ // buffer to replace input_B.
+ // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize
+ // the latest one. So, we need to always return packed_tensor_ here not only for input_B.
+ ORT_UNUSED_PARAMETER(input_idx);
+ return std::move(packed_tensor_);
+}
+
+template
+Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) {
+ if (input_idx == 1) {
+ // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state,
+ // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so
+ // pass empty/default buffer deleter here.
+ // const_cast here is temporary, will fix in follow up PR.
+ packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter());
+ }
+
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
index 67b4950af73bf..c9ee9e2cb760d 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
@@ -278,6 +278,7 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
template
Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
index 08e2276c3d9d5..d904c14857437 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
@@ -16,7 +16,7 @@ class SkipLayerNorm final : public OpKernel {
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* p_op_kernel_context) const override;
- Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;
private:
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
index dea5391c7629b..d190ed389f3e9 100644
--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
@@ -95,6 +95,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
}
Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/,
+ bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
index b408b3c1ee79b..4505c066baedb 100644
--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
@@ -17,6 +17,7 @@ class GroupNorm final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;
private:
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
index 3e93a527877c5..aa2c8755f6536 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
@@ -99,6 +99,7 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a
}
Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h
index 9d4e563c1feab..529fd00307d66 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h
@@ -20,6 +20,7 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase {
public:
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc
index a64f628f245e6..351e36b884540 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc
@@ -51,6 +51,7 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) {
}
Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /* prepacked_weights */) {
is_packed = false;
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h
index dcb6cc6374be1..d1cef99779e09 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h
@@ -18,6 +18,7 @@ class QOrderedMatMul final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h
index 7553883a2478d..dff8663133c31 100644
--- a/onnxruntime/contrib_ops/js/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
-
+#include "contrib_ops/cpu/bert/gqa_attention_base.h"
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
@@ -11,31 +11,29 @@ namespace js {
using onnxruntime::js::JsKernel;
-class GroupQueryAttention : public JsKernel {
+class GroupQueryAttention : public JsKernel, GQAAttentionBase {
public:
explicit GroupQueryAttention(const OpKernelInfo& info)
- : JsKernel(info) {
- int64_t num_heads = 0;
- int64_t kv_num_heads = 0;
- ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
- ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
- num_heads_ = static_cast(num_heads);
- kv_num_heads_ = static_cast(kv_num_heads);
- scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ : JsKernel(info), GQAAttentionBase(info, false) {
JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({
"numHeads" : $1,
"kvNumHeads" : $2,
"scale" : $3,
+ "softcap" : $4,
+ "doRotary" : $5,
+ "rotaryInterleaved" : $6,
+ "smoothSoftmax" : $7,
+ "localWindowSize" : $8
}),
static_cast(num_heads_),
static_cast(kv_num_heads_),
- static_cast(scale_));
+ static_cast(scale_),
+ static_cast(softcap_),
+ static_cast(do_rotary_),
+ static_cast(rotary_interleaved_),
+ static_cast(use_smooth_softmax_),
+ static_cast(local_window_size_));
}
-
- protected:
- int num_heads_; // number of attention heads
- int kv_num_heads_; // number of k and v heads
- float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size)
};
} // namespace js
diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc
index ea6c499829391..fe73a55735631 100644
--- a/onnxruntime/core/framework/external_data_loader.cc
+++ b/onnxruntime/core/framework/external_data_loader.cc
@@ -32,7 +32,7 @@ common::Status LoadWebAssemblyExternalData(const Env& env,
if (typeof Module == 'undefined' || !Module.MountedFiles) {
return 1; // "Module.MountedFiles" is not available.
}
- let fileName = UTF8ToString($0 >>> 0);
+ let fileName = UTF8ToString(Number($0 >>> 0));
if (fileName.startsWith('./')) {
fileName = fileName.substring(2);
}
@@ -40,9 +40,9 @@ common::Status LoadWebAssemblyExternalData(const Env& env,
if (!fileData) {
return 2; // File not found in preloaded files.
}
- const offset = $1 >>> 0;
- const length = $2 >>> 0;
- const dataIdOrBuffer = $3 >>> 0;
+ const offset = Number($1 >>> 0);
+ const length = Number($2 >>> 0);
+ const dataIdOrBuffer = Number($3 >>> 0);
const loadType = $4;
if (offset + length > fileData.byteLength) {
diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h
index 8d4db36106f28..18405231750ba 100644
--- a/onnxruntime/core/framework/session_options.h
+++ b/onnxruntime/core/framework/session_options.h
@@ -83,6 +83,11 @@ struct SessionOptions {
// enable profiling for this session.
bool enable_profiling = false;
+ // save pre-packed constant external initializers instead of original initializers to onnxruntime data file.
+ // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
+ // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
+ bool save_prepacked_constant_initializers = false;
+
// Non empty filepath enables serialization of the transformed optimized model to the specified filepath.
//
// Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly
@@ -191,6 +196,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_
<< " execution_mode:" << session_options.execution_mode
<< " execution_order:" << session_options.execution_order
<< " enable_profiling:" << session_options.enable_profiling
+ << " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers
<< " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath)
<< " enable_mem_pattern:" << session_options.enable_mem_pattern
<< " enable_mem_reuse:" << session_options.enable_mem_reuse
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 0d0b22ff61e01..943db091b341f 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -14,6 +14,7 @@
#include "core/framework/op_kernel.h"
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/session_state_utils.h"
+#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
@@ -397,12 +398,18 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type,
}
Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count,
- const std::unordered_map& initializers_to_share_map) {
- auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map](
+ const std::unordered_map& initializers_to_share_map,
+ bool save_prepacked_constant_initializers,
+ PrePackInitializers& pre_packed_initializers) {
+ auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map,
+ save_prepacked_constant_initializers, &pre_packed_initializers](
bool should_cache_prepacked_weights_for_shared_initializers) -> Status {
+ std::unordered_map pre_packed_kernel_input_map;
for (auto& node : GetGraphViewer().Nodes()) {
auto kernel = GetMutableKernel(node.Index());
+ auto kernel_name = kernel->Info().node().Name();
int input_idx = 0;
+ bool is_kernel_prepacked = false;
for (auto& input_def : node.InputDefs()) {
if (input_def->Exists()) {
const std::string& input_name = input_def->Name();
@@ -414,16 +421,27 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapGetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) {
std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_;
- if (constant_initialized_tensors.count(ort_value_idx)) {
+ if (constant_initialized_tensors.count(ort_value_idx) && !is_kernel_prepacked) {
bool is_packed = false;
const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get();
auto iter = initializers_to_share_map.find(input_name);
bool is_shared_initializer = (iter != initializers_to_share_map.end());
+ // found pre-packed constant initializers from data file, no need to do pre-packing again
+ // apply pre-packed tensor to kernel so kernel can use it directly
+ if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) != 0) {
+ is_packed = true;
+
+ // kernel like Matmul_nbits will call prepack multiple times with input_B and possibly scales/zero_points.
+ // If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked
+ // weights serialized), only need to set prepacked weights once to kernel.
+ is_kernel_prepacked = true;
+ ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor));
+ }
// Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now
- if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers &&
- node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON
+ else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers &&
+ node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON
AllocatorPtr allocator_for_caching = prepacked_weights_container_->GetOrCreateAllocator(CPU);
ORT_ENFORCE(allocator_for_caching.get() != nullptr);
@@ -435,7 +453,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapPrePack(const_initialized_tensor, input_idx, allocator_for_caching,
- is_packed,
+ save_prepacked_constant_initializers, is_packed,
&weights_to_be_filled_in));
if (is_packed) {
@@ -482,18 +500,50 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetDevice(OrtMemType::OrtMemTypeDefault));
ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx,
session_cpu_alloc, // use allocator tied to this session
+ save_prepacked_constant_initializers,
is_packed,
nullptr // no caching required
));
}
if (is_packed) {
+ // if intended to save prepacked initializers, get prepacked tensors from kernel and save in hashmap,
+ // will save to data file later
+ if (save_prepacked_constant_initializers) {
+ auto tensor = kernel->GetPrePackTensor(input_idx);
+
+ if (tensor != std::nullopt) {
+ // save prepacked initializers per initializer and kernel since one initializer could
+ // be used by multiple kernels
+ pre_packed_initializers.pre_packed_initializers_to_save[input_name][kernel_name] = std::move(tensor.value());
+
+ pre_packed_kernel_input_map[kernel_name] = input_name;
+ }
+ }
+
++number_of_prepacks_counter_;
- if (constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) {
+ // if constant_initialized_tensor is already pre-packed, don't need to remove it
+ if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) == 0 &&
+ constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) {
// release the constant initialized tensor
st->initialized_tensors_.erase(ort_value_idx);
constant_initialized_tensors.erase(ort_value_idx);
}
+ } else {
+ // handle prepack for matmul_nbits, it will prepack several times but set is_packed
+ // to false for scales and zero_points, we keep scales and zero_points as it is only
+ // update packed_tensor to input_B.
+ // TODO: this logic works with matmul_nbits kernel but if other kernels also call prepack
+ // multiple times and use different initializers to store prepacked weights, this piece of logic
+ // might introduce bug and need a per kernel strategy to update prepacked weights.
+ if (save_prepacked_constant_initializers && pre_packed_kernel_input_map.count(kernel_name)) {
+ auto tensor = kernel->GetPrePackTensor(input_idx);
+
+ if (tensor != std::nullopt) {
+ auto existing_input_name = pre_packed_kernel_input_map[kernel_name];
+ pre_packed_initializers.pre_packed_initializers_to_save[existing_input_name][kernel_name] = std::move(tensor.value());
+ }
+ }
}
}
// stop searching in 2 cases:
@@ -1176,6 +1226,7 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::
Status SessionState::FinalizeSessionState(const std::basic_string& graph_location,
const KernelRegistryManager& kernel_registry_manager,
+ PrePackInitializers& pre_packed_initializers,
bool remove_initializers,
bool saving_ort_format) {
// recursively create the subgraph session state instances and populate the kernel create info in them.
@@ -1189,7 +1240,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count;
ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count);
return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, sess_options_,
- remove_initializers, constant_initializers_use_count);
+ remove_initializers, constant_initializers_use_count, pre_packed_initializers);
}
static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map,
@@ -1323,6 +1374,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count,
+ PrePackInitializers& pre_packed_initializers,
const InlinedHashMap& outer_scope_node_arg_to_location_map,
bool graph_info_already_created) {
if (!graph_info_already_created) {
@@ -1422,6 +1474,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string>
+ typedef std::unordered_map> PrePackedTensorsToSave;
+ PrePackedTensorsToSave pre_packed_initializers_to_save;
+
+ // This set is used during model load with prepacked initializer serialized in external data file.
+ // ORT reads prepacked initializers and store their name into this set so we could skip PrePack
+ // process later to save heap memory. Prepacked tensor itself is saved in session state's constant_initialized_tensors_.
+ typedef std::unordered_set PrePackedTensorNamesReadFromFile;
+ PrePackedTensorNamesReadFromFile pre_packed_initializer_names_read_from_file;
+ };
+
Status FinalizeSessionState(const std::basic_string& graph_loc,
const KernelRegistryManager& kernel_registry_manager,
+ PrePackInitializers& pre_packed_initializers,
bool remove_initializers = true,
bool saving_ort_format = false);
@@ -321,6 +338,15 @@ class SessionState {
return parent_;
}
+ Status FinalizeSessionState(const std::basic_string& graph_loc,
+ const KernelRegistryManager& kernel_registry_manager,
+ bool remove_initializers = true,
+ bool saving_ort_format = false) {
+ PrePackInitializers pre_packed_initializers;
+ return FinalizeSessionState(graph_loc, kernel_registry_manager, pre_packed_initializers,
+ remove_initializers, saving_ort_format);
+ }
+
// Clear all removable attributes if they exists.
// The function logs the list of removable attributes for every node.
void PruneRemovableAttributes();
@@ -380,9 +406,13 @@ class SessionState {
/**
* Prepack the constant initialized tensors for better performance.
* The original constant initialized tensors will be removed to save memory.
+ * For model with prepacked initializer serialized into ONNX data file,
+ * PrePack will be skipped to save memory.
*/
Status PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count,
- const std::unordered_map& initializers_to_share_map);
+ const std::unordered_map& initializers_to_share_map,
+ bool save_prepacked_constant_initializers,
+ PrePackInitializers& pre_packed_initializers);
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
@@ -400,6 +430,7 @@ class SessionState {
const SessionOptions& session_options,
bool remove_initializers,
InlinedHashMap& constant_initializers_use_count,
+ PrePackInitializers& pre_packed_initializers,
const InlinedHashMap& outer_scope_node_arg_to_location_map = {},
bool graph_info_already_created = false);
diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc
index 2c74805c57dce..3424f40e79c01 100644
--- a/onnxruntime/core/framework/session_state_utils.cc
+++ b/onnxruntime/core/framework/session_state_utils.cc
@@ -21,7 +21,6 @@
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/ort_value_name_idx_map.h"
#include "core/framework/sequential_execution_plan.h"
-#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/framework/bfc_arena.h"
@@ -72,6 +71,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env,
const std::basic_string& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto,
Tensor& tensor, OrtCallback& ext_data_deleter,
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set,
Tensor* buffered_tensor = nullptr) {
ORT_ENFORCE(utils::HasExternalData(tensor_proto));
@@ -79,7 +79,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env,
SafeInt ext_data_len = 0;
ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto,
ext_data_buf, ext_data_len, ext_data_deleter,
- buffered_tensor));
+ &pre_packed_initializers_name_set, buffered_tensor));
// NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be
// avoided if the Tensor class implements the do-nothing behavior when given a
@@ -100,6 +100,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc,
OrtValue& ort_value, const DataTransferManager& data_transfer_mgr,
const ExternalDataLoaderManager& external_data_loader_mgr,
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set,
bool use_device_allocator_for_initializers = false,
Tensor* buffered_tensor = nullptr) {
if (bool(alloc) == (m != nullptr)) {
@@ -139,7 +140,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
// TensorProtoToTensor it would copy the data, causing unnecessary overhead
OrtCallback ext_data_deleter;
ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor,
- ext_data_deleter, buffered_tensor));
+ ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor));
ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()};
MLDataType ml_tensor_type = DataTypeImpl::GetType();
@@ -163,7 +164,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
OrtCallback ext_data_deleter;
std::optional scoped_ort_callback_invoker;
ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor,
- ext_data_deleter, buffered_tensor));
+ ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor));
scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter);
// TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation.
@@ -272,7 +273,8 @@ common::Status SaveInitializedTensors(
const ExecutionPlanBase& exec_plan,
const SessionOptions& session_options,
const MemoryProfileFunction& memory_profile_func,
- std::unordered_map>& buffered_tensors) {
+ std::unordered_map>& buffered_tensors,
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set) {
LOGS(logger, INFO) << "Saving initialized tensors.";
ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated.");
@@ -401,6 +403,7 @@ common::Status SaveInitializedTensors(
Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc,
default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr,
+ pre_packed_initializers_name_set,
use_device_allocator_for_initializers, p_tensor);
if (!st.IsOK()) {
std::ostringstream oss;
diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h
index af27f5caba0f4..4de501b6f7429 100644
--- a/onnxruntime/core/framework/session_state_utils.h
+++ b/onnxruntime/core/framework/session_state_utils.h
@@ -12,6 +12,7 @@
#include "core/framework/tensor.h"
#include "core/framework/tensor_allocator.h"
#include "core/framework/session_options.h"
+#include "core/framework/session_state.h"
#include "core/framework/sequential_execution_plan.h"
#include "core/platform/path_lib.h"
@@ -50,7 +51,8 @@ common::Status SaveInitializedTensors(
const ExecutionPlanBase& exec_plan,
const SessionOptions& session_options,
const MemoryProfileFunction& memory_profile_func,
- std::unordered_map>& buffered_tensors);
+ std::unordered_map>& buffered_tensors,
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set);
common::Status AllocateTensor(
const onnxruntime::MemBuffer* m,
diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc
index 93146e66d9f24..bcd04effe2bd4 100644
--- a/onnxruntime/core/framework/tensor_external_data_info.cc
+++ b/onnxruntime/core/framework/tensor_external_data_info.cc
@@ -40,6 +40,8 @@ Status ExternalDataInfo::Create(const RepeatedPtrField&
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed");
} else if (stringmap.key() == "checksum" && !stringmap.value().empty()) {
out->checksum_ = stringmap.value();
+ } else if (stringmap.key() == "prepacked" && !stringmap.value().empty()) {
+ out->prepacked_ = stringmap.value() == "1";
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error!");
}
diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h
index afc8fda6c3037..c2490f5cc5bc2 100644
--- a/onnxruntime/core/framework/tensor_external_data_info.h
+++ b/onnxruntime/core/framework/tensor_external_data_info.h
@@ -23,6 +23,8 @@ class ExternalDataInfo {
const std::string& GetChecksum() const { return checksum_; }
+ bool GetPrePacked() const noexcept { return prepacked_; }
+
// If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a
// wrong value and return FAIL.
static common::Status Create(
@@ -36,5 +38,6 @@ class ExternalDataInfo {
// 0 means the whole file
size_t length_ = 0;
std::string checksum_;
+ bool prepacked_ = false;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc
index 74c359881a1d7..0c69ee11f62bc 100644
--- a/onnxruntime/core/framework/tensorprotoutils.cc
+++ b/onnxruntime/core/framework/tensorprotoutils.cc
@@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)
-static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
- const std::filesystem::path& tensor_proto_dir,
- std::basic_string& external_file_path,
- onnxruntime::FileOffsetType& file_offset,
- SafeInt& tensor_byte_size) {
- ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
- "Tensor does not have external data to read from.");
-
- ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
- "External data type cannot be UNDEFINED or STRING.");
-
- std::unique_ptr external_data_info;
- ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
-
- const auto& location = external_data_info->GetRelPath();
-
- external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
- : (tensor_proto_dir / location);
-
- ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
- const size_t external_data_length = external_data_info->GetLength();
- ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
- "TensorProto: ", tensor_proto.name(),
- " external data size mismatch. Computed size: ", *&tensor_byte_size,
- ", external_data.length: ", external_data_length);
-
- file_offset = external_data_info->GetOffset();
-
- return Status::OK();
-}
-
// Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully.
// Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr
// then uses the current directory instead.
@@ -261,10 +230,49 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo
namespace utils {
+static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
+ const std::filesystem::path& tensor_proto_dir,
+ std::basic_string& external_file_path,
+ onnxruntime::FileOffsetType& file_offset,
+ SafeInt& tensor_byte_size,
+ bool& pre_packed) {
+ ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
+ "Tensor does not have external data to read from.");
+
+ ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
+ "External data type cannot be UNDEFINED or STRING.");
+
+ std::unique_ptr external_data_info;
+ ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
+
+ pre_packed = external_data_info->GetPrePacked();
+
+ const auto& location = external_data_info->GetRelPath();
+
+ external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
+ : (tensor_proto_dir / location);
+
+ ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
+ const size_t external_data_length = external_data_info->GetLength();
+ ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
+ "TensorProto: ", tensor_proto.name(),
+ " external data size mismatch. Computed size: ", *&tensor_byte_size,
+ ", external_data.length: ", external_data_length);
+
+ file_offset = external_data_info->GetOffset();
+
+ return Status::OK();
+}
+
void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
tensor_proto.set_raw_data(std::move(param));
}
+Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) {
+ bool pre_packed = false;
+ return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed);
+}
+
void ConvertRawDataInTensorProto(TensorProto* tensor) {
size_t element_size = 1;
char* bytes = NULL;
@@ -988,7 +996,7 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p
Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf,
SafeInt& ext_data_len, OrtCallback& ext_data_deleter,
- Tensor* buffered_tensor) {
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor) {
ORT_ENFORCE(utils::HasExternalData(tensor_proto));
std::basic_string tensor_proto_dir;
if (!model_path.empty()) {
@@ -997,8 +1005,13 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo
std::basic_string external_data_file_path;
FileOffsetType file_offset;
SafeInt raw_data_safe_len = 0;
+ bool pre_packed = false;
ORT_RETURN_IF_ERROR(
- GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len));
+ GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, pre_packed));
+
+ if (pre_packed && pre_packed_initializers_name_set != nullptr) {
+ (*pre_packed_initializers_name_set).insert(tensor_proto.name());
+ }
if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) {
// the value in location is the memory address of the data
@@ -1108,7 +1121,7 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa
OrtCallback& d = deleter_for_file_data.d;
if (utils::HasExternalData(tensor_proto)) {
- ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d));
+ ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d, nullptr));
} else if (utils::HasRawData(tensor_proto)) {
raw_data = const_cast(tensor_proto.raw_data().data());
// TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data
diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h
index 227ba0706197e..770132f8e95fc 100644
--- a/onnxruntime/core/framework/tensorprotoutils.h
+++ b/onnxruntime/core/framework/tensorprotoutils.h
@@ -17,12 +17,19 @@
#include "core/framework/external_data_loader.h"
#include "core/framework/ort_value.h"
#include "core/framework/mem_buffer.h"
+#include "core/framework/session_state.h"
#include "core/framework/tensor_external_data_info.h"
#include "core/graph/onnx_protobuf.h"
#include "core/platform/env.h"
namespace onnxruntime {
namespace utils {
+Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
+ const std::filesystem::path& tensor_proto_dir,
+ std::basic_string& external_file_path,
+ onnxruntime::FileOffsetType& file_offset,
+ SafeInt& tensor_byte_size);
+
/**
* This function is used to convert the endianess of Tensor data.
* Mostly, will be used in big endian system to support the model file
@@ -158,6 +165,7 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::
const ONNX_NAMESPACE::TensorProto& tensor_proto,
void*& ext_data_buf, SafeInt& ext_data_len,
OrtCallback& ext_data_deleter,
+ SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set,
Tensor* buffered_tensor = nullptr);
// Given a tensor proto with external data obtain a tensor using the specified custom external data loader.
diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc
index 9eed0249711f9..5402345447706 100644
--- a/onnxruntime/core/framework/utils.cc
+++ b/onnxruntime/core/framework/utils.cc
@@ -1064,5 +1064,11 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index
return false;
}
+std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name) {
+ const std::string seperator = ":";
+
+ return initializer_name + seperator + node_name;
+}
+
} // namespace utils
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h
index afdb5a2cb27f5..db38ef1675595 100644
--- a/onnxruntime/core/framework/utils.h
+++ b/onnxruntime/core/framework/utils.h
@@ -234,6 +234,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
+std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name);
+
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index e8a5855b36496..3f50841f50913 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -4084,10 +4084,75 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const {
return result;
}
+void Graph::SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info,
+ size_t tensor_bytes_size,
+ int64_t& external_offset,
+ std::ofstream& external_stream,
+ gsl::span raw_data,
+ ONNX_NAMESPACE::TensorProto& output_proto,
+ const std::filesystem::path& external_file_path,
+ const ONNX_NAMESPACE::TensorProto& initializer,
+ bool is_prepacked) {
+ // update external_offset for alignment
+ // need to do padding before write actual tensor data as we do offset alignment at the begin of
+ // large tensors (offset need to be page aligned and alloction granularity aligned) like below:
+ // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX
+ // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->|
+ if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) {
+ // Align to the larger of the page size or the allocation granularity
+ int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity);
+ // Align to the next page or alloc granularity boundary
+ int64_t new_external_offset = static_cast(
+ std::floor((external_offset + alignment_factor - 1) / alignment_factor)) *
+ alignment_factor;
+
+ // padding tensor with zeros for alignment
+ InlinedVector paddings;
+ size_t padding_size = SafeInt(new_external_offset - external_offset);
+ paddings.reserve(padding_size);
+ for (size_t index = 0; index != padding_size; ++index) {
+ paddings.push_back(0x0);
+ }
+ external_stream.write(reinterpret_cast(paddings.data()), padding_size);
+
+ external_offset = new_external_offset;
+ }
+
+ external_stream.write(reinterpret_cast(raw_data.data()), tensor_bytes_size);
+
+ output_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
+ ONNX_NAMESPACE::StringStringEntryProto* location = output_proto.add_external_data();
+ location->set_key("location");
+ location->set_value(ToUTF8String(external_file_path.native()));
+ ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto.add_external_data();
+ offset->set_key("offset");
+ offset->set_value(std::to_string(external_offset));
+ ONNX_NAMESPACE::StringStringEntryProto* length = output_proto.add_external_data();
+ length->set_key("length");
+ length->set_value(std::to_string(tensor_bytes_size));
+
+ if (is_prepacked) {
+ ONNX_NAMESPACE::StringStringEntryProto* pre_packed = output_proto.add_external_data();
+ pre_packed->set_key("prepacked");
+ pre_packed->set_value("1");
+ }
+
+ output_proto.set_name(initializer.name());
+ output_proto.set_data_type(initializer.data_type());
+ for (int i = 0; i != initializer.dims_size(); ++i) {
+ output_proto.add_dims(initializer.dims(i));
+ }
+ output_proto.set_doc_string(initializer.doc_string());
+
+ external_offset += tensor_bytes_size;
+}
+
ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
const std::filesystem::path& model_file_path,
size_t initializer_size_threshold,
- const OffsetAlignmentInfo& align_info) const {
+ const OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ PrePackedTensorProtoToSave& pre_packed_initializers) const {
GraphProto result;
ToGraphProtoInternal(result);
ORT_ENFORCE(external_file_path.is_relative());
@@ -4106,6 +4171,34 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std
#endif
for (const auto& initializer : graph_proto_->initializer()) {
+ bool use_pre_packed_initializer = false;
+ InlinedVector pre_packed_initializers_tensor_proto;
+ // If this initializer has been prepacked, saved prepacked external initializer instead of original one.
+ // Since one initializer could be used by multiple kernels and been prepacked differently,
+ // Save each prepacked initializers seperately, chagne the initializer name to [initializer_name]:[node_name]
+ // to avoid conflict. Change the node input name accordingly.
+ // IT could potentially make the ONNX data file larger since we store multiple prepacked initializers into disk
+ // but this could be rare case.
+ if (save_prepacked_constant_initializers && pre_packed_initializers.count(initializer.name())) {
+ for (const auto& item : pre_packed_initializers[initializer.name()]) {
+ auto& node_name = item.first;
+ std::string prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer.name(), node_name);
+ pre_packed_initializers_tensor_proto.push_back(item.second);
+ use_pre_packed_initializer = true;
+
+ for (auto& node : *result.mutable_node()) {
+ if (node.name() == node_name) {
+ int input_index = 0;
+ for (const auto& input : node.input()) {
+ if (input == initializer.name()) {
+ node.set_input(input_index, prepacked_initializer_name);
+ }
+ input_index += 1;
+ }
+ }
+ }
+ }
+ }
#if !defined(DISABLE_SPARSE_TENSORS)
if (sparse_end != sparse_tensor_names_.find(initializer.name())) {
// Sparse tensors are added to the ONNX file.
@@ -4114,61 +4207,39 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std
ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse");
} else {
#endif
- // Dense tensors larger than the threshold are added to the external file.
- TensorProto* output_proto = result.add_initializer();
-
- std::vector raw_data;
- ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data));
- size_t tensor_bytes_size = raw_data.size();
- if (tensor_bytes_size < initializer_size_threshold) {
- *output_proto = initializer;
- continue;
- }
+ if (use_pre_packed_initializer) {
+ for (const auto& pre_packed_initializer : pre_packed_initializers_tensor_proto) {
+ // Dense tensors larger than the threshold are added to the external file.
+ TensorProto* output_proto = result.add_initializer();
+ std::vector raw_data;
+ size_t tensor_bytes_size = 0;
+
+ ORT_THROW_IF_ERROR(utils::UnpackInitializerData(pre_packed_initializer, model_path, raw_data));
+ tensor_bytes_size = raw_data.size();
+ if (tensor_bytes_size < initializer_size_threshold) {
+ *output_proto = pre_packed_initializer;
+ continue;
+ }
- // update external_offset for alignment
- // need to do padding before write actual tensor data as we do offset alignment at the begin of
- // large tensors (offset need to be page aligned and alloction granularity aligned) like below:
- // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX
- // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->|
- if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) {
- // Align to the larger of the page size or the allocation granularity
- int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity);
- // Align to the next page or alloc granularity boundary
- int64_t new_external_offset = static_cast(
- std::floor((external_offset + alignment_factor - 1) / alignment_factor)) *
- alignment_factor;
-
- // padding tensor with zeros for alignment
- for (int64_t index = external_offset; index != new_external_offset; ++index) {
- external_stream << '0';
+ SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream,
+ raw_data, *output_proto, external_file_path, pre_packed_initializer, true);
+ }
+ } else {
+ // Dense tensors larger than the threshold are added to the external file.
+ TensorProto* output_proto = result.add_initializer();
+ std::vector raw_data;
+ size_t tensor_bytes_size = 0;
+
+ ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data));
+ tensor_bytes_size = raw_data.size();
+ if (tensor_bytes_size < initializer_size_threshold) {
+ *output_proto = initializer;
+ continue;
}
- external_offset = new_external_offset;
- }
-
- for (size_t index = 0; index != tensor_bytes_size; ++index) {
- external_stream << raw_data[index];
- }
-
- output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
- ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data();
- location->set_key("location");
- location->set_value(ToUTF8String(external_file_path.native()));
- ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data();
- offset->set_key("offset");
- offset->set_value(std::to_string(external_offset));
- ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data();
- length->set_key("length");
- length->set_value(std::to_string(tensor_bytes_size));
-
- output_proto->set_name(initializer.name());
- output_proto->set_data_type(initializer.data_type());
- for (int i = 0; i != initializer.dims_size(); ++i) {
- output_proto->add_dims(initializer.dims(i));
+ SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream,
+ raw_data, *output_proto, external_file_path, initializer, false);
}
- output_proto->set_doc_string(initializer.doc_string());
-
- external_offset += tensor_bytes_size;
#if !defined(DISABLE_SPARSE_TENSORS)
}
#endif
diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc
index e67884e3875d8..ad1ec9c8dedb3 100644
--- a/onnxruntime/core/graph/model.cc
+++ b/onnxruntime/core/graph/model.cc
@@ -384,13 +384,17 @@ ModelProto Model::ToProto() const {
ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name,
const std::filesystem::path& file_path,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info) const {
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const {
ModelProto result(model_proto_);
const auto& graph = *graph_;
*(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name,
file_path,
initializer_size_threshold,
- align_info);
+ align_info,
+ save_prepacked_constant_initializers,
+ pre_packed_initializers);
return result;
}
@@ -554,8 +558,8 @@ static Status SaveModel(Model& model, const T& file_path) {
model_proto.SerializeToArray(buffer, buffer_size);
EM_ASM(({
- const buffer = $0;
- const buffer_size = $1;
+ const buffer = Number($0);
+ const buffer_size = Number($1);
const file_path = UTF8ToString($2);
const bytes = new Uint8Array(buffer_size);
bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size));
@@ -570,9 +574,9 @@ static Status SaveModel(Model& model, const T& file_path) {
window.open(url, '_blank');
}
}),
- reinterpret_cast(buffer),
- static_cast(buffer_size),
- reinterpret_cast(file_path.c_str()));
+ buffer,
+ buffer_size,
+ file_path.c_str());
free(buffer);
return Status::OK();
@@ -608,7 +612,9 @@ static Status SaveModelWithExternalInitializers(Model& model,
const T& file_path,
const std::filesystem::path& external_file_name,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info) {
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers) {
int fd = 0;
Status status = Env::Default().FileOpenWr(file_path, fd);
ORT_RETURN_IF_ERROR(status);
@@ -616,7 +622,8 @@ static Status SaveModelWithExternalInitializers(Model& model,
ORT_TRY {
status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name,
initializer_size_threshold,
- align_info);
+ align_info, save_prepacked_constant_initializers,
+ pre_packed_initializers);
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
@@ -647,9 +654,12 @@ Status Model::Load(const PathString& file_path, std::shared_ptr& p_model,
Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path,
const std::filesystem::path& external_file_name,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info) {
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers) {
return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold,
- align_info);
+ align_info, save_prepacked_constant_initializers,
+ pre_packed_initializers);
}
Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
@@ -766,7 +776,9 @@ Status Model::SaveWithExternalInitializers(Model& model,
const std::filesystem::path& file_path,
const std::filesystem::path& external_file_name,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info) {
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers) {
if (fd < 0) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0.");
}
@@ -775,7 +787,8 @@ Status Model::SaveWithExternalInitializers(Model& model,
auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path,
initializer_size_threshold,
- align_info);
+ align_info, save_prepacked_constant_initializers,
+ pre_packed_initializers);
google::protobuf::io::FileOutputStream output(fd);
const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush();
if (result) {
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index 9bcec6f78ca08..38d9044ff9d31 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -191,13 +191,17 @@ class Model {
ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name,
const std::filesystem::path& file_path,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info) const;
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const;
ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name,
const std::filesystem::path& file_path,
size_t initializer_size_threshold) const {
Graph::OffsetAlignmentInfo default_align_info;
- return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info);
+ Graph::PrePackedTensorProtoToSave pre_packed_initializers;
+ return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info,
+ false, pre_packed_initializers);
}
static common::Status Save(Model& model, const PathString& file_path);
@@ -210,14 +214,18 @@ class Model {
const std::filesystem::path& file_path,
const std::filesystem::path& external_file_path,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info);
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers);
static common::Status SaveWithExternalInitializers(Model& model,
const std::filesystem::path& file_path,
const std::filesystem::path& external_file_path,
size_t initializer_size_threshold) {
Graph::OffsetAlignmentInfo default_align_info;
- return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info);
+ Graph::PrePackedTensorProtoToSave pre_packed_initializers;
+ return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info,
+ false, pre_packed_initializers);
}
static common::Status SaveWithExternalInitializers(Model& model,
@@ -225,7 +233,9 @@ class Model {
const std::filesystem::path& file_path,
const std::filesystem::path& external_file_path,
size_t initializer_size_threshold,
- const Graph::OffsetAlignmentInfo& align_info);
+ const Graph::OffsetAlignmentInfo& align_info,
+ bool save_prepacked_constant_initializers,
+ Graph::PrePackedTensorProtoToSave& pre_packed_initializers);
static common::Status SaveWithExternalInitializers(Model& model,
int fd,
@@ -233,7 +243,9 @@ class Model {
const std::filesystem::path& external_file_path,
size_t initializer_size_threshold) {
Graph::OffsetAlignmentInfo default_align_info;
- return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info);
+ Graph::PrePackedTensorProtoToSave pre_packed_initializers;
+ return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info,
+ false, pre_packed_initializers);
}
static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto);
diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
index 37db095e92570..0a1a3a5995872 100644
--- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
+++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
@@ -51,6 +51,7 @@ class FusedConvFp16 final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override;
Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
@@ -101,6 +102,7 @@ class FusedConvFp16 final : public OpKernel {
};
Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc
index 5406dd1a40446..dbc7becdf2397 100644
--- a/onnxruntime/core/providers/cpu/math/gemm.cc
+++ b/onnxruntime/core/providers/cpu/math/gemm.cc
@@ -248,6 +248,7 @@ template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE
template
Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc_for_caching*/,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /*prepacked_weight_for_caching*/) {
is_packed = false;
@@ -256,7 +257,7 @@ Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, Allocat
template <>
Status Gemm::PrePack(const Tensor& tensor, int input_idx,
- AllocatorPtr alloc, /*out*/ bool& is_packed,
+ AllocatorPtr alloc, bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h
index 953949732560d..92f05a7921f8b 100644
--- a/onnxruntime/core/providers/cpu/math/gemm.h
+++ b/onnxruntime/core/providers/cpu/math/gemm.h
@@ -21,6 +21,7 @@ class Gemm : protected GemmBase, public OpKernel {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc
index 2c6d23e4de908..8f2c2c53b188b 100644
--- a/onnxruntime/core/providers/cpu/math/matmul.cc
+++ b/onnxruntime/core/providers/cpu/math/matmul.cc
@@ -173,6 +173,7 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc,
#endif
Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h
index b9bbe36583879..0bb0e6c2ef596 100644
--- a/onnxruntime/core/providers/cpu/math/matmul.h
+++ b/onnxruntime/core/providers/cpu/math/matmul.h
@@ -37,6 +37,7 @@ class MatMul final : public OpKernel {
}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc
index f0c1b0b409831..2c7afddf38070 100644
--- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc
+++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc
@@ -38,6 +38,7 @@ ONNX_CPU_OPERATOR_KERNEL(
template
Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /*prepacked_weights*/
) {
@@ -47,6 +48,7 @@ Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, Al
template <>
Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h
index c82cd5ad49d7e..d03b5566e334f 100644
--- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h
+++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h
@@ -28,6 +28,7 @@ class ConvTranspose : public OpKernel {
ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
index 24a5dcab225c4..fe2bf1035bb65 100644
--- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
+++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
@@ -229,6 +229,7 @@ Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const {
}
Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
index f8b528b398cba..abce87d03c14b 100644
--- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
+++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
@@ -15,7 +15,7 @@ class LayerNormImpl : public OpKernel {
LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false);
Status Compute(OpKernelContext* p_op_kernel_context) const override;
- Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;
// This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`.
diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h
index e26eae19b8fd4..8a8ce27990069 100644
--- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h
+++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h
@@ -14,6 +14,7 @@ class MatMulIntegerBase : public OpKernel {
MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc
index 7797cbe678bd4..736cde24591ff 100644
--- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc
+++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc
@@ -25,6 +25,7 @@ class QLinearConv : public OpKernel {
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
@@ -360,6 +361,7 @@ REGISTER_QLINEARCONV_INT8_KERNEL(kMSDomain, 1);
template
Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
index b78c5236e6fab..7afd00eacef89 100644
--- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
+++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
@@ -284,6 +284,7 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr&
}
Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h
index 5a6dd97c7c3f2..914077b2f2c15 100644
--- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h
+++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h
@@ -62,6 +62,7 @@ class DeepCpuGruOp final : public OpKernel {
private:
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
@@ -197,4 +198,4 @@ class UniDirectionalGru {
};
} // namespace detail
-} // namespace onnxruntime
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
index 09bbf6c4c79e6..e4082e5d7634a 100644
--- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
+++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
@@ -225,7 +225,9 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke
}
Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx,
- AllocatorPtr alloc, /*out*/ bool& is_packed,
+ AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
+ /*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h
index 9c4c12954022a..ff8ab9abf0eed 100644
--- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h
+++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h
@@ -19,6 +19,7 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase {
DeepCpuLstmOp(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc
index 3129f519da2e5..45a1d3bbc0414 100644
--- a/onnxruntime/core/providers/cuda/nn/conv.cc
+++ b/onnxruntime/core/providers/cuda/nn/conv.cc
@@ -52,6 +52,7 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true)
// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW
template
Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h
index e4047a6af272e..6294566af3cb9 100644
--- a/onnxruntime/core/providers/cuda/nn/conv.h
+++ b/onnxruntime/core/providers/cuda/nn/conv.h
@@ -219,6 +219,7 @@ class Conv : public CudaKernel {
}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;
Status ComputeInternal(OpKernelContext* context) const override;
diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
index 2972ae999adc4..9c9a83460daeb 100644
--- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
+++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
@@ -45,7 +45,8 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true)
// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW
template
-Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed,
+Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool /*save_prepacked_initializers*/, bool& is_packed,
[[maybe_unused]] PrePackedWeights* prepacked_weights) {
is_packed = false;
// only layout of weight input is adjusted via PrePack
diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h
index 3b8f117522210..b41e715c060be 100644
--- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h
+++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h
@@ -22,6 +22,7 @@ class ConvTranspose : public CudaKernel {
ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){};
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override;
Status ComputeInternal(OpKernelContext* context) const override;
Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp
index 45ff25c4fdd90..02fb72b5a073a 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp
@@ -50,5 +50,6 @@ class DmlOperatorCast : public DmlOperator
DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast);
DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast);
DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast);
+DML_OP_DEFINE_CREATION_FUNCTION(CastLike21, DmlOperatorCast);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
index 9b7ad9aa9e088..f8710fd266c07 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
@@ -123,5 +123,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel);
+DML_OP_DEFINE_CREATION_FUNCTION(Pad21, VersionedKernel);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 2375131cb34ea..ceed388bb0a6f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -365,6 +365,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad11);
DML_OP_EXTERN_CREATION_FUNCTION(Pad13);
DML_OP_EXTERN_CREATION_FUNCTION(Pad18);
DML_OP_EXTERN_CREATION_FUNCTION(Pad19);
+DML_OP_EXTERN_CREATION_FUNCTION(Pad21);
DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth);
DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace);
DML_OP_EXTERN_CREATION_FUNCTION(Sqrt);
@@ -445,6 +446,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(Cast);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike15);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike19);
+DML_OP_EXTERN_CREATION_FUNCTION(CastLike21);
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost);
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost);
DML_OP_EXTERN_CREATION_FUNCTION(TopK7);
@@ -792,6 +794,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
+ {REG_INFO( 21, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis.
@@ -804,6 +807,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},
+ {REG_INFO_VER( 21, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},
#if DML_TARGET_VERSION >= 0x6400
{REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},
@@ -819,6 +823,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))},
+ {REG_INFO( 21, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
{REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
@@ -853,6 +858,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
+ {REG_INFO_COPY(21, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
@@ -1087,6 +1093,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
+ {REG_INFO_VER( 21, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)},
@@ -1102,6 +1109,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
+ {REG_INFO( 21, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 323fcc779d98d..c1ea69ab35374 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1673,6 +1673,7 @@ using ShapeInferenceHelper_Flatten7 = FlattenHelper;
using ShapeInferenceHelper_Flatten9 = FlattenHelper;
using ShapeInferenceHelper_Flatten11 = FlattenHelper;
using ShapeInferenceHelper_Flatten13 = FlattenHelper;
+using ShapeInferenceHelper_Flatten21 = FlattenHelper;
using ShapeInferenceHelper_Split7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split13 = VersionedOpsetHelper;
@@ -1689,6 +1690,7 @@ using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper;
using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper;
using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper;
+using ShapeInferenceHelper_Pad21 = VersionedOpsetHelper;
using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper;
using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
@@ -1865,6 +1867,7 @@ using ShapeInferenceHelper_Range = RangeHelper;
using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_CastLike21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DmlFusedConv = ConvHelper;
using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 26529c0d59dd6..c2a6d57fca0a9 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -446,6 +446,12 @@ namespace OperatorHelper
static const int sc_sinceVer_Reshape = 21;
static const int sc_sinceVer_Cast = 21;
static const int sc_sinceVer_Shape = 21;
+ static const int sc_sinceVer_Size = 21;
+ static const int sc_sinceVer_CastLike = 21;
+ static const int sc_sinceVer_ConstantOfShape = 21;
+ static const int sc_sinceVer_Flatten = 21;
+ static const int sc_sinceVer_Pad = 21;
+ static const int sc_sinceVer_Transpose = 21;
}
namespace MsftOperatorSet1
diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc
index ebea041b80128..3809df2c82e4c 100644
--- a/onnxruntime/core/providers/js/data_transfer.cc
+++ b/onnxruntime/core/providers/js/data_transfer.cc
@@ -6,7 +6,7 @@
#include "core/providers/js/data_transfer.h"
EM_ASYNC_JS(void, jsepDownload, (const void* src_data, void* dst_data, size_t bytes), {
- await Module.jsepCopyAsync(src_data, dst_data, bytes);
+ await Module.jsepCopyAsync(Number(src_data), Number(dst_data), Number(bytes));
});
namespace onnxruntime {
@@ -30,10 +30,10 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
if (dst_device.Type() == OrtDevice::GPU) {
if (src_device.Type() == OrtDevice::GPU) {
// copy from GPU to GPU
- EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes);
+ EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2), true); }, src_data, dst_data, bytes);
} else {
// copy from CPU to GPU
- EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes);
+ EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes);
}
} else /* if (src_device.Type() == OrtDevice::GPU) */ {
// copy from GPU to CPU
diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc
index 2402bb33ce9d0..f99e90bcb13f6 100644
--- a/onnxruntime/core/providers/js/js_export.cc
+++ b/onnxruntime/core/providers/js/js_export.cc
@@ -6,8 +6,8 @@
#include "core/framework/op_kernel.h"
const void* JsepOutput(void* context, int index, const void* data) {
- const uint32_t* data_offset = reinterpret_cast(data);
- uint32_t dim = *data_offset++;
+ const uintptr_t* data_offset = reinterpret_cast(data);
+ uintptr_t dim = *data_offset++;
size_t dim_size = static_cast(dim);
std::vector dims(dim_size);
for (size_t i = 0; i < dim_size; i++) {
diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h
index 7324b0d69474c..68d89c96d96f7 100644
--- a/onnxruntime/core/providers/js/js_kernel.h
+++ b/onnxruntime/core/providers/js/js_kernel.h
@@ -110,16 +110,17 @@ class JsKernel : public OpKernel {
temp_data_size += sizeof(size_t) * 3;
}
}
- uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size));
+ uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size));
if (p_serialized_kernel_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context.");
}
- p_serialized_kernel_context[0] = reinterpret_cast(context);
- p_serialized_kernel_context[1] = static_cast(context->InputCount());
- p_serialized_kernel_context[2] = static_cast(context->OutputCount());
- p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr);
- p_serialized_kernel_context[4] = static_cast(custom_data_size);
+ p_serialized_kernel_context[0] = reinterpret_cast(context);
+ p_serialized_kernel_context[1] = static_cast(context->InputCount());
+ p_serialized_kernel_context[2] = static_cast(context->OutputCount());
+ p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr);
+ p_serialized_kernel_context[4] = static_cast(custom_data_size);
+
size_t index = 5;
for (int i = 0; i < context->InputCount(); i++) {
const auto* input_ptr = context->Input(i);
@@ -130,11 +131,11 @@ class JsKernel : public OpKernel {
p_serialized_kernel_context[index++] = 0;
continue;
}
- p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType());
- p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw());
- p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions());
+ p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType());
+ p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw());
+ p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions());
for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) {
- p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]);
+ p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]);
}
}
@@ -199,9 +200,9 @@ class JsKernel : public OpKernel {
return status;
}
- int status_code = EM_ASM_INT(
- { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); },
- this, reinterpret_cast(p_serialized_kernel_context));
+ intptr_t status_code = EM_ASM_INT(
+ { return Module.jsepRunKernel(Number($0), Number($1), Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); },
+ this, reinterpret_cast(p_serialized_kernel_context));
LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
<< (size_t)(context->Output(0)->DataRaw()) << ".";
diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h
index 0357c2f02a7a2..276b600cf40d2 100644
--- a/onnxruntime/core/providers/js/operators/conv.h
+++ b/onnxruntime/core/providers/js/operators/conv.h
@@ -51,14 +51,14 @@ class ConvBase : public JsKernel {
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $11 ? "NHWC" : "NCHW",
"auto_pad" : $1,
- "dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [],
+ "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [],
"group" : $4,
- "kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
- "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
- "strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [],
- "w_is_const" : () JS_ARROW(!!HEAP8[$12]),
+ "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [],
+ "pads" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [],
+ "strides" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [],
+ "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]),
"activation" : UTF8ToString($13),
- "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
+ "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : []
}),
static_cast(conv_attrs_.auto_pad),
JSEP_HEAP32_INDEX_START(dilations),
@@ -78,6 +78,7 @@ class ConvBase : public JsKernel {
}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /* prepacked_weights */) override {
is_packed = false;
diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h
index c51bf5ce9d4a6..baa93f825a203 100644
--- a/onnxruntime/core/providers/js/operators/conv_transpose.h
+++ b/onnxruntime/core/providers/js/operators/conv_transpose.h
@@ -48,8 +48,8 @@ class ConvTranspose : public JsKernel {
"pads" : [ $5, $6 ],
"strides" : [$7],
"wIsConst" : () JS_ARROW(!!HEAP8[$9]),
- "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [],
- "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [],
+ "outputPadding" : $10 ? Array.from(HEAP32.subarray(Number($10), Number($11))) : [],
+ "outputShape" : $12 ? Array.from(HEAP32.subarray(Number($12), Number($13))) : [],
"activation" : UTF8ToString($14)
}),
static_cast(conv_transpose_attrs_.auto_pad),
@@ -99,14 +99,14 @@ class ConvTranspose : public JsKernel {
JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({
"format" : $7 ? "NHWC" : "NCHW",
"autoPad" : $1,
- "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)),
+ "dilations" : Array.from(HEAP32.subarray(Number($2), (Number($2) >>> 0) + /* dialations_vec_size */ 2)),
"group" : $3,
- "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)),
- "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)),
- "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)),
+ "kernelShape" : Array.from(HEAP32.subarray(Number($4), (Number($4) >>> 0) + /* kernel_shape_vec_size */ 2)),
+ "pads" : Array.from(HEAP32.subarray(Number($5), (Number($5) >>> 0) + /* pads_vec_size */ 4)),
+ "strides" : Array.from(HEAP32.subarray(Number($6), (Number($6) >>> 0) + /* strides_vec_size */ 2)),
"wIsConst" : () JS_ARROW(!!HEAP8[$8]),
- "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [],
- "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [],
+ "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [],
+ "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [],
"activation" : UTF8ToString($13)
}),
static_cast(conv_transpose_attrs_.auto_pad),
@@ -126,8 +126,10 @@ class ConvTranspose : public JsKernel {
}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /* prepacked_weights */) override {
+ ORT_UNUSED_PARAMETER(save_prepacked_initializers);
is_packed = false;
if (input_idx == 1) {
diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc
index 485cd3da9b91b..e9c6f5c79294f 100644
--- a/onnxruntime/core/providers/js/operators/gather.cc
+++ b/onnxruntime/core/providers/js/operators/gather.cc
@@ -15,11 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
- .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>())
+ .TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()),
Gather);
@@ -30,11 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
- .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>())
+ .TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()),
Gather);
@@ -44,11 +36,7 @@ ONNX_OPERATOR_KERNEL_EX(
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
- .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>())
+ .TypeConstraint("T", JsepSupportedDataTypes())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()),
Gather);
diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h
index c18c7dd456dc2..f656462285bc4 100644
--- a/onnxruntime/core/providers/js/operators/pad.h
+++ b/onnxruntime/core/providers/js/operators/pad.h
@@ -22,7 +22,7 @@ class Pad : public JsKernel, public PadBase {
JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1,
"value" : $2,
- "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}),
+ "pads" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}),
static_cast(mode_),
static_cast(value_),
JSEP_HEAP32_INDEX_START(pads),
diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h
index 66bcde86020b6..32556eeaeefe4 100644
--- a/onnxruntime/core/providers/js/operators/pool.h
+++ b/onnxruntime/core/providers/js/operators/pool.h
@@ -3,22 +3,22 @@
#pragma once
-#include "core/providers/js/js_kernel.h"
#include "core/providers/cpu/nn/pool_base.h"
+#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
-#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
- "format" : $13 ? "NHWC" : "NCHW", \
- "auto_pad" : $1, \
- "ceil_mode" : $2, \
- "count_include_pad" : $3, \
- "storage_order" : $4, \
- "dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \
- "kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \
- "pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \
- "strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \
+#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
+ "format" : $13 ? "NHWC" : "NCHW", \
+ "auto_pad" : $1, \
+ "ceil_mode" : $2, \
+ "count_include_pad" : $3, \
+ "storage_order" : $4, \
+ "dilations" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], \
+ "kernel_shape" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], \
+ "pads" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], \
+ "strides" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [] \
})
#define POOL_ATTRIBUTES_PARAM_LIST \
diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h
index 937f1f990dc67..4ae558f9dfc00 100644
--- a/onnxruntime/core/providers/js/operators/reduce.h
+++ b/onnxruntime/core/providers/js/operators/reduce.h
@@ -8,29 +8,29 @@
namespace onnxruntime {
namespace js {
-#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
- template \
- class ReduceKernel : public JsKernel, public ReduceKernelBase { \
- public: \
- using ReduceKernelBase::axes_; \
- using ReduceKernelBase::noop_with_empty_axes_; \
- using ReduceKernelBase::keepdims_; \
- ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \
- std::vector axes(axes_.size()); \
- if (axes_.size() > 0) { \
- std::transform(axes_.begin(), axes_.end(), axes.begin(), \
- [](int64_t axis) { return gsl::narrow_cast(axis); }); \
- } \
- JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
- "keepDims" : !!$1, \
- "noopWithEmptyAxes" : !!$2, \
- "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \
- }), \
- static_cast(keepdims_), \
- static_cast(noop_with_empty_axes_), \
- JSEP_HEAP32_INDEX_START(axes), \
- JSEP_HEAP32_INDEX_END(axes)); \
- } \
+#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
+ template \
+ class ReduceKernel : public JsKernel, public ReduceKernelBase { \
+ public: \
+ using ReduceKernelBase::axes_; \
+ using ReduceKernelBase::noop_with_empty_axes_; \
+ using ReduceKernelBase::keepdims_; \
+ ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \
+ std::vector axes(axes_.size()); \
+ if (axes_.size() > 0) { \
+ std::transform(axes_.begin(), axes_.end(), axes.begin(), \
+ [](int64_t axis) { return gsl::narrow_cast(axis); }); \
+ } \
+ JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
+ "keepDims" : !!$1, \
+ "noopWithEmptyAxes" : !!$2, \
+ "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \
+ }), \
+ static_cast(keepdims_), \
+ static_cast(noop_with_empty_axes_), \
+ JSEP_HEAP32_INDEX_START(axes), \
+ JSEP_HEAP32_INDEX_END(axes)); \
+ } \
};
JSEP_DEFINE_REDUCE_KERNEL(ReduceMax);
diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h
index 134eb4bf5a7f4..3e8ccf40753c8 100644
--- a/onnxruntime/core/providers/js/operators/resize.h
+++ b/onnxruntime/core/providers/js/operators/resize.h
@@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase {
std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); });
JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({
"antialias" : $1,
- "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [],
+ "axes" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [],
"coordinateTransformMode" : UTF8ToString($4),
"cubicCoeffA" : $5,
"excludeOutside" : $6,
diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h
index daeffaa664741..f30e7bf01ec7b 100644
--- a/onnxruntime/core/providers/js/operators/slice.h
+++ b/onnxruntime/core/providers/js/operators/slice.h
@@ -20,9 +20,9 @@ class Slice : public JsKernel, public SliceBase {
std::vector starts(attr_starts.begin(), attr_starts.end());
std::vector ends(attr_ends.begin(), attr_ends.end());
- JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [],
- "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [],
- "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}),
+ JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [],
+ "ends" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : [],
+ "axes" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : []}),
JSEP_HEAP32_INDEX_START(starts),
JSEP_HEAP32_INDEX_END(starts),
JSEP_HEAP32_INDEX_START(ends),
diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h
index 4fdbab00e739c..3f6cfcb8921f3 100644
--- a/onnxruntime/core/providers/js/operators/split.h
+++ b/onnxruntime/core/providers/js/operators/split.h
@@ -49,7 +49,7 @@ class Split : public JsKernel, public SplitBase {
JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,
- "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}),
+ "splitSizes" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}),
static_cast(axis_),
static_cast(num_outputs_),
JSEP_HEAP32_INDEX_START(split_sizes),
diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h
index 7a945471c7701..f6b2b4faba850 100644
--- a/onnxruntime/core/providers/js/operators/transpose.h
+++ b/onnxruntime/core/providers/js/operators/transpose.h
@@ -21,7 +21,7 @@ class Transpose final : public JsKernel, public TransposeBase {
}
}
JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({
- "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : []
+ "perm" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : []
}),
JSEP_HEAP32_INDEX_START(perm),
JSEP_HEAP32_INDEX_END(perm));
diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
index 8da255a288f17..fffe964e6aaf2 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
@@ -12,27 +12,6 @@
namespace onnxruntime {
namespace webnn {
-
-// Shared functions.
-bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
- const logging::Logger& logger) {
- for (const auto* node_arg : node.InputDefs()) {
- const auto& input_name(node_arg->Name());
- if (!Contains(initializers, input_name))
- continue;
-
- const auto& tensor = *initializers.at(input_name);
- if (tensor.has_data_location() &&
- tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
- LOGS(logger, VERBOSE) << "Initializer [" << input_name
- << "] with external data location are not currently supported";
- return true;
- }
- }
-
- return false;
-}
-
// Add operator related.
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
@@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;
- // We do not support external initializers for now.
- if (HasExternalInitializer(initializers, node, logger))
- return false;
-
if (!HasSupportedOpSet(node, logger))
return false;
diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc
index c8cea833983b1..5e99551fe6e7d 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc
@@ -95,11 +95,6 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
}
- if (input_shape.empty()) {
- LOGS(logger, VERBOSE) << "Expand does not support empty input's shape.";
- return false;
- }
-
std::vector output_shape;
if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) {
LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape);
diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc
index a7911683f0355..0a438e98ad737 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc
@@ -44,21 +44,25 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& input_defs = node.InputDefs();
const auto& initializers(model_builder.GetInitializerTensors());
const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name());
- const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
- ? reinterpret_cast(target_shape_tensor.raw_data().data())
- : target_shape_tensor.int64_data().data();
+ const auto& target_shape_tensor_dims = target_shape_tensor.dims();
+ std::vector new_shape;
+ // Do nothing if target shape is an empty shape, which means converting to a scalar.
+ if (!target_shape_tensor_dims.empty()) {
+ const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
+ ? reinterpret_cast(target_shape_tensor.raw_data().data())
+ : target_shape_tensor.int64_data().data();
+
+ const auto size = target_shape_tensor_dims[0];
+ TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
+ std::vector input_shape;
+ ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
+ ReshapeHelper helper(TensorShape(input_shape), target_shape);
+ std::transform(target_shape.cbegin(), target_shape.cend(),
+ std::back_inserter(new_shape),
+ [](int64_t dim) -> uint32_t { return SafeInt(dim); });
+ }
- const auto size = target_shape_tensor.dims()[0];
- TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
- std::vector input_shape;
- ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
- ReshapeHelper helper(TensorShape(input_shape), target_shape);
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
- std::vector new_shape;
- std::transform(target_shape.cbegin(), target_shape.cend(),
- std::back_inserter(new_shape),
- [](int64_t dim) -> uint32_t { return SafeInt(dim); });
-
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call("reshape",
@@ -76,6 +80,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
+
+ std::vector input_shape;
+ if (!GetShape(*input_defs[0], input_shape, logger))
+ return false;
+
const auto& perm_name = input_defs[1]->Name();
if (!Contains(initializers, perm_name)) {
LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer";
@@ -92,24 +101,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
const int64_t* raw_new_shape = reinterpret_cast(unpacked_tensor.data());
const auto& perm_dims = perm_tensor.dims();
- if (perm_dims.empty() || perm_dims[0] == 0) {
- LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty";
- return false;
- }
-
- std::vector input_shape;
- if (!GetShape(*input_defs[0], input_shape, logger))
- return false;
-
- if (input_shape.empty()) {
- LOGS(logger, VERBOSE) << "Reshape does not support empty input shape";
- return false;
- }
// WebNN reshape does not support 0 as dimension.
NodeAttrHelper helper(node);
- const bool allow_zero = helper.Get("allowzero ", 0) == 1;
- if (allow_zero) {
+ const bool allow_zero = helper.Get("allowzero", 0) == 1;
+ if (allow_zero && !perm_dims.empty()) {
for (int64_t i = 0; i < perm_dims[0]; i++) {
if (raw_new_shape[i] == 0) {
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";
diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc
index 044baa738e8c4..8a7fea0cde431 100644
--- a/onnxruntime/core/providers/webnn/builders/model_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc
@@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() {
auto num_elements = SafeInt(Product(shape));
emscripten::val view = emscripten::val::undefined();
std::byte* tensor_ptr = nullptr;
- if (tensor.has_raw_data()) {
- tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str()));
+
+ if (utils::HasExternalData(tensor)) {
+ // Create WebNN Constant from external data.
+ std::basic_string external_file_path;
+ onnxruntime::FileOffsetType data_offset;
+ SafeInt tensor_byte_size;
+ ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(
+ tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size));
+
+ auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant");
+ operand = jsepRegisterMLConstant(emscripten::val(external_file_path),
+ static_cast(data_offset),
+ static_cast(tensor_byte_size),
+ wnn_builder_,
+ desc);
} else {
- // Store temporary unpacked_tensor.
- unpacked_tensors_.push_back({});
- std::vector& unpacked_tensor = unpacked_tensors_.back();
- ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
- tensor_ptr = reinterpret_cast(unpacked_tensor.data());
- }
- switch (data_type) {
- case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
- case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
- view = emscripten::val{emscripten::typed_memory_view(num_elements,
- reinterpret_cast(tensor_ptr))};
- break;
- case ONNX_NAMESPACE::TensorProto_DataType_INT8:
- view = emscripten::val{emscripten::typed_memory_view(num_elements,
- reinterpret_cast(tensor_ptr))};
- break;
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
- view = emscripten::val{emscripten::typed_memory_view(num_elements,
- reinterpret_cast