Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to deploy Vision Transformer with ANE to Achieve Faster Uncached Load Speed #2300

Open
cvv-student opened this issue Aug 8, 2024 · 0 comments
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)

Comments

@cvv-student
Copy link

❓Question

I wanted to deploy some ViT models on an iPhone. I referred to https://machinelearning.apple.com/research/vision-transformers for deployment and wrote a simple demo based on the code from https://github.com/apple/ml-vision-transformers-ane. However, I found that the uncached load time on the phone is very long. According to the blog, the input is already aligned to 64 bytes, but the speed is still very slow. Is there any way to speed it up? This is my test case:

import torch
import coremltools as ct
import math
from torch import nn


class SelfAttn(torch.nn.Module):
    def __init__(self, window_size, num_heads, dim, dim_out):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.dim = dim
        self.dim_out = dim_out
        self.q_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.k_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.v_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )

    def forward(self, x):
        B, HW, C = x.shape
        image_shape = (B, C, self.window_size, self.window_size)
        x_2d = x.permute((0, 2, 1)).reshape(image_shape)  # BCHW
        x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2)  # BC1L
        q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d)

        mh_q = torch.split(q, self.dim_out // self.num_heads, dim=1)  # BC1L
        mh_v = torch.split(
            v_2d.reshape(B, -1, x_flat.shape[2], x_flat.shape[3]), self.dim_out // self.num_heads, dim=1
        )
        mh_k = torch.split(
            torch.permute(k, (0, 3, 2, 1)), self.dim_out // self.num_heads, dim=3
        )
        scale_factor = 1 / math.sqrt(mh_q[0].size(1))
        attn_weights = [
            torch.einsum("bchq, bkhc->bkhq", qi, ki) * scale_factor
            for qi, ki in zip(mh_q, mh_k)
        ]
        attn_weights = [
            torch.softmax(aw, dim=1) for aw in attn_weights
        ]  # softmax applied on channel "C"
        mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(attn_weights, mh_v)]
        x = torch.cat(mh_x, dim=1)
        return x


window_size = 8
path_batch = 1024
emb_dim = 96
emb_dim_out = 96
x = torch.rand(path_batch, window_size * window_size, emb_dim)
qkv_layer = SelfAttn(window_size, 1, emb_dim, emb_dim_out)
jit = torch.jit.trace(qkv_layer, (x))

mlmod_fixed_shape = ct.convert(
    jit,
    inputs=[
        ct.TensorType("x", x.shape),
    ],
    convert_to="mlprogram",
)
mlmodel_path = "test_ane.mlpackage"
mlmod_fixed_shape.save(mlmodel_path)

This is my profiler results:
20240808-091653
The uncached load took nearly 36 seconds, and it was just a single matrix multiplication.

@cvv-student cvv-student added the question Response providing clarification needed. Will not be assigned to a release. (type) label Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)
Projects
None yet
Development

No branches or pull requests

1 participant