Skip to content

Commit

Permalink
FIlter test plan by INPUT_SHAPES and KWARGS
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Nov 6, 2024
1 parent 11f0704 commit 77c847a
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ def build_filtered_collection(cls) -> TestCollection:
Query criterias are defined by the following environment variables:
- OPERATORS: List of operators to filter
- INPUT_SOURCES: List of input sources to filter
- INPUT_SHAPES: List of input shapes to filter
- DEV_DATA_FORMATS: List of data formats to filter
- MATH_FIDELITIES: List of math fidelities to filter
- KWARGS: Dictionary of kwargs to filter
"""
operators = os.getenv("OPERATORS", None)
if operators:
Expand All @@ -88,7 +90,9 @@ def build_filtered_collection(cls) -> TestCollection:
input_sources = input_sources.split(",")
input_sources = [getattr(InputSource, input_source) for input_source in input_sources]

# TODO INPUT_SHAPES
input_shapes = os.getenv("INPUT_SHAPES", None)
if input_shapes:
input_shapes = eval(input_shapes)

dev_data_formats = os.getenv("DEV_DATA_FORMATS", None)
if dev_data_formats:
Expand All @@ -100,13 +104,17 @@ def build_filtered_collection(cls) -> TestCollection:
math_fidelities = math_fidelities.split(",")
math_fidelities = [getattr(forge.MathFidelity, math_fidelity) for math_fidelity in math_fidelities]

# TODO KWARGS
kwargs = os.getenv("KWARGS", None)
if kwargs:
kwargs = eval(kwargs)

filtered_collection = TestCollection(
operators=operators,
input_sources=input_sources,
input_shapes=input_shapes,
dev_data_formats=dev_data_formats,
math_fidelities=math_fidelities,
kwargs=kwargs,
)

return filtered_collection
Expand Down

0 comments on commit 77c847a

Please sign in to comment.