Skip to content

Commit

Permalink
DETR Model Export
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #169

Make d2go DETR exportable (torchscript compatible)
Move generating masks to preprocessing

Reviewed By: sstsai-adl

Differential Revision: D33798073

fbshipit-source-id: d629b0c9cbdb67060982be717c7138a0e7e9adbc
  • Loading branch information
zhanghang1989 authored and facebook-github-bot committed Feb 7, 2022
1 parent 6791682 commit 5aadaaa
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 60 deletions.
21 changes: 17 additions & 4 deletions projects_oss/detr/detr/d2/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def forward(self, batched_inputs):
dict[str: Tensor]:
mapping from a named loss to a tensor storing the loss. Used during training only.
"""
images = self.preprocess_image(batched_inputs)
output = self.detr(images)
images_lists = self.preprocess_image(batched_inputs)
# convert images_lists to Nested Tensor?
nested_images = self.imagelist_to_nestedtensor(images_lists)
output = self.detr(nested_images)

if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
Expand All @@ -144,10 +146,12 @@ def forward(self, batched_inputs):
box_cls = output["pred_logits"]
box_pred = output["pred_boxes"]
mask_pred = output["pred_masks"] if self.mask_on else None
results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes)
results = self.inference(
box_cls, box_pred, mask_pred, images_lists.image_sizes
)
processed_results = []
for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes
results, batched_inputs, images_lists.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
Expand Down Expand Up @@ -239,3 +243,12 @@ def preprocess_image(self, batched_inputs):
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
images = ImageList.from_tensors(images)
return images

def imagelist_to_nestedtensor(self, images):
tensor = images.tensor
device = tensor.device
N, _, H, W = tensor.shape
masks = torch.ones((N, H, W), dtype=torch.bool, device=device)
for idx, (h, w) in enumerate(images.image_sizes):
masks[idx, :h, :w] = False
return NestedTensor(tensor, masks)
67 changes: 12 additions & 55 deletions projects_oss/detr/detr/models/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
from detectron2.modeling import build_backbone
from detectron2.utils.registry import Registry
from detr.models.backbone import Joiner
Expand Down Expand Up @@ -59,35 +62,15 @@ def __init__(self, cfg):
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = [backbone_shape[k].channels for k in backbone_shape.keys()]

def forward(self, images):
features = self.backbone(images.tensor)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
for i, k in enumerate(features.keys()):
features[k] = NestedTensor(features[k], masks[i])
return features

def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
def forward(self, tensor_list: NestedTensor):
xs = self.backbone(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out


class FBNetMaskedBackbone(ResNetMaskedBackbone):
Expand All @@ -105,20 +88,6 @@ def __init__(self, cfg):
self.backbone._out_feature_strides[k] for k in self.out_features
]

def forward(self, images):
features = self.backbone(images.tensor)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
ret_features = {}
for i, k in enumerate(features.keys()):
if k in self.out_features:
ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features


class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone,
Expand All @@ -135,15 +104,3 @@ def __init__(self, cfg):
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
self.strides = [cfg.MODEL.BACKBONE.STRIDE]

def forward(self, images):
y = self.backbone(images.tensor)
masks = self.mask_out_padding(
[y.shape],
images.image_sizes,
images.tensor.device,
)
assert len(masks) == 1
ret_features = {}
ret_features[self.out_features[0]] = NestedTensor(y, masks[0])
return ret_features
1 change: 0 additions & 1 deletion projects_oss/detr/test_all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io
import unittest
from typing import List
Expand Down
67 changes: 67 additions & 0 deletions projects_oss/detr/test_detr_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import torch
from d2go.runner import create_runner
from detr.util.misc import nested_tensor_from_tensor_list
from fvcore.nn import flop_count_table, FlopCountAnalysis


class Tester(unittest.TestCase):
@staticmethod
def _set_detr_cfg(cfg, enc_layers, dec_layers, num_queries, dim_feedforward):
cfg.MODEL.META_ARCHITECTURE = "Detr"
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = num_queries
cfg.MODEL.DETR.ENC_LAYERS = enc_layers
cfg.MODEL.DETR.DEC_LAYERS = dec_layers
cfg.MODEL.DETR.DEEP_SUPERVISION = False
cfg.MODEL.DETR.DIM_FEEDFORWARD = dim_feedforward # 2048

def _assert_model_output(self, model, scripted_model):
x = nested_tensor_from_tensor_list(
[torch.rand(3, 200, 200), torch.rand(3, 200, 250)]
)
out = model(x)
out_script = scripted_model(x)
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))

def test_detr_res50_export(self):
runner = create_runner("d2go.projects.detr.runner.DETRRunner")
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
# DETR
self._set_detr_cfg(cfg, 6, 6, 100, 2048)
# backbone
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
cfg.MODEL.RESNETS.DEPTH = 50
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
# build model
model = runner.build_model(cfg).eval()
model = model.detr
scripted_model = torch.jit.script(model)
self._assert_model_output(model, scripted_model)

def test_detr_fbnet_export(self):
runner = create_runner("d2go.projects.detr.runner.DETRRunner")
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
# DETR
self._set_detr_cfg(cfg, 3, 3, 50, 256)
# backbone
cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone"
cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5"
cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8
cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"]
# build model
model = runner.build_model(cfg).eval()
model = model.detr
print(model)
scripted_model = torch.jit.script(model)
self._assert_model_output(model, scripted_model)
# print flops
table = flop_count_table(FlopCountAnalysis(model, ([torch.rand(3, 224, 320)],)))
print(table)

0 comments on commit 5aadaaa

Please sign in to comment.