From 1c22e32e53d7a573d44d3639b3ed90b7bb563741 Mon Sep 17 00:00:00 2001 From: Erika Hunhoff Date: Fri, 20 Sep 2024 19:13:11 -0600 Subject: [PATCH] Some fun with types, working on getting matrix vector working still --- .../matrix_vector/aie2.py | 24 +++++++-------- .../matrix_multiplication/single_core/aie2.py | 25 +++++++--------- .../basic/matrix_scalar_add/aie2.py | 22 +++++--------- python/api/dataflow/inout/inout.py | 4 +-- python/api/dataflow/inout/simplefifoinout.py | 9 +++--- python/api/dataflow/objectfifo.py | 5 ++-- python/api/dataflow/objectfifolink.py | 29 ++++++++++--------- python/api/kernels/binkernel.py | 6 ++-- python/api/phys/device.py | 4 +-- python/api/phys/tile.py | 3 +- python/api/worker.py | 10 +++---- python/dialects/aie.py | 24 +++++++-------- python/dialects/aiex.py | 7 ++--- python/extras/context.py | 5 ++-- python/extras/dialects/ext/arith.py | 18 ++++++------ python/extras/dialects/ext/func.py | 7 +++-- python/extras/dialects/ext/memref.py | 12 ++++---- python/extras/dialects/ext/scf.py | 4 +-- python/extras/dialects/ext/tensor.py | 26 ++++++++--------- python/extras/runtime/passes.py | 6 ++-- python/extras/util.py | 21 ++++++++++---- python/utils/trace.py | 4 +-- 22 files changed, 133 insertions(+), 142 deletions(-) diff --git a/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py b/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py index 556a82b539..5b0fd6f046 100644 --- a/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py +++ b/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py @@ -9,7 +9,7 @@ from aie.extras.dialects.ext.scf import _for as range_ from aie.dialects.aiex import npu_dma_memcpy_nd, npu_sync -from aie.api.dataflow.inout.inout import MyInOutProgram +from aie.api.dataflow.inout.inout import MyInOutSequence from aie.api.dataflow.objectfifo import MyObjectFifo from aie.api.dataflow.objectfifolink import MyObjectFifoLink from aie.api.kernels.binkernel import BinKernel @@ -46,9 +46,9 @@ dtype_out_str = "i32" # Input/output tensor definitions # TODO: can simplify if single value? -inA_ty = np.ndarray(dtype_in, (M * K)) -inB_ty = np.ndarray(dtype_in, (K,)) -outC_ty = np.ndarray(dtype_out, (M,)) +inA_ty = np.ndarray[dtype_in, (M * K,)] +inB_ty = np.ndarray[dtype_in, (K,)] +outC_ty = np.ndarray[dtype_out, (M,)] a_ty = np.ndarray[dtype_in, (m, k)] a_flat_ty = np.ndarray[dtype_in, (m * k,)] b_ty = np.ndarray[dtype_in, (k,)] @@ -59,7 +59,7 @@ zero = BinKernel(f"zero_{scalar_str}{dtype_out_str}", f"mv_{m}x{k}.o", [c_ty]) matvec = BinKernel( f"matvec_{scalar_str}{dtype_in_str}_{dtype_out_str}", - f"mm_{m}x{k}x{n}.o", + f"mv_{m}x{k}.o", [a_ty, b_ty, c_ty], ) @@ -88,13 +88,13 @@ def core_body(a_in, b_in, c_out, zero, matvec): # Setup workers + per-worker dataflow -inB_fifo = MyObjectFifo(2, b_ty, name="inB", end_first=(1, 0)) +inB_fifo = MyObjectFifo(2, b_ty, name="inB", shim_endpoint=(1, 0)) for i in range(n_cores): # Create object fifos for per-code dataflow - memA = MyObjectFifo(2, a_flat_ty, name=f"memA{i}", end_first=(i, 0)) - toStreamA = [(k // 2 // 2, 2), (m, k), (2, 1)] if vectorized else [] - inA = MyObjectFifo(2, a_ty, name=f"inA{i}", toStream=toStreamA) - outC = MyObjectFifo(2, c_ty, end_second=(i, 0)) + memA = MyObjectFifo(2, a_flat_ty, name=f"memA{i}", shim_endpoint=(i, 0)) + dimensionsToStreamA = [(k // 2 // 2, 2), (m, k), (2, 1)] if vectorized else [] + inA = MyObjectFifo(2, a_ty, name=f"inA{i}", dimensionsToStream=dimensionsToStreamA) + outC = MyObjectFifo(2, c_ty, shim_endpoint=(i, 0)) # Create per-core worker program worker_programs.append( @@ -147,7 +147,7 @@ def sequence_fn(A, B, C, memA, inB, memC): npu_sync(column=i, row=0, direction=0, channel=0) -inout_program = MyInOutProgram( +inout_sequence = MyInOutSequence( sequence_fn, [inA_ty, inB_ty, outC_ty], [memA_fifos, inB_fifo, outC_fifos], @@ -157,7 +157,7 @@ def sequence_fn(A, B, C, memA, inB, memC): NPU1Col4(), worker_programs=worker_programs, links=A_links, - inout_program=inout_program, + inout_sequence=inout_sequence, ) my_program.resolve_program() diff --git a/programming_examples/basic/matrix_multiplication/single_core/aie2.py b/programming_examples/basic/matrix_multiplication/single_core/aie2.py index 8def3589fa..3d3b07034b 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/aie2.py +++ b/programming_examples/basic/matrix_multiplication/single_core/aie2.py @@ -15,7 +15,7 @@ from aie.extras.dialects.ext.scf import _for as range_ from aie.dialects.aiex import npu_dma_memcpy_nd, npu_sync -from aie.api.dataflow.inout.inout import MyInOutProgram +from aie.api.dataflow.inout.inout import MyInOutSequence from aie.api.dataflow.objectfifo import MyObjectFifo from aie.api.dataflow.objectfifolink import MyObjectFifoLink from aie.api.kernels.binkernel import BinKernel @@ -112,26 +112,24 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized): [a_ty, b_ty, c_ty], ) - inA = MyObjectFifo(2, a_ty) + inA = MyObjectFifo(2, a_ty, shim_endpoint=(0, 0)) memAToStream = [(m // r, r * k), (k // s, s), (r, k), (s, 1)] if vectorized else [] memA = MyObjectFifo(2, a_ty, dimensionsToStream=memAToStream) - inALink = MyObjectFifoLink([inA.second], [memA.first], coords=(0, 1)) # AnyMemtile + inALink = MyObjectFifoLink([inA.second], [memA.first], coords=(0, 1)) # Input B - inB = MyObjectFifo(2, b_ty) + inB = MyObjectFifo(2, b_ty, shim_endpoint=(0, 0)) memBToStream = [(k // s, s * n), (n // t, t), (s, n), (t, 1)] if vectorized else [] memB = MyObjectFifo(2, b_ty, dimensionsToStream=memBToStream) - inBLink = MyObjectFifoLink([inB.second], [memB.first], coords=(0, 1)) # AnyMemtile + inBLink = MyObjectFifoLink([inB.second], [memB.first], coords=(0, 1)) # Output C memC = MyObjectFifo(2, c_ty) memCToStream = ( [(m // r, r * n), (r, t), (n // t, r * t), (t, 1)] if vectorized else [] ) - outC = MyObjectFifo(2, c_ty, dimensionsToStream=memCToStream) - outCLink = MyObjectFifoLink( - [memC.second], [outC.first], coords=(0, 1) - ) # AnyMemtile + outC = MyObjectFifo(2, c_ty, dimensionsToStream=memCToStream, shim_endpoint=(0, 0)) + outCLink = MyObjectFifoLink([memC.second], [outC.first], coords=(0, 1)) def core_fn(a, b, c, zero, matmul): for _ in range_(0xFFFFFFFF): @@ -199,28 +197,25 @@ def sequence_fn(A, B, C, inA, inB, outC): npu_sync(column=0, row=0, direction=0, channel=0) npu_sync(column=0, row=0, direction=0, channel=0) - inout_program = MyInOutProgram( + inout_sequence = MyInOutSequence( sequence_fn, [A_ty, B_ty, C_ty], [inA.first, inB.first, outC.second], - coords=(0, 0), # AnyShim ) worker_program = MyWorker( core_fn, [memA.second, memB.second, memC.first, zero, matmul], - coords=(0, 2), # AnyCore + coords=(0, 2), ) my_program = MyProgram( NPU1Col1(), worker_programs=[worker_program], links=[inALink, inBLink, outCLink], - inout_program=inout_program, - # placer=SequentialPlacer(pack=True) + inout_sequence=inout_sequence, ) - # g = my_program.get_dataflow_graph() my_program.resolve_program() diff --git a/programming_examples/basic/matrix_scalar_add/aie2.py b/programming_examples/basic/matrix_scalar_add/aie2.py index 1d910f390b..4f53850413 100644 --- a/programming_examples/basic/matrix_scalar_add/aie2.py +++ b/programming_examples/basic/matrix_scalar_add/aie2.py @@ -12,7 +12,7 @@ from aie.extras.dialects.ext.arith import constant from aie.extras.dialects.ext.func import func from aie.extras.dialects.ext.scf import _for as range_ -from aie.api.dataflow.inout.simplefifoinout import SimpleFifoInOutProgram +from aie.api.dataflow.inout.simplefifoinout import SimpleFifoInOutSequence from aie.api.dataflow.objectfifo import MyObjectFifo from aie.api.phys.device import NPU1Col1, XCVC1902 from aie.api.program import MyProgram @@ -42,12 +42,13 @@ else: raise ValueError("[ERROR] Device name {} is unknown".format(sys.argv[1])) +col = int(sys.argv[2]) my_dtype = np.int32 tile_ty = np.ndarray[my_dtype, (TILE_SIZE,)] # AIE-array data movement with object fifos -of_in = MyObjectFifo(objfifo_capacity, tile_ty) -of_out = MyObjectFifo(objfifo_capacity, tile_ty) +of_in = MyObjectFifo(objfifo_capacity, tile_ty, shim_endpoint=(col, 0)) +of_out = MyObjectFifo(objfifo_capacity, tile_ty, shim_endpoint=(col, 0)) @func @@ -66,15 +67,15 @@ def core_fn(of_in, of_out, add_kernel): of_out.release(1) -# Set up compute tile 2 TODO: clean up placement +# Set up worker worker_program = MyWorker( core_fn, [of_in.second, of_out.first, add_kernel], - coords=(int(sys.argv[2]), 2), + coords=(col, 2), ) # To/from AIE-array data movement -inout_program = SimpleFifoInOutProgram( +inout_sequence = SimpleFifoInOutSequence( of_in.first, TILE_SIZE, of_out.second, @@ -84,16 +85,9 @@ def core_fn(of_in, of_out, add_kernel): out_sizes=[1, 1, TILE_HEIGHT, TILE_WIDTH], out_strides=[1, 1, IMAGE_WIDTH, 1], dtype=my_dtype, - coords=(int(sys.argv[2]), 0), ) my_program = MyProgram( - dev, worker_programs=[worker_program], inout_program=inout_program + dev, worker_programs=[worker_program], inout_sequence=inout_sequence ) my_program.resolve_program() - -""" -TODOs: -* look into # @canonicalize(using=scf_canonicalizer) shoudl decorate this after func if we want control flow -* we need emit = true because must be emited in outer loop (not deferred) to have access to symbol table -""" diff --git a/python/api/dataflow/inout/inout.py b/python/api/dataflow/inout/inout.py index f90395ff38..88d9d4115a 100644 --- a/python/api/dataflow/inout/inout.py +++ b/python/api/dataflow/inout/inout.py @@ -29,10 +29,10 @@ def __init__( ): self.sequence_fn = sequence_fn self.inout_types = inout_types - self.fifos = fifos + self.fifos = fifos.copy() def get_fifos(self) -> list[ObjectFifoHandle]: - return self.fifos + return self.fifos.copy() def resolve( self, diff --git a/python/api/dataflow/inout/simplefifoinout.py b/python/api/dataflow/inout/simplefifoinout.py index 8e4631e81b..ed4a351e46 100644 --- a/python/api/dataflow/inout/simplefifoinout.py +++ b/python/api/dataflow/inout/simplefifoinout.py @@ -5,7 +5,6 @@ """ import numpy as np -from typing import Optional from .... import ir from ....dialects.aiex import runtime_sequence, npu_sync, npu_dma_memcpy_nd @@ -21,10 +20,10 @@ def __init__( bytes_in: int, fifo_out: ObjectFifoHandle, bytes_out: int, - in_sizes: Optional[list[int]] = None, - in_strides: Optional[list[int]] = None, - out_sizes: Optional[list[int]] = None, - out_strides: Optional[list[int]] = None, + in_sizes: list[int] | None = None, + in_strides: list[int] | None = None, + out_sizes: list[int] | None = None, + out_strides: list[int] | None = None, dtype: np.generic = np.uint8, ): assert bytes_in % np.prod(get_np_ndarray_type_shape(fifo_in.obj_type)) == 0 diff --git a/python/api/dataflow/objectfifo.py b/python/api/dataflow/objectfifo.py index 9ef1729f38..1ea8e062a3 100644 --- a/python/api/dataflow/objectfifo.py +++ b/python/api/dataflow/objectfifo.py @@ -7,7 +7,6 @@ # Address circular dependency between MyObjectFifo and ObjectFifoHandle from __future__ import annotations -from typing import Optional import numpy as np from ... import ir @@ -39,7 +38,7 @@ def __init__( end2: MyObjectFifoEndpoint = None, dimensionsToStream=None, # TODO(erika): needs a type dimensionsFromStreamPerConsumer=None, # TODO(erika): needs a type - shim_endpoint: Optional[tuple[int, int]] = None, + shim_endpoint: tuple[int, int] | None = None, ): self.__depth = depth self.__obj_type = obj_type @@ -52,7 +51,7 @@ def __init__( self.name = f"myof{MyObjectFifo.__get_index()}" else: self.name = name - self.__op: Optional[ObjectFifoCreateOp] = None + self.__op: ObjectFifoCreateOp | None = None self.__first: ObjectFifoHandle = ObjectFifoHandle(self, True) self.__second: ObjectFifoHandle = ObjectFifoHandle(self, False) if shim_endpoint: diff --git a/python/api/dataflow/objectfifolink.py b/python/api/dataflow/objectfifolink.py index 38f9751313..b94bdc1a61 100644 --- a/python/api/dataflow/objectfifolink.py +++ b/python/api/dataflow/objectfifolink.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Optional +from collections.abc import Sequence from ... import ir from ...dialects._aie_ops_gen import ObjectFifoLinkOp @@ -8,31 +8,34 @@ from ..phys.tile import MyTile from .endpoint import MyObjectFifoEndpoint from .objectfifo import ObjectFifoHandle +from ...extras.util import single_elem_or_list_to_list class MyObjectFifoLink(MyObjectFifoEndpoint): def __init__( self, - seconds: list[ObjectFifoHandle] = [], - firsts: list[ObjectFifoHandle] = [], - coords: Optional[tuple[int, int]] = None, + seconds: Sequence[ObjectFifoHandle] | ObjectFifoHandle = [], + firsts: Sequence[ObjectFifoHandle] | ObjectFifoHandle = [], + coords: tuple[int, int] | None = None, ): column, row = coords self.__tile = MyTile(column, row) - self.__seconds = [] - self.__firsts = [] + self.__seconds = single_elem_or_list_to_list(seconds) + self.__firsts = single_elem_or_list_to_list(firsts) self.__op = None - self.__obj_type = seconds[0].obj_type - for s in seconds: - assert s.obj_type == self.__obj_type - s.set_endpoint(self) - self.__seconds.append(s) - for f in firsts: + assert len(self.__firsts) > 0 + assert len(self.__seconds) > 0 + + self.__obj_type = self.__seconds[0].obj_type + for f in self.__firsts: + # TODO: need to check size not exactness assert f.obj_type == self.__obj_type f.set_endpoint(self) - self.__firsts.append(f) + for s in self.__seconds: + assert s.obj_type == self.__obj_type + s.set_endpoint(self) @property def tile(self) -> MyTile: diff --git a/python/api/kernels/binkernel.py b/python/api/kernels/binkernel.py index 5cd70af085..ed32bd258d 100644 --- a/python/api/kernels/binkernel.py +++ b/python/api/kernels/binkernel.py @@ -5,7 +5,7 @@ """ import numpy as np -from typing import get_origin, Optional, Union +from typing import get_origin from ... import ir @@ -22,13 +22,13 @@ def __init__( name: str, bin_name: str, inout_types: list[ - Union[np.ndarray[np.generic.dtype, np.generic.shape], np.dtype] + np.ndarray[np.generic.dtype, np.generic.shape] | np.dtype ] = [], ) -> None: self.__name = name self.__bin_name = bin_name self.__inout_types = inout_types - self.__op: Optional[FuncOp] = None + self.__op: FuncOp | None = None @property def bin_name(self) -> str: diff --git a/python/api/phys/device.py b/python/api/phys/device.py index 2554dd6677..5aa59e28e4 100644 --- a/python/api/phys/device.py +++ b/python/api/phys/device.py @@ -14,8 +14,6 @@ } """ -from typing import Optional - from ... import ir from ...dialects.aie import AIEDevice, tile, TileOp from ..resolvable import Resolvable @@ -35,7 +33,7 @@ class __MyDeviceTile(Resolvable): def __init__(self, col: int, row: int) -> None: self.__col: int = col self.__row: int = row - self.__op: Optional[TileOp] = None + self.__op: TileOp | None = None super().__init__() def resolve( diff --git a/python/api/phys/tile.py b/python/api/phys/tile.py index 3dba22722c..bd870954b7 100644 --- a/python/api/phys/tile.py +++ b/python/api/phys/tile.py @@ -5,7 +5,6 @@ * tile types" """ -from typing import Optional from ...dialects.aie import TileOp @@ -13,7 +12,7 @@ class MyTile: def __init__(self, col: int, row: int) -> None: self.col: int = col self.row: int = row - self.__op: Optional[TileOp] = None + self.__op: TileOp | None = None @property def op(self) -> TileOp: diff --git a/python/api/worker.py b/python/api/worker.py index bff6031063..9888a89837 100644 --- a/python/api/worker.py +++ b/python/api/worker.py @@ -6,7 +6,7 @@ """ import sys -from typing import Callable, Optional, Union +from typing import Callable from .. import ir from ..dialects.aie import core @@ -20,8 +20,8 @@ class MyWorker(MyObjectFifoEndpoint): def __init__( self, - core_fn: Optional[Callable[[Union[ObjectFifoHandle, MyKernel]], None]], - fn_args: list[Union[ObjectFifoHandle, MyKernel]] = [], + core_fn: Callable[[ObjectFifoHandle | MyKernel], None] | None, + fn_args: list[ObjectFifoHandle | MyKernel] = [], coords: tuple[int, int] = None, ): column, row = coords @@ -35,7 +35,7 @@ def do_nothing_core_fun() -> None: self.core_fn = do_nothing_core_fun else: self.core_fn = core_fn - self.link_with: Optional[str] = None + self.link_with: str | None = None self.fn_args = fn_args bin_names = set() self.__fifos = [] @@ -56,7 +56,7 @@ def tile(self) -> MyTile: return self.__tile def get_fifos(self) -> list[ObjectFifoHandle]: - return self.__fifos + return self.__fifos.copy() def resolve( self, diff --git a/python/dialects/aie.py b/python/dialects/aie.py index 936a8071a5..5b13d64f4a 100644 --- a/python/dialects/aie.py +++ b/python/dialects/aie.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from dataclasses import dataclass import inspect -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import List, Tuple, Dict, Any import contextlib import numpy as np @@ -106,9 +106,7 @@ def bd_dim_layout(size, stride): @register_attribute_builder("BDDimLayoutArrayAttr") -def bd_dim_layout_array_attr_builder( - tups: List[Union[Attribute, Tuple[int]]], context=None -): +def bd_dim_layout_array_attr_builder(tups: List[Attribute | Tuple[int]], context=None): if isinstance(tups, list) and all(isinstance(t, tuple) for t in tups): tups = list(map(lambda t: bd_dim_layout(*t), tups)) return Attribute.parse( @@ -374,7 +372,7 @@ def __init__( dest, dest_port, dest_channel, - keep_pkt_header: Optional[bool] = None, + keep_pkt_header: bool | None = None, ): super().__init__(ID=pkt_id, keep_pkt_header=keep_pkt_header) bb = Block.create_at_start(self.ports) @@ -448,9 +446,9 @@ def __init__( channel_dir, channel_index, *, - dest: Optional[Union[Successor, Block]] = None, - chain: Optional[Union[Successor, Block]] = None, - repeat_count: Optional[int] = None, + dest: Successor | Block | None = None, + chain: Successor | Block | None = None, + repeat_count: int | None = None, loc=None, ip=None, ): @@ -485,8 +483,8 @@ def dma_start( channel_dir, channel_index, *, - dest: Optional[Union[Successor, Block]] = None, - chain: Optional[Union[Successor, Block]] = None, + dest: Successor | Block | None = None, + chain: Successor | Block | None = None, loc=None, ip=None, ): @@ -496,9 +494,7 @@ def dma_start( @_cext.register_operation(_Dialect, replace=True) class NextBDOp(NextBDOp): - def __init__( - self, dest: Optional[Union[Successor, Block]] = None, *, loc=None, ip=None - ): + def __init__(self, dest: Successor | Block | None = None, *, loc=None, ip=None): if isinstance(dest, Successor): dest = dest.block if dest is None: @@ -513,7 +509,7 @@ def dest(self): def next_bd( - dest: Optional[Union[Successor, Block, ContextManagedBlock]] = None, + dest: Successor | Block | ContextManagedBlock | None = None, loc=None, ip=None, ): diff --git a/python/dialects/aiex.py b/python/dialects/aiex.py index c23866b4c9..db86db5165 100644 --- a/python/dialects/aiex.py +++ b/python/dialects/aiex.py @@ -4,7 +4,6 @@ from functools import partial import itertools from operator import itemgetter -from typing import Union, Optional import numpy as np @@ -47,7 +46,7 @@ def __init__( offsets: MixedValues = None, sizes: MixedValues = None, strides: MixedValues = None, - issue_token: Optional[bool] = None, + issue_token: bool | None = None, ): x = 0 y = 0 @@ -685,8 +684,8 @@ def __repr__(self): def broadcast_flow( - source: Union[np.ndarray, TileOp], - dest: Union[np.ndarray, TileOp], + source: np.ndarray | TileOp, + dest: np.ndarray | TileOp, source_bundle=None, source_channel=None, dest_bundle=None, diff --git a/python/extras/context.py b/python/extras/context.py index 5f053a07cd..5190597858 100644 --- a/python/extras/context.py +++ b/python/extras/context.py @@ -1,7 +1,6 @@ import contextlib from contextlib import ExitStack, contextmanager from dataclasses import dataclass -from typing import Optional from .. import ir @@ -17,7 +16,7 @@ def __str__(self): @contextmanager def mlir_mod_ctx( - src: Optional[str] = None, + src: str | None = None, context: ir.Context = None, location: ir.Location = None, allow_unregistered_dialects=False, @@ -45,7 +44,7 @@ class RAIIMLIRContext: context: ir.Context location: ir.Location - def __init__(self, location: Optional[ir.Location] = None): + def __init__(self, location: ir.Location | None = None): self.context = ir.Context() self.context.__enter__() if location is None: diff --git a/python/extras/dialects/ext/arith.py b/python/extras/dialects/ext/arith.py index 6debc7557d..46ca570b35 100644 --- a/python/extras/dialects/ext/arith.py +++ b/python/extras/dialects/ext/arith.py @@ -2,7 +2,7 @@ from abc import abstractmethod from copy import deepcopy from functools import cached_property, partialmethod -from typing import Optional, Tuple +from typing import Tuple from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype from ...._mlir_libs._mlir import register_value_caster @@ -43,11 +43,11 @@ def constant( - value: Union[int, float, bool, np.ndarray], - type: Optional[Type] = None, - index: Optional[bool] = None, + value: int | float | bool | np.ndarray, + type: Type | None = None, + index: bool | None = None, *, - vector: Optional[bool] = False, + vector: bool | None = False, loc: Location = None, ip: InsertionPoint = None, ) -> Value: @@ -215,7 +215,7 @@ def __call__(cls, *args, **kwargs): @register_attribute_builder("Arith_CmpIPredicateAttr", replace=True) -def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context): +def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context): predicates = { "eq": CmpIPredicate.eq, "ne": CmpIPredicate.ne, @@ -235,7 +235,7 @@ def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context) @register_attribute_builder("Arith_CmpFPredicateAttr", replace=True) -def _arith_CmpFPredicateAttr(predicate: Union[str, Attribute], context: Context): +def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context): predicates = { "false": CmpFPredicate.AlwaysFalse, # ordered comparison @@ -373,7 +373,7 @@ class ArithValue(Value, metaclass=ArithValueMeta): Value.__init__ """ - def __init__(self, val, *, fold: Optional[bool] = None): + def __init__(self, val, *, fold: bool | None = None): self._fold = fold if fold is not None else False super().__init__(val) @@ -464,7 +464,7 @@ def dtype(self) -> Type: return self.type @cached_property - def literal_value(self) -> Union[int, float, bool]: + def literal_value(self) -> int | float | bool: if not self.is_constant(): raise ValueError("Can't build literal from non-constant Scalar") return self.owner.opview.literal_value diff --git a/python/extras/dialects/ext/func.py b/python/extras/dialects/ext/func.py index 6229f575c9..19f452d6d4 100644 --- a/python/extras/dialects/ext/func.py +++ b/python/extras/dialects/ext/func.py @@ -21,9 +21,10 @@ def call( - callee_or_results: Union[FuncOp, List[Type]], - arguments_or_callee: Union[List[Value], FlatSymbolRefAttr, str], - arguments: Optional[list] = None, + callee_or_results: FuncOp | List[Type], + arguments_or_callee: List[Value] | FlatSymbolRefAttr, + str, + arguments: list | None = None, *, call_op_ctor=CallOp.__base__, loc=None, diff --git a/python/extras/dialects/ext/memref.py b/python/extras/dialects/ext/memref.py index a17350bf09..a833a19358 100644 --- a/python/extras/dialects/ext/memref.py +++ b/python/extras/dialects/ext/memref.py @@ -1,5 +1,5 @@ import inspect -from typing import Sequence, Union +from typing import Sequence import numpy as np @@ -32,7 +32,7 @@ def _alloc( op_ctor, - *sizes_element_type: Sequence[Union[int, Value]], + *sizes_element_type: Sequence[int | Value], loc=None, ip=None, ): @@ -55,17 +55,17 @@ def _alloc( ) -def alloc(*sizes: Union[int, Value], element_type: Type = None): +def alloc(*sizes: int | Value, element_type: Type = None): loc = get_user_code_loc() return _alloc(AllocOp, *sizes, element_type, loc=loc, ip=None) -def alloca(*sizes: Union[int, Value], element_type: Type = None): +def alloca(*sizes: int | Value, element_type: Type = None): loc = get_user_code_loc() return _alloc(AllocaOp, *sizes, element_type, loc=loc, ip=None) -def load(mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None): +def load(mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None): if loc is None: loc = get_user_code_loc() indices = list(indices) @@ -76,7 +76,7 @@ def load(mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None) def store( - value: Value, mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None + value: Value, mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None ): if loc is None: loc = get_user_code_loc() diff --git a/python/extras/dialects/ext/scf.py b/python/extras/dialects/ext/scf.py index 793d485dac..b2b660f12c 100644 --- a/python/extras/dialects/ext/scf.py +++ b/python/extras/dialects/ext/scf.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Sequence from ....ir import InsertionPoint, Value from ....dialects.linalg.opdsl.lang.emitter import _is_index_type @@ -11,7 +11,7 @@ def _for( start, stop=None, step=None, - iter_args: Optional[Sequence[Value]] = None, + iter_args: Sequence[Value] | None = None, insert_yield: bool = True, *, loc=None, diff --git a/python/extras/dialects/ext/tensor.py b/python/extras/dialects/ext/tensor.py index ebc760eead..cb725d6505 100644 --- a/python/extras/dialects/ext/tensor.py +++ b/python/extras/dialects/ext/tensor.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import Any, List, Tuple, Union # noinspection PyUnresolvedReferences import numpy as np @@ -20,7 +20,7 @@ S = ShapedType.get_dynamic_size() -def empty(*sizes: Union[int, Value], element_type: Type = None, loc=None, ip=None): +def empty(*sizes: int | Value, element_type: Type = None, loc=None, ip=None): if loc is None: loc = get_user_code_loc() if element_type is None: @@ -32,11 +32,11 @@ def empty(*sizes: Union[int, Value], element_type: Type = None, loc=None, ip=Non def extract_slice( source: "Tensor", - offsets: Optional[Sequence[Value]] = None, - strides: Optional[Sequence[Value]] = None, - static_offsets: Optional[Sequence[int]] = None, - static_sizes: Optional[Sequence[int]] = None, - static_strides: Optional[Sequence[int]] = None, + offsets: Sequence[Value] | None = None, + strides: Sequence[Value] | None = None, + static_offsets: Sequence[int] | None = None, + static_sizes: Sequence[int] | None = None, + static_strides: Sequence[int] | None = None, *, loc=None, ip=None, @@ -69,11 +69,11 @@ def extract_slice( def insert_slice( source: Value, dest: Value, - offsets: Optional[Sequence[Value]] = None, - strides: Optional[Sequence[Value]] = None, - static_offsets: Optional[Sequence[int]] = None, - static_sizes: Optional[Sequence[int]] = None, - static_strides: Optional[Sequence[int]] = None, + offsets: Sequence[Value] | None = None, + strides: Sequence[Value] | None = None, + static_offsets: Sequence[int] | None = None, + static_sizes: Sequence[int] | None = None, + static_strides: Sequence[int] | None = None, *, loc=None, ip=None, @@ -205,7 +205,7 @@ def coerce( class _Indexer: indices: Tuple[Union[int, Scalar, slice, "Ellipsis", None]] newaxis_dims: Tuple[int, "Ellipsis"] - in_shape: Tuple[Union[Value, int]] + in_shape: Tuple[Value | int] def is_constant(self): return all(_is_constant_index(i) for i in self.indices) diff --git a/python/extras/runtime/passes.py b/python/extras/runtime/passes.py index b5c8e678fa..4914bddd36 100644 --- a/python/extras/runtime/passes.py +++ b/python/extras/runtime/passes.py @@ -4,7 +4,7 @@ import tempfile from contextlib import ExitStack from io import StringIO -from typing import List, Optional, Union +from typing import List from ..context import disable_multithreading from ...ir import Module, StringAttr @@ -25,8 +25,8 @@ def get_module_name_for_debug_dump(module): def run_pipeline( module, - pipeline: Union[str, "Pipeline"], - description: Optional[str] = None, + pipeline: str | "Pipeline", + description: str | None = None, enable_ir_printing=False, print_pipeline=False, verify=True, diff --git a/python/extras/util.py b/python/extras/util.py index 11deb9ee63..45c3f13447 100644 --- a/python/extras/util.py +++ b/python/extras/util.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import wraps from pathlib import Path -from typing import Callable, List, Optional, Sequence, Tuple, Union, get_args +from typing import Callable, List, Sequence, Tuple, get_args, TypeVar import tensorflow as tf import numpy as np @@ -46,12 +46,21 @@ ) TypeID = object +E = TypeVar("E") + + +def single_elem_or_list_to_list(val: list[E] | E) -> list[T]: + """does not work for list of lists but still useful""" + if not isinstance(val, list): + return [val] + return val + def is_relative_to(self, other): return other == self or other in self.parents -def get_user_code_loc(user_base: Optional[Path] = None): +def get_user_code_loc(user_base: Path | None = None): from .. import extras if Context.current is None: @@ -215,8 +224,8 @@ def mlir_type_to_ctype(mlir_type): def infer_mlir_type( - py_val: Union[int, float, bool, np.ndarray], memref=False, vector=False -) -> Union[IntegerType, F32Type, F64Type, RankedTensorType]: + py_val: int | float | bool | np.ndarray, memref=False, vector=False +) -> IntegerType | F32Type | F64Type | RankedTensorType: """Infer MLIR type (`ir.Type`) from supported python values. Note ints and floats are mapped to 64-bit types. @@ -336,7 +345,7 @@ def new_dec(*args, **kwargs): @dataclass class Successor: - op: Union[OpView, Operation] + op: OpView | Operation operands: List[Value] block: Block pos: int @@ -350,7 +359,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @contextlib.contextmanager -def bb(*preds: Tuple[Union[Successor, OpView]]): +def bb(*preds: Tuple[Successor | OpView]): current_ip = InsertionPoint.current op = current_ip.block.owner op_region = op.regions[0] diff --git a/python/utils/trace.py b/python/utils/trace.py index 04cfe3f280..382f5c81d9 100644 --- a/python/utils/trace.py +++ b/python/utils/trace.py @@ -13,11 +13,11 @@ class GenericEvent: - def __init__(self, code: typing.Union[CoreEvent, MemEvent, PLEvent, MemTileEvent]): + def __init__(self, code: CoreEvent | MemEvent | PLEvent | MemTileEvent): # For backwards compatibility, allow integer as event if isinstance(code, int): code = CoreEvent(code) - self.code: typing.Union[CoreEvent, MemEvent, PLEvent, MemTileEvent] = code + self.code: CoreEvent | MemEvent | PLEvent | MemTileEvent = code def get_register_writes(self): """