From fb503c211deb3a8a4c8a7e047d7e5994d4fb5e4a Mon Sep 17 00:00:00 2001 From: Vladimir Brkic Date: Wed, 6 Nov 2024 08:43:20 +0000 Subject: [PATCH] Filter test plan by INPUT_SHAPES and KWARGS --- forge/test/operators/pytorch/test_all.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/forge/test/operators/pytorch/test_all.py b/forge/test/operators/pytorch/test_all.py index cc65a30c..15b7428a 100644 --- a/forge/test/operators/pytorch/test_all.py +++ b/forge/test/operators/pytorch/test_all.py @@ -6,12 +6,6 @@ # Examples -# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique --collect-only -# TEST_ID='no_device-ge-FROM_HOST-None-(1, 2, 3, 4)-Float16_b-HiFi4' pytest -svv forge/test/operators/pytorch/test_all.py::test_single -# OPERATORS=add,div FILTERS=HAS_DATA_FORMAT,QUICK DEV_DATA_FORMATS=Float16_b,Int8 MATH_FIDELITIES=HiFi4,HiFi3 RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only -# OPERATORS=add,div FAILING_REASONS=DATA_MISMATCH,UNSUPPORTED_DATA_FORMAT SKIP_REASONS=FATAL_ERROR RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only -# FAILING_REASONS=NOT_IMPLEMENTED INPUT_SOURCES=FROM_HOST pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only - # pytest -svv forge/test/operators/pytorch/test_all.py::test_plan # pytest -svv forge/test/operators/pytorch/test_all.py::test_failed # pytest -svv forge/test/operators/pytorch/test_all.py::test_skipped @@ -20,6 +14,9 @@ # pytest -svv forge/test/operators/pytorch/test_all.py::test_data_mismatch # pytest -svv forge/test/operators/pytorch/test_all.py::test_unsupported_df # pytest -svv forge/test/operators/pytorch/test_all.py::test_custom +# pytest -svv forge/test/operators/pytorch/test_all.py::test_query +# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique +# pytest -svv forge/test/operators/pytorch/test_all.py::test_single import os @@ -74,8 +71,10 @@ def build_filtered_collection(cls) -> TestCollection: Query criterias are defined by the following environment variables: - OPERATORS: List of operators to filter - INPUT_SOURCES: List of input sources to filter + - INPUT_SHAPES: List of input shapes to filter - DEV_DATA_FORMATS: List of data formats to filter - MATH_FIDELITIES: List of math fidelities to filter + - KWARGS: List of kwargs dictionaries to filter """ operators = os.getenv("OPERATORS", None) if operators: @@ -88,7 +87,9 @@ def build_filtered_collection(cls) -> TestCollection: input_sources = input_sources.split(",") input_sources = [getattr(InputSource, input_source) for input_source in input_sources] - # TODO INPUT_SHAPES + input_shapes = os.getenv("INPUT_SHAPES", None) + if input_shapes: + input_shapes = eval(input_shapes) dev_data_formats = os.getenv("DEV_DATA_FORMATS", None) if dev_data_formats: @@ -100,13 +101,17 @@ def build_filtered_collection(cls) -> TestCollection: math_fidelities = math_fidelities.split(",") math_fidelities = [getattr(forge.MathFidelity, math_fidelity) for math_fidelity in math_fidelities] - # TODO KWARGS + kwargs = os.getenv("KWARGS", None) + if kwargs: + kwargs = eval(kwargs) filtered_collection = TestCollection( operators=operators, input_sources=input_sources, + input_shapes=input_shapes, dev_data_formats=dev_data_formats, math_fidelities=math_fidelities, + kwargs=kwargs, ) return filtered_collection