Skip to content

Commit

Permalink
restructure
Browse files Browse the repository at this point in the history
- don't use dataclass anymore, implement update_persistent_hash
- test _sym_tag_to_num_tag
- (hopefully) fix __eq__
  • Loading branch information
matthiasdiener committed Sep 18, 2023
1 parent 9a63df4 commit 2b6630d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 27 deletions.
12 changes: 10 additions & 2 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,22 @@ def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair:
return local_interior_trace_pair(dcoll, vec)


@dataclass(frozen=True) # for KeyBuilder support
class CommTag:
"""A communication tag with a hash value that is stable across
runs, even without setting ``PYTHONHASHSEED``."""

@memoize_method
def __hash__(self):
def __hash__(self) -> int:
return hash(tuple(str(type(self)).encode("ascii")))

def __eq__(self, other: object) -> bool:
return type(self) is type(other)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, (self.__class__.__module__,
self.__class__.__qualname__))



def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *,
comm_tag: Optional[Hashable] = None, tag: Hashable = None,
Expand Down
30 changes: 5 additions & 25 deletions test/test_trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


import numpy as np
from grudge.trace_pair import TracePair, CommTag
from grudge.trace_pair import TracePair
import meshmode.mesh.generation as mgen
from meshmode.dof_array import DOFArray
from dataclasses import dataclass
Expand Down Expand Up @@ -72,6 +72,8 @@ def rand():

def test_commtag(actx_factory):

from grudge.trace_pair import CommTag, _sym_tag_to_num_tag

class DerivedCommTag(CommTag):
pass

Expand All @@ -86,6 +88,8 @@ class DerivedDerivedCommTag(DerivedCommTag):
dct2 = DerivedCommTag()
ddct = DerivedDerivedCommTag()

assert _sym_tag_to_num_tag(ct) == 441551355

assert ct == ct2
assert ct != dct
assert dct == dct2
Expand All @@ -110,27 +114,3 @@ class DerivedDerivedCommTag(DerivedCommTag):
assert hash((dct, ct)) == 6599529611285265043

# }}}

# {{{ test using derived dataclasses

@dataclass(frozen=True)
class DataCommTag(CommTag):
data: int

@dataclass(frozen=True)
class DataCommTag2(CommTag):
data: int

d1 = DataCommTag(1)
d2 = DataCommTag(2)
d3 = DataCommTag(1)

assert d1 != d2
assert hash(d1) != hash(d2)
assert d1 == d3
assert hash(d1) == hash(d3)

d4 = DataCommTag2(1)
assert d1 != d4

# }}}

0 comments on commit 2b6630d

Please sign in to comment.