From 08a662e2c64f071af1a7ca520968ae10b40a1eaf Mon Sep 17 00:00:00 2001 From: Billy Lee Date: Wed, 12 Apr 2023 13:52:00 -0400 Subject: [PATCH] added speech2text model support for BT --- docs/source/bettertransformer/overview.mdx | 1 + optimum/bettertransformer/models/__init__.py | 1 + tests/bettertransformer/testing_utils.py | 1 + 3 files changed, 3 insertions(+) diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 86f3a65d3b..42ad35a80d 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -57,6 +57,7 @@ The list of supported model below: - [RoBERTa](https://arxiv.org/abs/1907.11692) - [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf) - [RoFormer](https://arxiv.org/abs/2104.09864) +- [Speech2Text](https://arxiv.org/abs/2010.05171) - [Splinter](https://arxiv.org/abs/2101.00438) - [Tapas](https://arxiv.org/abs/2211.06550) - [ViLT](https://arxiv.org/abs/2102.03334) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 5ebd9255af..19783162c6 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -85,6 +85,7 @@ class BetterTransformerManager: "roberta": {"RobertaLayer": BertLayerBetterTransformer}, "roc_bert": {"RoCBertLayer": BertLayerBetterTransformer}, "roformer": {"RoFormerLayer": BertLayerBetterTransformer}, + "speech2text": {"Speech2TextEncoderLayer": MBartEncoderLayerBetterTransformer}, "splinter": {"SplinterLayer": BertLayerBetterTransformer}, "tapas": {"TapasLayer": BertLayerBetterTransformer}, "t5": {"T5Attention": T5AttentionLayerBetterTransformer}, diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index fd1ff135a3..e3c6d60e6f 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -59,6 +59,7 @@ "roberta": "hf-internal-testing/tiny-random-RobertaModel", "rocbert": "hf-internal-testing/tiny-random-RoCBertModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", + "speech2text": "hf-internal-testing/tiny-random-Speech2TextModel", "splinter": "hf-internal-testing/tiny-random-SplinterModel", "tapas": "hf-internal-testing/tiny-random-TapasModel", "t5": "hf-internal-testing/tiny-random-t5",