Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of WebGPU EP #22591

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Add implementation of WebGPU EP #22591

wants to merge 10 commits into from

Conversation

fs-eire
Copy link
Contributor

@fs-eire fs-eire commented Oct 24, 2024

Description

This PR adds the actual implementation of the WebGPU EP based on #22318.

This change includes the following:

core framework of WebGPU EP
  • WebGPU EP factory classes for:
    • handling WebGPU options
    • creating WebGPU EP instance
    • creating WebGPU context
  • WebGPU Execution Provider classes
    • GPU Buffer allocator
    • data transfer
  • Buffer management classes
    • Buffer Manager
    • BufferCacheManager
      • DisabledCacheManager
      • SimpleCacheManager
      • LazyReleaseCacheManager
      • BucketCacheManager
  • Program classes
    • Program (base)
    • Program Cache Key
    • Program Manager
  • Shader helper classes
    • Shader Helper
    • ShaderIndicesHelper
    • ShaderVariableHelper
  • Utils
    • GPU Query based profiler
    • compute context
    • string utils
  • Miscs
    • Python binding webgpu support (basic)
Kernel implementation
  • onnx.ai (default opset):
    • Elementwise (math): Abs, Neg, Floor, Ceil, Reciprocal, Sqrt, Exp, Erf, Log, Sin, Cos, Tan, Asin, Acos, Atan, Sinh, Cosh, Asinh, Acosh, Atanh, Tanh, Not, Cast
    • Elementwise (activation): Sigmoid, HardSigmoid, Clip, Elu, Relu, LeakyRelu, ThresholdedRelu, Gelu
    • Binary (math): Add, Sub, Mul, Div, Pow, Equal, Greater, GreaterOrEqual, Less, LessOrEqual
    • (Tensors): Shape, Reshape, Squeeze, Unsqueeze
    • Where
    • Transpose
    • Concat
    • Expand
    • Gather
    • Tile
    • Range
    • LayerNormalization
  • com.microsoft
    • FastGelu
    • MatMulNBits
    • MultiHeadAttention
    • RotaryEmbedding
    • SkipLayerNormalization
    • LayerNormalization
    • SimplifiedLayerNormalization
    • SkipSimplifiedLayerNormalization
Build, test and CI pipeline integration
  • build works for Windows, macOS and iOS
  • support onnxruntime_test_all and python node test
  • added a new unit test for --use_external_dawn build flag.
  • updated MacOS pipeline to build with WebGPU support
  • added a new pipeline for WebGPU Windows

This change does not include:

  • Node.js binding support for WebGPU (will be a separate PR)

@fs-eire fs-eire requested a review from a team as a code owner October 24, 2024 20:20
@snnn snnn closed this Oct 25, 2024
@snnn snnn reopened this Oct 25, 2024
guschmue
guschmue previously approved these changes Oct 25, 2024
Copy link
Contributor

@guschmue guschmue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm

", Actual: ", shape.NumDimensions());

std::vector<uint32_t> dims(expected_rank);
std::vector<uint32_t> stride(expected_rank - 1);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2). Warning

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
std::vector<uint32_t> stride(expected_rank - 1);
for (size_t j = 0; j < expected_rank; ++j) {
dims[j] = gsl::narrow<uint32_t>(shape[j]);
if (j < expected_rank - 1) {

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2). Warning

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants