Skip to content

Commit

Permalink
Filter test plan by INPUT_SHAPES and KWARGS
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Nov 7, 2024
1 parent f0f64e6 commit 302cc32
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@


# Examples
# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique --collect-only
# TEST_ID='no_device-ge-FROM_HOST-None-(1, 2, 3, 4)-Float16_b-HiFi4' pytest -svv forge/test/operators/pytorch/test_all.py::test_single
# OPERATORS=add,div FILTERS=HAS_DATA_FORMAT,QUICK DEV_DATA_FORMATS=Float16_b,Int8 MATH_FIDELITIES=HiFi4,HiFi3 RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only
# OPERATORS=add,div FAILING_REASONS=DATA_MISMATCH,UNSUPPORTED_DATA_FORMAT SKIP_REASONS=FATAL_ERROR RANGE=5 pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only
# FAILING_REASONS=NOT_IMPLEMENTED INPUT_SOURCES=FROM_HOST pytest -svv forge/test/operators/pytorch/test_all.py::test_query --collect-only

# pytest -svv forge/test/operators/pytorch/test_all.py::test_plan
# pytest -svv forge/test/operators/pytorch/test_all.py::test_failed
# pytest -svv forge/test/operators/pytorch/test_all.py::test_skipped
Expand All @@ -20,6 +14,9 @@
# pytest -svv forge/test/operators/pytorch/test_all.py::test_data_mismatch
# pytest -svv forge/test/operators/pytorch/test_all.py::test_unsupported_df
# pytest -svv forge/test/operators/pytorch/test_all.py::test_custom
# pytest -svv forge/test/operators/pytorch/test_all.py::test_query
# pytest -svv forge/test/operators/pytorch/test_all.py::test_unique
# pytest -svv forge/test/operators/pytorch/test_all.py::test_single


import os
Expand Down Expand Up @@ -74,8 +71,10 @@ def build_filtered_collection(cls) -> TestCollection:
Query criterias are defined by the following environment variables:
- OPERATORS: List of operators to filter
- INPUT_SOURCES: List of input sources to filter
- INPUT_SHAPES: List of input shapes to filter
- DEV_DATA_FORMATS: List of data formats to filter
- MATH_FIDELITIES: List of math fidelities to filter
- KWARGS: List of kwargs dictionaries to filter
"""
operators = os.getenv("OPERATORS", None)
if operators:
Expand All @@ -88,7 +87,9 @@ def build_filtered_collection(cls) -> TestCollection:
input_sources = input_sources.split(",")
input_sources = [getattr(InputSource, input_source) for input_source in input_sources]

# TODO INPUT_SHAPES
input_shapes = os.getenv("INPUT_SHAPES", None)
if input_shapes:
input_shapes = eval(input_shapes)

dev_data_formats = os.getenv("DEV_DATA_FORMATS", None)
if dev_data_formats:
Expand All @@ -100,13 +101,17 @@ def build_filtered_collection(cls) -> TestCollection:
math_fidelities = math_fidelities.split(",")
math_fidelities = [getattr(forge.MathFidelity, math_fidelity) for math_fidelity in math_fidelities]

# TODO KWARGS
kwargs = os.getenv("KWARGS", None)
if kwargs:
kwargs = eval(kwargs)

filtered_collection = TestCollection(
operators=operators,
input_sources=input_sources,
input_shapes=input_shapes,
dev_data_formats=dev_data_formats,
math_fidelities=math_fidelities,
kwargs=kwargs,
)

return filtered_collection
Expand Down

0 comments on commit 302cc32

Please sign in to comment.