forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_content_store.py
135 lines (117 loc) · 4.73 KB
/
test_content_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Owner(s): ["oncall: pt2"]
import tempfile
import unittest
import torch
from torch._prims.debug_prims import load_tensor_reader
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfRocm,
TestCase,
)
from torch.utils._content_store import (
ContentStoreReader,
ContentStoreWriter,
hash_storage,
)
@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
class TestContentStore(TestCase):
def test_basic(self, device):
# setup test data
x = torch.randn(4, device=device)
y = torch.randn(6, device=device)
z = x.view(2, 2)
# start writing
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
writer.write_tensor("x", x)
writer.write_tensor("y", y)
writer.write_tensor("z", z)
# do some mutation that is VC UNTRACKED
x.data.add_(1)
writer.write_tensor("x2", x)
writer.write_tensor("y2", y)
writer.write_tensor("z2", z)
del writer
reader = ContentStoreReader(loc)
n_x = reader.read_tensor("x")
n_y = reader.read_tensor("y")
n_z = reader.read_tensor("z")
self.assertEqual(n_x + 1, x)
self.assertEqual(n_y, y)
self.assertEqual(n_z + 1, z)
self.assertEqual(
StorageWeakRef(n_x.untyped_storage()),
StorageWeakRef(n_z.untyped_storage()),
)
n_x2 = reader.read_tensor("x2")
n_y2 = reader.read_tensor("y2")
n_z2 = reader.read_tensor("z2")
self.assertEqual(n_x2, x)
self.assertEqual(n_y2, y)
self.assertEqual(n_z2, z)
self.assertEqual(
StorageWeakRef(n_y2.untyped_storage()),
StorageWeakRef(n_y.untyped_storage()),
)
def test_scalar(self, device):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@torch._dynamo.config.patch(cache_size_limit=1)
def test_repeated_hash(self, device):
# Test that repeated hashing doesn't trigger a recompile in dynamo
# If it does, we will execute prims.xor_sum in eager which fails
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
x = torch.randn(4, device=device)
def same_meta_as_x(t):
self.assertEqual(t.size(), x.size())
self.assertEqual(t.stride(), x.stride())
self.assertEqual(t.dtype, x.dtype)
self.assertEqual(t.device, x.device)
writer.write_tensor("x", x)
with load_tensor_reader(loc):
x2 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x2)
x3 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x3)
# Must not alias!
self.assertNotEqual(
StorageWeakRef(x.untyped_storage()),
StorageWeakRef(x2.untyped_storage()),
)
self.assertNotEqual(
StorageWeakRef(x2.untyped_storage()),
StorageWeakRef(x3.untyped_storage()),
)
# Check fake tensor mode works too
with FakeTensorMode():
x4 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
same_meta_as_x(x6)
instantiate_device_type_tests(TestContentStore, globals())
if __name__ == "__main__":
run_tests()