Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pymbolic typing #868

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,15 @@
# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"Expression"],
]

autodoc_type_aliases = {
"ToLoopyTypeConvertible": "ToLoopyTypeConvertible",
"ExpressionT": "ExpressionT",
"InameStr": "InameStr",
"ShapeType": "ShapeType",
"StridesType": "StridesType",
}
5 changes: 4 additions & 1 deletion loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
)
from loopy.translation_unit import TranslationUnit, for_each_kernel, make_program
from loopy.type_inference import infer_unknown_types
from loopy.types import to_loopy_type
from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible, to_loopy_type
from loopy.typing import auto
from loopy.version import MOST_RECENT_LANGUAGE_VERSION, VERSION

Expand Down Expand Up @@ -248,12 +248,14 @@
"LinearSubscript",
"LoopKernel",
"LoopyError",
"LoopyType",
"LoopyWarning",
"MemAccess",
"MemoryOrdering",
"MemoryScope",
"MultiAssignmentBase",
"NoOpInstruction",
"NumpyType",
"Op",
"OpenCLTarget",
"Optional",
Expand All @@ -270,6 +272,7 @@
"TemporaryVariable",
"ToCountMap",
"ToCountPolynomialMap",
"ToLoopyTypeConvertible",
"TranslationUnit",
"TypeCast",
"UniqueName",
Expand Down
3 changes: 3 additions & 0 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic.primitives import Variable
from pytools import memoize_method

from loopy.diagnostic import (
Expand Down Expand Up @@ -1669,6 +1670,8 @@ def _are_sub_array_refs_equivalent(
if len(sar1.swept_inames) != len(sar2.swept_inames):
return False

assert isinstance(sar1.subscript.aggregate, Variable)
assert isinstance(sar2.subscript.aggregate, Variable)
if sar1.subscript.aggregate.name != sar2.subscript.aggregate.name:
return False

Expand Down
16 changes: 15 additions & 1 deletion loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from immutables import Map

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
from loopy.translation_unit import CallablesTable, TranslationUnit


Expand Down Expand Up @@ -86,6 +87,12 @@
.. automodule:: loopy.codegen.result

.. automodule:: loopy.codegen.tools

References
^^^^^^^^^^
.. class:: Expression

See :class:`pymbolic.Expression`.
"""


Expand Down Expand Up @@ -661,8 +668,15 @@ def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
ast=t_unit.target.get_device_ast_builder().ast_module.Collection(
callee_fdecls+[device_programs[0].ast]))] +
device_programs[1:])

def not_reduction_op(name: str | ReductionOpFunction) -> str:
assert isinstance(name, str)
return name

cgr = TranslationUnitCodeGenerationResult(
host_programs=host_programs,
host_programs={
not_reduction_op(name): prg
for name, prg in host_programs.items()},
device_programs=device_programs,
device_preambles=device_preambles)

Expand Down
3 changes: 2 additions & 1 deletion loopy/frontend/fortran/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class FortranExpressionParser(ExpressionParserBase):
(_not, pytools.lex.RE(r"\.not\.", re.I)),
(_and, pytools.lex.RE(r"\.and\.", re.I)),
(_or, pytools.lex.RE(r"\.or\.", re.I)),
] + ExpressionParserBase.lex_table
*ExpressionParserBase.lex_table,
]

def __init__(self, tree_walker):
self.tree_walker = tree_walker
Expand Down
32 changes: 13 additions & 19 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@
from pytools.tag import Tag, Taggable

from loopy.diagnostic import LoopyError
from loopy.tools import is_integer
from loopy.types import LoopyType
from loopy.typing import ExpressionT, ShapeType, auto
from loopy.typing import ExpressionT, ShapeType, auto, is_integer


if TYPE_CHECKING:
Expand Down Expand Up @@ -624,16 +623,20 @@ def _parse_shape_or_strides(
return auto

if isinstance(x, str):
x = parse(x)
x_parsed = parse(x)
else:
x_parsed = x

if isinstance(x, list):
if isinstance(x_parsed, list):
raise ValueError("shape can't be a list")

if not isinstance(x, tuple):
if isinstance(x_parsed, tuple):
x_tup: tuple[ExpressionT, ...] = x_parsed
else:
assert x is not auto
x = (x,)
x_tup = (x_parsed,)

return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x)
return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x_tup)


class ArrayBase(ImmutableRecord, Taggable):
Expand Down Expand Up @@ -1024,16 +1027,6 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.__str__()

def update_persistent_hash_for_shape(self, key_hash, key_builder, shape):
if isinstance(shape, tuple):
for shape_i in shape:
if shape_i is None:
key_builder.rec(key_hash, shape_i)
else:
key_builder.update_for_pymbolic_expression(key_hash, shape_i)
else:
key_builder.rec(key_hash, shape)

def update_persistent_hash(self, key_hash, key_builder):
"""Custom hash computation function for use with
:class:`pytools.persistent_dict.PersistentDict`.
Expand All @@ -1042,7 +1035,7 @@ def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.dtype)
self.update_persistent_hash_for_shape(key_hash, key_builder, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.dim_tags)
key_builder.rec(key_hash, self.offset)
key_builder.rec(key_hash, self.dim_names)
Expand Down Expand Up @@ -1230,11 +1223,12 @@ def get_access_info(kernel: "LoopKernel",

import loopy as lp

def eval_expr_assert_integer_constant(i, expr):
def eval_expr_assert_integer_constant(i, expr) -> int:
from pymbolic.mapper.evaluator import UnknownVariableError
try:
result = eval_expr(expr)
except UnknownVariableError as e:
assert ary.dim_tags is not None
raise LoopyError("When trying to index the array '%s' along axis "
"%d (tagged '%s'), the index was not a compile-time "
"constant (but it has to be in order for code to be "
Expand Down
12 changes: 9 additions & 3 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@
.. autoclass:: UnrollTag

.. autoclass:: Iname

References
^^^^^^^^^^

.. class:: ToLoopyTypeConvertible

See :class:`loopy.ToLoopyTypeConvertible`.
"""

# This docstring is included in ref_internals. Do not include parts of the public
Expand Down Expand Up @@ -853,8 +860,7 @@ def update_persistent_hash(self, key_hash, key_builder):
"""

super().update_persistent_hash(key_hash, key_builder)
self.update_persistent_hash_for_shape(key_hash, key_builder,
self.storage_shape)
key_builder.rec(key_hash, self.storage_shape)
key_builder.rec(key_hash, self.base_indices)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.base_storage)
Expand Down Expand Up @@ -899,7 +905,7 @@ def copy(self, **kwargs: Any) -> SubstitutionRule:
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.arguments)
key_builder.update_for_pymbolic_expression(key_hash, self.expression)
key_builder.rec(key_hash, self.expression)


# }}}
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def depends_on(self):
return frozenset(var.name for var in result)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.dim_tags)

Expand Down
37 changes: 5 additions & 32 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ class InstructionBase(ImmutableRecord, Taggable):
"within_inames_is_final within_inames "
"priority".split())

# Names of fields that are pymbolic expressions. Needed for key building
pymbolic_fields = set("")

# Names of fields that are sets of pymbolic expressions. Needed for key building
pymbolic_set_fields = {"predicates"}

def __init__(self,
id: Optional[str],
happens_after: Union[
Expand Down Expand Up @@ -545,25 +539,7 @@ def _key_builder(self):
key_builder.update_for_class(self.__class__)

for field_name in self.fields:
field_value = getattr(self, field_name)
if field_name in self.pymbolic_fields:
key_builder.update_for_pymbolic_field(field_name, field_value)
elif field_name in self.pymbolic_set_fields:
# First sort the fields, as a canonical form
items = tuple(sorted(field_value, key=str))
key_builder.update_for_pymbolic_field(field_name, items)

# from CExpression
elif field_name == "iname_exprs":
from loopy.symbolic import EqualityPreservingStringifyMapper
key_builder.field_dict[field_name] = [
(iname, EqualityPreservingStringifyMapper()(expr)
.encode("utf-8"))
for iname, expr in self.iname_exprs
]

else:
key_builder.update_for_field(field_name, field_value)
key_builder.update_for_field(field_name, getattr(self, field_name))

return key_builder

Expand Down Expand Up @@ -841,7 +817,6 @@ class MultiAssignmentBase(InstructionBase):
"""An assignment instruction with an expression as a right-hand side."""

fields = InstructionBase.fields | {"expression"}
pymbolic_fields = InstructionBase.pymbolic_fields | {"expression"}

@memoize_method
def read_dependency_names(self):
Expand Down Expand Up @@ -933,7 +908,6 @@ class Assignment(MultiAssignmentBase):

fields = MultiAssignmentBase.fields | \
set("assignee temp_var_type atomicity".split())
pymbolic_fields = MultiAssignmentBase.pymbolic_fields | {"assignee"}

def __init__(self,
assignee: Union[str, ExpressionT],
Expand Down Expand Up @@ -979,7 +953,9 @@ def __init__(self,
if isinstance(assignee, str):
assignee = parse(assignee)
if isinstance(expression, str):
expression = parse(expression)
parsed_expression = parse(expression)
else:
parsed_expression = expression

from pymbolic.primitives import Lookup, Subscript, Variable

Expand All @@ -988,7 +964,7 @@ def __init__(self,
raise LoopyError("invalid lvalue '%s'" % assignee)

self.assignee = assignee
self.expression = expression
self.expression = parsed_expression

self.temp_var_type = _check_and_fix_temp_var_type(temp_var_type)
self.atomicity = atomicity
Expand Down Expand Up @@ -1092,7 +1068,6 @@ class CallInstruction(MultiAssignmentBase):

fields = MultiAssignmentBase.fields | \
set("assignees temp_var_types".split())
pymbolic_fields = MultiAssignmentBase.pymbolic_fields | {"assignees"}

def __init__(self,
assignees, expression,
Expand Down Expand Up @@ -1404,8 +1379,6 @@ class CInstruction(InstructionBase):

fields = InstructionBase.fields | \
set("iname_exprs code read_variables assignees".split())
pymbolic_fields = InstructionBase.pymbolic_fields | \
set("assignees".split())

def __init__(self,
iname_exprs, code,
Expand Down
23 changes: 9 additions & 14 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
"""


from typing import ClassVar, Tuple

import numpy as np

from pymbolic import var
from pymbolic.primitives import expr_dataclass
from pytools.persistent_dict import Hash, KeyBuilder

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand Down Expand Up @@ -129,6 +129,10 @@ def __str__(self):

return result

def update_persistent_hash(self, key_hash: Hash, key_builder: KeyBuilder) -> None:
# They're all stateless.
key_builder.rec(key_hash, type(self))


class SumReductionOperation(ScalarReductionOperation):
def neutral_element(self, dtype, callables_table, target):
Expand Down Expand Up @@ -276,14 +280,9 @@ def __call__(self, dtype, operand1, operand2, callables_table, target):

# {{{ base class for symbolic reduction ops

@expr_dataclass()
class ReductionOpFunction(FunctionIdentifier):
init_arg_names: ClassVar[Tuple[str, ...]] = ("reduction_op",)

def __init__(self, reduction_op):
self.reduction_op = reduction_op

def __getinitargs__(self):
return (self.reduction_op,)
reduction_op: ReductionOperation

@property
def name(self):
Expand All @@ -295,11 +294,6 @@ def copy(self, reduction_op=None):

return type(self)(reduction_op)

hash_fields = (
"reduction_op",)

update_persistent_hash = update_persistent_hash

# }}}


Expand Down Expand Up @@ -413,6 +407,7 @@ class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation):

# {{{ argmin/argmax

@expr_dataclass()
class ArgExtOp(ReductionOpFunction):
pass

Expand Down
Loading
Loading