Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT CLOSE] Easy CLA signal for Python docs push #1281

Draft
wants to merge 6 commits into
base: pytorchbot/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions _posts/2024-04-04-accelerating-moe-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author: Adnan Hoque, Less Wright, Antoni Virós Martin, Chih-Chieh Yang

We show that by implementing column-major scheduling to improve data locality, we can accelerate the core Triton GEMM (General Matrix-Matrix Multiply) kernel for MoEs (Mixture of Experts) up to 4x on A100, and up to 4.4x on H100 Nvidia GPUs. This post demonstrates several different work decomposition and scheduling algorithms for MoE GEMMs and shows, at the hardware level, why column-major scheduling produces the highest speedup.

Repo and code available at: [https://github.com/pytorch-labs/applied-ai/tree/main/triton/](https://github.com/pytorch-labs/applied-ai/tree/main/triton/inference/col_major_moe_gemm).
Repo and code available at: [https://github.com/pytorch-labs/applied-ai/tree/main/kernels/triton/inference/col_major_moe_gemm](https://github.com/pytorch-labs/applied-ai/tree/main/kernels/triton/inference/col_major_moe_gemm).


![Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M](/assets/images/accelerating-moe-model/fig-7.png){:style="width:100%;display: block; max-width: 600px; margin-right: auto; margin-left: auto"}
Expand Down Expand Up @@ -128,4 +128,4 @@ We have [open sourced](https://github.com/pytorch-labs/applied-ai/tree/main/kern

## Acknowledgements

We want to thank Daniel Han, Raghu Ganti, Mudhakar Srivatsa, Bert Maher, Gregory Chanan, Eli Uriegas, and Geeta Chauhan for their review of the presented material and Woo Suk from the vLLM team as we built on his implementation of the Fused MoE kernel.
We want to thank Daniel Han, Raghu Ganti, Mudhakar Srivatsa, Bert Maher, Gregory Chanan, Eli Uriegas, and Geeta Chauhan for their review of the presented material and Woosuk from the vLLM team as we built on his implementation of the Fused MoE kernel.
Binary file modified docs/2.3/_images/RReLU.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 11 additions & 11 deletions docs/2.3/_sources/generated/exportdb/index.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ support in export please create an issue in the pytorch/pytorch repo wih a modul
:caption: Tags

torch.escape-hatch
torch.dynamic-shape
torch.cond
torch.dynamic-shape
python.closure
torch.dynamic-value
python.data-structure
Expand Down Expand Up @@ -203,7 +203,7 @@ cond_branch_class_method

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -284,7 +284,7 @@ cond_branch_nested_function

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -363,7 +363,7 @@ cond_branch_nonlocal_variables

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -528,7 +528,7 @@ cond_operands

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -602,7 +602,7 @@ cond_predicate

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -981,7 +981,7 @@ dynamic_shape_if_guard

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`
Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -1027,7 +1027,7 @@ dynamic_shape_map

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.map <torch.map>`
Tags: :doc:`torch.map <torch.map>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -1240,7 +1240,7 @@ list_contains

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`
Tags: :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -1865,7 +1865,7 @@ dynamic_shape_round

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.builtin <python.builtin>`
Tags: :doc:`python.builtin <python.builtin>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: NOT_SUPPORTED_YET

Expand Down Expand Up @@ -2005,6 +2005,6 @@ Result:

.. code-block::

Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f4d9cf5cd30>
Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f0885c63d30>


2 changes: 1 addition & 1 deletion docs/2.3/_sources/generated/exportdb/python.assert.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ list_contains

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`
Tags: :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dynamic_shape_round

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.builtin <python.builtin>`
Tags: :doc:`python.builtin <python.builtin>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: NOT_SUPPORTED_YET

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dynamic_shape_if_guard

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`
Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ list_contains

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`
Tags: :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
10 changes: 5 additions & 5 deletions docs/2.3/_sources/generated/exportdb/torch.cond.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cond_branch_class_method

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -86,7 +86,7 @@ cond_branch_nested_function

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -165,7 +165,7 @@ cond_branch_nonlocal_variables

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -330,7 +330,7 @@ cond_operands

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -404,7 +404,7 @@ cond_predicate

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
18 changes: 9 additions & 9 deletions docs/2.3/_sources/generated/exportdb/torch.dynamic-shape.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cond_branch_class_method

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -86,7 +86,7 @@ cond_branch_nested_function

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -165,7 +165,7 @@ cond_branch_nonlocal_variables

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -269,7 +269,7 @@ cond_operands

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -343,7 +343,7 @@ cond_predicate

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`
Tags: :doc:`torch.cond <torch.cond>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -451,7 +451,7 @@ dynamic_shape_if_guard

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`
Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -497,7 +497,7 @@ dynamic_shape_map

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.map <torch.map>`
Tags: :doc:`torch.map <torch.map>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down Expand Up @@ -552,7 +552,7 @@ dynamic_shape_round

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.builtin <python.builtin>`
Tags: :doc:`python.builtin <python.builtin>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: NOT_SUPPORTED_YET

Expand Down Expand Up @@ -686,7 +686,7 @@ list_contains

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`
Tags: :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
2 changes: 1 addition & 1 deletion docs/2.3/_sources/generated/exportdb/torch.map.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dynamic_shape_map

.. note::

Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.map <torch.map>`
Tags: :doc:`torch.map <torch.map>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

Support Level: SUPPORTED

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ Result:

.. code-block::

Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f4d9cf5cd30>
Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f0885c63d30>
4 changes: 2 additions & 2 deletions docs/2.3/ddp_comm_hooks.html
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ <h2>What Does a Communication Hook Operate On?<a class="headerlink" href="#what-

<dl class="py function">
<dt class="sig sig-object py" id="torch.distributed.GradBucket.gradients">
<span class="sig-prename descclassname"><span class="pre">torch.distributed.GradBucket.</span></span><span class="sig-name descname"><span class="pre">gradients</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference internal" href="#torch.distributed.GradBucket" title="torch._C._distributed_c10d.GradBucket"><span class="pre">torch._C._distributed_c10d.GradBucket</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="tensors.html#torch.Tensor" title="torch.Tensor"><span class="pre">torch.Tensor</span></a><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#torch.distributed.GradBucket.gradients" title="Permalink to this definition">¶</a></dt>
<span class="sig-prename descclassname"><span class="pre">torch.distributed.GradBucket.</span></span><span class="sig-name descname"><span class="pre">gradients</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference internal" href="#torch.distributed.GradBucket" title="torch._C._distributed_c10d.GradBucket"><span class="pre">torch._C._distributed_c10d.GradBucket</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><span class="pre">list</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="tensors.html#torch.Tensor" title="torch.Tensor"><span class="pre">torch.Tensor</span></a><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#torch.distributed.GradBucket.gradients" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>A list of <code class="docutils literal notranslate"><span class="pre">torch.Tensor</span></code>. Each tensor in the list corresponds to a gradient.</p>
Expand All @@ -554,7 +554,7 @@ <h2>What Does a Communication Hook Operate On?<a class="headerlink" href="#what-

<dl class="py function">
<dt class="sig sig-object py" id="torch.distributed.GradBucket.parameters">
<span class="sig-prename descclassname"><span class="pre">torch.distributed.GradBucket.</span></span><span class="sig-name descname"><span class="pre">parameters</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference internal" href="#torch.distributed.GradBucket" title="torch._C._distributed_c10d.GradBucket"><span class="pre">torch._C._distributed_c10d.GradBucket</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="tensors.html#torch.Tensor" title="torch.Tensor"><span class="pre">torch.Tensor</span></a><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#torch.distributed.GradBucket.parameters" title="Permalink to this definition">¶</a></dt>
<span class="sig-prename descclassname"><span class="pre">torch.distributed.GradBucket.</span></span><span class="sig-name descname"><span class="pre">parameters</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference internal" href="#torch.distributed.GradBucket" title="torch._C._distributed_c10d.GradBucket"><span class="pre">torch._C._distributed_c10d.GradBucket</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#list" title="(in Python v3.12)"><span class="pre">list</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="tensors.html#torch.Tensor" title="torch.Tensor"><span class="pre">torch.Tensor</span></a><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#torch.distributed.GradBucket.parameters" title="Permalink to this definition">¶</a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Returns</dt>
<dd class="field-odd"><p>A list of <code class="docutils literal notranslate"><span class="pre">torch.Tensor</span></code>. Each tensor in the list corresponds to a model
Expand Down
4 changes: 2 additions & 2 deletions docs/2.3/distributed.html
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ <h2>Distributed Key-Value Store<a class="headerlink" href="#distributed-key-valu
<span class="sig-prename descclassname"><span class="pre">torch.distributed.Store.</span></span><span class="sig-name descname"><span class="pre">wait</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#torch.distributed.Store.wait" title="Permalink to this definition">¶</a></dt>
<dd><p>Overloaded function.</p>
<ol class="arabic simple">
<li><p>wait(self: torch._C._distributed_c10d.Store, arg0: List[str]) -&gt; None</p></li>
<li><p>wait(self: torch._C._distributed_c10d.Store, arg0: list[str]) -&gt; None</p></li>
</ol>
<p>Waits for each key in <code class="docutils literal notranslate"><span class="pre">keys</span></code> to be added to the store. If not all keys are
set before the <code class="docutils literal notranslate"><span class="pre">timeout</span></code> (set during store initialization), then <code class="docutils literal notranslate"><span class="pre">wait</span></code>
Expand All @@ -1370,7 +1370,7 @@ <h2>Distributed Key-Value Store<a class="headerlink" href="#distributed-key-valu
</dd>
</dl>
<ol class="arabic simple" start="2">
<li><p>wait(self: torch._C._distributed_c10d.Store, arg0: List[str], arg1: datetime.timedelta) -&gt; None</p></li>
<li><p>wait(self: torch._C._distributed_c10d.Store, arg0: list[str], arg1: datetime.timedelta) -&gt; None</p></li>
</ol>
<p>Waits for each key in <code class="docutils literal notranslate"><span class="pre">keys</span></code> to be added to the store, and throws an exception
if the keys have not been set by the supplied <code class="docutils literal notranslate"><span class="pre">timeout</span></code>.</p>
Expand Down
Loading
Loading