Skip to content

Commit

Permalink
remove mutate, fix t_unit_to_python
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Mar 6, 2024
1 parent 0dc0c75 commit c190e51
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,15 @@ def get_typed_and_scheduled_translation_unit_uncached(
# FIXME: This is not so nice. This transfers types from the
# subarrays of sep-tagged arrays to the 'main' array, because
# type inference fails otherwise.
with arg_to_dtype.mutate() as mm:
for name, sep_info in self.sep_info.items():
if entry_knl.arg_dict[name].dtype is None:
for sep_name in sep_info.subarray_names.values():
if sep_name in arg_to_dtype:
mm.set(name, arg_to_dtype[sep_name])
del mm[sep_name]

arg_to_dtype = mm.finish()
mm = dict(arg_to_dtype)
for name, sep_info in self.sep_info.items():
if entry_knl.arg_dict[name].dtype is None:
for sep_name in sep_info.subarray_names.values():
if sep_name in arg_to_dtype:
mm[name] = arg_to_dtype[sep_name]
del mm[sep_name]

arg_to_dtype = Map(mm)

from loopy.kernel.tools import add_dtypes
t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype))
Expand Down
2 changes: 1 addition & 1 deletion loopy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def t_unit_to_python(t_unit, var_name="t_unit",
"import loopy as lp",
"import numpy as np",
"from pymbolic.primitives import *",
"import constantdict",
"from constantdict import constantdict",
])
body_str = "\n".join(knl_python_code_srcs + ["\n", merge_stmt])

Expand Down

0 comments on commit c190e51

Please sign in to comment.