Skip to content

Commit

Permalink
Enable RGG
Browse files Browse the repository at this point in the history
Issue #313
  • Loading branch information
vbrkicTT committed Sep 27, 2024
1 parent f001bcf commit e3df179
Show file tree
Hide file tree
Showing 20 changed files with 2,286 additions and 0 deletions.
4 changes: 4 additions & 0 deletions forge/forge/op_repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange
from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository
from .datatypes import ShapeCalculationContext
from .pybuda_operators import pybuda_operator_repository
from .pytorch_operators import pytorch_operator_repository

__ALL__ = [
"OperandNumInt",
Expand All @@ -26,4 +28,6 @@
"OperatorDefinition",
"OperatorRepository",
"ShapeCalculationContext",
"pybuda_operator_repository",
"pytorch_operator_repository",
]
81 changes: 81 additions & 0 deletions forge/forge/op_repo/pybuda_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
# PyBuda repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [

# Unary operators
OperatorDefinition("exp", "forge.op.Exp", 1),
OperatorDefinition("reciprocal", "forge.op.Reciprocal", 1),
OperatorDefinition("buffer", "forge.op.Buffer", 1),
OperatorDefinition("sqrt", "forge.op.Sqrt", 1),
OperatorDefinition("relu", "forge.op.Relu", 1),
OperatorDefinition("leaky_relu", "forge.op.LeakyRelu", 1, forward_params=[
OperatorParamNumber("alpha", float, 0, 100),
]),
OperatorDefinition("nop", "forge.op.Identity", 1),
OperatorDefinition("gelu", "forge.op.Gelu", 1),
OperatorDefinition("log", "forge.op.Log", 1),
OperatorDefinition("sigmoid", "forge.op.Sigmoid", 1),
OperatorDefinition("clip", "forge.op.Clip", 1, forward_params=[
OperatorParamNumber("min", float, 0, 100),
OperatorParamNumber("max", float, 0, 100),
]),
OperatorDefinition("sine", "forge.op.Sine", 1),
OperatorDefinition("cosine", "forge.op.Cosine", 1),
OperatorDefinition("abs", "forge.op.Abs", 1),
OperatorDefinition("tanh", "forge.op.Tanh", 1),
OperatorDefinition("cumsum", "forge.op.CumSum", 1),
OperatorDefinition("argmax", "forge.op.Argmax", 1),
OperatorDefinition("logical_not", "forge.op.LogicalNot", 1),
OperatorDefinition("dropout", "forge.op.Dropout", 1),
OperatorDefinition("pow", "forge.op.Pow", 1, forward_params=[
OperatorParamNumber("exponent", float, 0, 100),
]),
OperatorDefinition("tilizer", "forge.op.Tilize", 1),

# Binary operators
OperatorDefinition("add", "forge.op.Add", 2),
OperatorDefinition("divide", "forge.op.Divide", 2),
OperatorDefinition("subtract", "forge.op.Subtract", 2),
OperatorDefinition("multiply", "forge.op.Multiply", 2),
OperatorDefinition("maximum", "forge.op.Max", 2),
OperatorDefinition("minimum", "forge.op.Min", 2),
OperatorDefinition("heaviside", "forge.op.Heaviside", 2),
OperatorDefinition("binary_stack", "forge.op.BinaryStack", 2),
OperatorDefinition("power", "forge.op.Power", 2),
OperatorDefinition("greater", "forge.op.Greater", 2),
OperatorDefinition("greater_equal", "forge.op.GreaterEqual", 2),
OperatorDefinition("less", "forge.op.Less", 2),
OperatorDefinition("less_equal", "forge.op.LessEqual", 2),
OperatorDefinition("equal", "forge.op.Equal", 2),
OperatorDefinition("not_equal", "forge.op.NotEqual", 2),
OperatorDefinition("logical_and", "forge.op.LogicalAnd", 2),

# Nary operators
OperatorDefinition("where", "forge.op.Where", 3),
# OperatorDefinition("index_copy", "forge.op.IndexCopy", 3), # Bug #2705
OperatorDefinition("interleave", "forge.op.Interleave", (1,10), forward_params=[
OperatorParamNumber("axis", int, -3, -3),
OperatorParamNumber("stride", int, 1, 1),
]),
OperatorDefinition("concatenate", "forge.op.Concatenate", (1, 10), forward_params=[
OperatorParamNumber("axis", int, -10, 10),
]),
OperatorDefinition("stack", "forge.op.Stack", (2,4), forward_params=[
OperatorParamNumber("axis", int, 1, 10),
]),

OperatorDefinition("matmul", "forge.op.Matmul", 2),
# OperatorDefinition("sparse_matmul", "forge.op.SparseMatmul", 2),
]


pybuda_operator_repository = OperatorRepository([op for op in _OPERATORS])
73 changes: 73 additions & 0 deletions forge/forge/op_repo/pytorch_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
# PyTorch repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
OperatorDefinition("linear", "torch.nn.Linear", 1, instantiate=True, constructor_params=[
OperatorParamNumber("in_features", int, 10, 50),
OperatorParamNumber("out_features", int, 10, 50),
]),
OperatorDefinition("conv2d", "torch.nn.Conv2d", 1, instantiate=True, constructor_params=[
OperatorParamNumber("in_channels", int, 10, 50),
OperatorParamNumber("out_channels", int, 10, 50),
OperatorParamNumber("kernel_size", int, 3, 3),
OperatorParamNumber("stride", int, 1, 1),
OperatorParamNumber("padding", int, 1, 1),
]),
OperatorDefinition("relu", "torch.relu", 1),
OperatorDefinition("sqrt", "torch.sqrt", 1),
OperatorDefinition("tanh", "torch.tanh", 1),
# OperatorDefinition("add", "torch.add", 1),
OperatorDefinition("add", "torch.add", 2),
OperatorDefinition("sub", "torch.sub", 2),
OperatorDefinition("mul", "torch.mul", 2),
OperatorDefinition("div", "torch.div", 2),
OperatorDefinition("ge", "torch.ge", 2),

# Non-linear activation functions
# HARDTANH = OperatorDefinition("hardtanh", 1)
# HARDWISH = OperatorDefinition("hardwish", 1)
# RELU6 = OperatorDefinition("relu6", 1)
# ELU = OperatorDefinition("elu", 1)
# SELU = OperatorDefinition("selu", 1)
# CELU = OperatorDefinition("celu", 1)
# LEACKY_RELU = OperatorDefinition("leaky_relu", 1)
# PRELU = OperatorDefinition("prelu", 1)
# RRELU = OperatorDefinition("rrelu", 1)
# GLU = OperatorDefinition("glu", 1)
# GELU = OperatorDefinition("gelu", 1)
# LOGSIGMOID = OperatorDefinition("logsigmoid", 1)
# HARDSHRINK = OperatorDefinition("hardshrink", 1)
# TANHSHRINK = OperatorDefinition("tanhshrink", 1)
# SOFTSIGN = OperatorDefinition("softsign", 1)
# SOFTPLUS = OperatorDefinition("softplus", 1)
# SOFTMIN = OperatorDefinition("softmin", 1)
# SOFTMAX = OperatorDefinition("softmax", 1)
# SOFTSHRINK = OperatorDefinition("softshrink", 1)
# GUMBEL_SOFTMAX = OperatorDefinition("gumbel_softmax", 1)
# LOG_SOFTMAX = OperatorDefinition("log_softmax", 1)
# TANH = OperatorDefinition("tanh", 1)
# SIGMOID = OperatorDefinition("sigmoid", 1)
# HARDSIGMOID = OperatorDefinition("hardsigmoid", 1)
# SILU = OperatorDefinition("silu", 1)
# MISH = OperatorDefinition("mish", 1)
# BATCH_NORM = OperatorDefinition("batch_norm", 1)
# GROUP_NORM = OperatorDefinition("group_norm", 1)
# INSTANCE_NORM = OperatorDefinition("instance_norm", 1)
# LAYER_NORM = OperatorDefinition("layer_norm", 1)
# LOCAL_RESPONSE_NORM = OperatorDefinition("local_response_norm", 1)
# NORMALIZE = OperatorDefinition("normalize", 1)

OperatorDefinition("matmul", "torch.matmul", 2),
OperatorDefinition("eltwise", "torch.add", 2),
]


pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])
1 change: 1 addition & 0 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import forge
import torch

from loguru import logger
from loguru import logger
from typing import Optional, List

Expand Down
14 changes: 14 additions & 0 deletions forge/test/random/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import forge

from .rgg import get_randomizer_config_default

test_rg = random.Random()
seeds = []

Expand All @@ -27,6 +29,18 @@ def run_test(test_index, random_seeds):
yield

def pytest_generate_tests(metafunc):
if "randomizer_config" in metafunc.fixturenames:
configs = []
for (build_model_from_code,) in [
(True,),
# (False,),
]:
config = get_randomizer_config_default()
# config.build_model_from_code = build_model_from_code
# config.debug_forward = not build_model_from_code
# config.print_code = not build_model_from_code
configs.append(config)
metafunc.parametrize("randomizer_config", configs)
if "test_index" in metafunc.fixturenames:
if "RANDOM_TEST_COUNT" in os.environ:
test_count = int(os.environ["RANDOM_TEST_COUNT"])
Expand Down
48 changes: 48 additions & 0 deletions forge/test/random/rgg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0


from .datatypes import TensorShape
from .datatypes import RandomizerConstantNode
from .datatypes import RandomizerInputNode, RandomizerNode, ExecutionContext, RandomizerParameters, RandomizerGraph, RandomizerConfig
from .datatypes import NodeShapeCalculationContext
from .datatypes import RandomizerTestContext
from .datatypes import ModelBuilder, Framework
from .config import get_randomizer_config_default
from .utils import StrUtils, GraphUtils
from .utils import DebugUtils
from .base import GraphBuilder
from .base import RandomizerRunner, RandomizerCodeGenerator, process_test
from .frameworks import Frameworks
from .frameworks import FrameworkTestUtils
from .algorithms import GraphNodeSetup
from .algorithms import RandomGraphAlgorithm

__all__ = [
"TensorShape",
"RandomizerConstantNode",
"RandomizerInputNode",
"RandomizerNode",
"ExecutionContext",
"RandomizerParameters",
"RandomizerGraph",
"RandomizerConfig",
"NodeShapeCalculationContext",
"RandomizerTestContext",
"ModelBuilder",
"Framework",
"get_randomizer_config_default",
"StrUtils",
"GraphUtils",
"DebugUtils",
"Framework",
"GraphBuilder",
"RandomizerRunner",
"RandomizerCodeGenerator",
"process_test",
"Frameworks",
"FrameworkTestUtils"
"GraphNodeSetup",
"RandomGraphAlgorithm",
]
Loading

0 comments on commit e3df179

Please sign in to comment.