Skip to content

Commit

Permalink
deploy: da05f77
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Dec 6, 2023
1 parent 66b3fc9 commit 7181aa6
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 9 deletions.
41 changes: 39 additions & 2 deletions api/_modules/opacus/layers/dp_multihead_attention.html
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,21 @@ <h1>Source code for opacus.layers.dp_multihead_attention</h1><div class="highlig
<span class="n">add_zero_attn</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">kdim</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">vdim</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">batch_first</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">DPMultiheadAttention</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kdim</span> <span class="o">=</span> <span class="n">kdim</span> <span class="k">if</span> <span class="n">kdim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">embed_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vdim</span> <span class="o">=</span> <span class="n">vdim</span> <span class="k">if</span> <span class="n">vdim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">embed_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_qkv_same_embed_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kdim</span> <span class="o">==</span> <span class="n">embed_dim</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">vdim</span> <span class="o">==</span> <span class="n">embed_dim</span>

<span class="c1"># when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_qkv_same_embed_dim</span> <span class="o">=</span> <span class="kc">False</span>

<span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">dropout</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span> <span class="o">=</span> <span class="n">batch_first</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="o">//</span> <span class="n">num_heads</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_dim</span> <span class="o">*</span> <span class="n">num_heads</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span>
Expand All @@ -155,6 +159,10 @@ <h1>Source code for opacus.layers.dp_multihead_attention</h1><div class="highlig

<span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

<span class="c1"># to avoid null pointers in Transformer.forward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_proj_weight</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">in_proj_bias</span> <span class="o">=</span> <span class="kc">None</span>

<div class="viewcode-block" id="DPMultiheadAttention.load_state_dict">
<a class="viewcode-back" href="../../../dp_multihead_attention.html#opacus.layers.dp_multihead_attention.DPMultiheadAttention.load_state_dict">[docs]</a>
<span class="k">def</span> <span class="nf">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">):</span>
Expand Down Expand Up @@ -218,7 +226,33 @@ <h1>Source code for opacus.layers.dp_multihead_attention</h1><div class="highlig
<span class="n">key_padding_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">need_weights</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">attn_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">is_causal</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">is_batched</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">3</span>

<span class="k">assert</span> <span class="n">is_batched</span> <span class="o">==</span> <span class="kc">True</span><span class="p">,</span> <span class="s2">"The query must have a dimension of 3."</span>

<span class="w"> </span><span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,</span>
<span class="sd"> since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.</span>
<span class="sd"> """</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">is_causal</span> <span class="o">==</span> <span class="kc">False</span>
<span class="p">),</span> <span class="s2">"We currently do not support causal mask. Will fix it in the future."</span>

<span class="w"> </span><span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span><span class="p">:</span>
<span class="k">if</span> <span class="n">key</span> <span class="ow">is</span> <span class="n">value</span><span class="p">:</span>
<span class="k">if</span> <span class="n">query</span> <span class="ow">is</span> <span class="n">key</span><span class="p">:</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">key</span> <span class="o">=</span> <span class="n">value</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)]</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">key</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)]</span>

<span class="n">tgt_len</span><span class="p">,</span> <span class="n">bsz</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="k">if</span> <span class="n">embed_dim</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
Expand Down Expand Up @@ -363,6 +397,9 @@ <h1>Source code for opacus.layers.dp_multihead_attention</h1><div class="highlig
<span class="p">)</span>
<span class="n">attn_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_proj</span><span class="p">(</span><span class="n">attn_output</span><span class="p">)</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span><span class="p">:</span>
<span class="n">attn_output</span> <span class="o">=</span> <span class="n">attn_output</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

<span class="k">if</span> <span class="n">need_weights</span><span class="p">:</span>
<span class="c1"># average attention weights over heads</span>
<span class="n">attn_output_weights</span> <span class="o">=</span> <span class="n">attn_output_weights</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
Expand Down Expand Up @@ -404,7 +441,7 @@ <h1>Source code for opacus.layers.dp_multihead_attention</h1><div class="highlig
<span class="n">keep_vars</span><span class="o">=</span><span class="n">keep_vars</span><span class="p">,</span>
<span class="p">)</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_qkv_same_embed_dim</span><span class="p">:</span>
<span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kdim</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vdim</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">):</span>
<span class="n">destination_alter</span><span class="p">[</span><span class="n">prefix</span> <span class="o">+</span> <span class="s2">"in_proj_weight"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span>
<span class="n">destination</span><span class="p">[</span><span class="n">prefix</span> <span class="o">+</span> <span class="s2">"qlinear.weight"</span><span class="p">],</span>
Expand Down
Loading

0 comments on commit 7181aa6

Please sign in to comment.