-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #169 Make d2go DETR exportable (torchscript compatible) Move generating masks to preprocessing Reviewed By: sstsai-adl Differential Revision: D33798073 fbshipit-source-id: d629b0c9cbdb67060982be717c7138a0e7e9adbc
- Loading branch information
1 parent
6791682
commit 5aadaaa
Showing
4 changed files
with
96 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import unittest | ||
|
||
import torch | ||
from d2go.runner import create_runner | ||
from detr.util.misc import nested_tensor_from_tensor_list | ||
from fvcore.nn import flop_count_table, FlopCountAnalysis | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
@staticmethod | ||
def _set_detr_cfg(cfg, enc_layers, dec_layers, num_queries, dim_feedforward): | ||
cfg.MODEL.META_ARCHITECTURE = "Detr" | ||
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = num_queries | ||
cfg.MODEL.DETR.ENC_LAYERS = enc_layers | ||
cfg.MODEL.DETR.DEC_LAYERS = dec_layers | ||
cfg.MODEL.DETR.DEEP_SUPERVISION = False | ||
cfg.MODEL.DETR.DIM_FEEDFORWARD = dim_feedforward # 2048 | ||
|
||
def _assert_model_output(self, model, scripted_model): | ||
x = nested_tensor_from_tensor_list( | ||
[torch.rand(3, 200, 200), torch.rand(3, 200, 250)] | ||
) | ||
out = model(x) | ||
out_script = scripted_model(x) | ||
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"])) | ||
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"])) | ||
|
||
def test_detr_res50_export(self): | ||
runner = create_runner("d2go.projects.detr.runner.DETRRunner") | ||
cfg = runner.get_default_cfg() | ||
cfg.MODEL.DEVICE = "cpu" | ||
# DETR | ||
self._set_detr_cfg(cfg, 6, 6, 100, 2048) | ||
# backbone | ||
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone" | ||
cfg.MODEL.RESNETS.DEPTH = 50 | ||
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False | ||
cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"] | ||
# build model | ||
model = runner.build_model(cfg).eval() | ||
model = model.detr | ||
scripted_model = torch.jit.script(model) | ||
self._assert_model_output(model, scripted_model) | ||
|
||
def test_detr_fbnet_export(self): | ||
runner = create_runner("d2go.projects.detr.runner.DETRRunner") | ||
cfg = runner.get_default_cfg() | ||
cfg.MODEL.DEVICE = "cpu" | ||
# DETR | ||
self._set_detr_cfg(cfg, 3, 3, 50, 256) | ||
# backbone | ||
cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone" | ||
cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5" | ||
cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8 | ||
cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"] | ||
# build model | ||
model = runner.build_model(cfg).eval() | ||
model = model.detr | ||
print(model) | ||
scripted_model = torch.jit.script(model) | ||
self._assert_model_output(model, scripted_model) | ||
# print flops | ||
table = flop_count_table(FlopCountAnalysis(model, ([torch.rand(3, 224, 320)],))) | ||
print(table) |