From 5875df57031bd0c6c258065902f014bb6aad8737 Mon Sep 17 00:00:00 2001 From: Sami Abu-el-haija Date: Wed, 21 Sep 2022 10:54:38 -0700 Subject: [PATCH] Support "valid_mask" for sampled edges. Mask is useful in two cases: (1) sampling, with or without replacement, for (outgoing) edges for nodes with zero (outgoing) neighbors; and (2) sampling without replacement, when asking for more neighbors than out-degree. PiperOrigin-RevId: 475870387 --- examples/in_memory/int_arithmetic_sampler.py | 178 +++++++++++++----- .../in_memory/int_arithmetic_sampler_test.py | 100 ++++++++-- examples/in_memory/keras_minibatch_trainer.py | 4 +- 3 files changed, 216 insertions(+), 66 deletions(-) diff --git a/examples/in_memory/int_arithmetic_sampler.py b/examples/in_memory/int_arithmetic_sampler.py index 8e9e7db9..6e82d00f 100644 --- a/examples/in_memory/int_arithmetic_sampler.py +++ b/examples/in_memory/int_arithmetic_sampler.py @@ -16,7 +16,7 @@ The entry point is method `make_sampled_subgraphs_dataset()`, which accepts as input, an in-memory graph dataset (from dataset.py) and `SamplingSpec`, and -outputs tf.data.Dataset that generates subgraphs according to `SamplingSpec`. +returns tf.data.Dataset that generates subgraphs according to `SamplingSpec`. Specifically, `tf.data.Dataset` made by `make_sampled_subgraphs_dataset` wraps a generator that yields `GraphTensor`, consisting of sub-graphs, rooted at @@ -32,9 +32,10 @@ inmem_ds = datasets.get_dataset(dataset_name) # Craft sampling specification. +sample_size = sample_size1 = 5 graph_schema = dataset_wrapper.export_graph_schema() sampling_spec = (tfgnn.SamplingSpecBuilder(graph_schema) - .seed().sample([3, 3]).to_sampling_spec()) + .seed().sample([sample_size, sample_size1]).to_sampling_spec()) train_data = make_sampled_subgraphs_dataset(inmem_ds, sampling_spec) @@ -46,20 +47,6 @@ # composed of `tfgnn.keras.layers`. ``` -# Note - -This particular sampler expects that there are *no orphan nodes*. In particular, -if sampling specification samples from edge-set with name "E", then every node -must have *at least one* outgoing edge in edge-set "E". This "feature" can be -fixed, e.g., by allowing zero-degree nodes to jump to special node, then get -filtered upon output. However, we delay such completeness until we compare other -sampling implementations e.g. ones that uses RaggedTensors to naturally -accomodate variable-length neighborhoods. - -Nonetheless, if each node has at least one-edge, then sampling will be correct. -If some node has less neighbors than required samples, then selection will -contain repeatitions. - # Algorithm & Implementation `make_sampled_subgraphs_dataset(ds)` returns a generator over object @@ -99,11 +86,11 @@ class exposes function `random_walk_tree`, which describe below. Both of which are stored as tf.Tensor. -After initialization, function `random_walk_tree` accepts(*) seed nodes +After initialization, function `random_walk_tree` accepts seed nodes `[n1, n2, n3, ..., nB]`, i.e. with batch size `B`. -NOTE: (*) generator `make_sampled_subgraphs_dataset` yield `GraphTensor` +NOTE: generator `make_sampled_subgraphs_dataset` yield `GraphTensor` instances, each instance contain subgraphs rooted at a batch of nodes, which cycle from `ds.node_split().train`. @@ -112,20 +99,31 @@ class exposes function `random_walk_tree`, which describe below. ``` sample(f1, 'cites') paper --------------------------> paper - \ - \ sample(f2, 'rev_writes') sample(f3, 'affiliated_with') - ---------------------------> author ------------------> institution + V1 \ V2 + \ sample(f2, 'rev_writes') sample(f3, 'affiliated_with') + ---------------------------> author ------------------> institution + V3 V4 ``` -Instance nodes of `TypedWalkTree` (above) have attribute `nodes` with shapes: -(B), (B, f1), (B, f2), (B, f2, f3) -- (left-to-right). All are `tf.Tensor`s -with dtype `tf.int{32, 64}`, matching the dtype of its input argument. +Instance nodes of `TypedWalkTree` (above) have attribute `nodes`, which is +`tf.Tensor`, depicted as V1, V2, V3, V4 with shapes, respectively (B), (B, f1), +(B, f2), (B, f2, f3). All are with dtype `tf.int{32, 64}`, matching the dtype of +input argument to function `random_walk_tree`. For some node position (i), then +node `V1[i]` has sampled edges pointing to nodes `V2[i, 0], V2[i, 1], ...`. The +(`int`) `B` corresponds to batch size and (`int`s) `f1, f2, ...` correspond to +`sample_size` that can be configured in `SamplingSpec` proto (below). + +Further, if `sampling` strategy is one of `EdgeSampling.W_REPLACEMENT_W_ORPHANS` +or `EdgeSampling.WO_REPLACEMENT_WO_REPEAT`, then each `TypedWalkTree` node will +also contain attribute `valid` (tf.Tensor with dtype tf.bool) with same shape as +`nodes`, which marks positions in `nodes` that correspond to valid edges. + +## Building SamplingSpec Function `random_walk_tree` also requires argument `sampling_spec`, which controls the subgraph size, sampled around seed nodes. For the above example, `sampling_spec` instance can be built as, e.g.,: - ``` f2 = f1 = 5 f3 = 3 # somewhat arbitrary. @@ -143,7 +141,7 @@ class exposes function `random_walk_tree`, which describe below. import collections import enum import functools -from typing import Any, Tuple, Callable, Mapping, Optional, MutableMapping, List +from typing import Any, Tuple, Callable, Mapping, Optional, MutableMapping, List, Union import numpy as np import scipy.sparse as ssp @@ -225,21 +223,45 @@ class TypedWalkTree: `TypedWalkTree`) with node features & labels, into `GraphTensor` instances. """ - def __init__(self, nodes, owner=None): + def __init__(self, nodes: tf.Tensor, owner: Optional['GraphSampler'] = None, + valid_mask: Optional[tf.Tensor] = None): self._nodes = nodes self._next_steps = [] self._owner = owner + if valid_mask is None: + self._valid_mask = tf.ones(shape=nodes.shape, dtype=tf.bool) + else: + self._valid_mask = valid_mask @property def nodes(self) -> tf.Tensor: + """int tf.Tensor with shape `[b, s1, s2, ..., sH]` where `b` is batch size. + + `H` is number of hops (until this sampling step). Each int `si` indicates + number of nodes sampled at step `i`. + """ return self._nodes + @property + def valid_mask(self) -> Optional[tf.Tensor]: + """bool tf.Tensor with same shape of `nodes` marking "correct" samples. + + If entry `valid_mask[i, j, k]` is True, then `nodes[i, j, k]` corresponds to + a node that is indeed a sampled neighbor of `previous_step.nodes[i, j]`. + """ + return self._valid_mask + @property def next_steps(self) -> List[Tuple[tfgnn.EdgeSetName, 'TypedWalkTree']]: return self._next_steps - def add_step(self, edge_set_name: tfgnn.EdgeSetName, nodes: tf.Tensor): - child_tree = TypedWalkTree(nodes, owner=self._owner) + def add_step(self, edge_set_name: tfgnn.EdgeSetName, nodes: tf.Tensor, + valid_mask: Optional[tf.Tensor] = None, + propagate_validation: bool = True) -> 'TypedWalkTree': + if propagate_validation and valid_mask is not None: + valid_mask = tf.logical_and(tf.expand_dims(self.valid_mask, -1), + valid_mask) + child_tree = TypedWalkTree(nodes, owner=self._owner, valid_mask=valid_mask) self._next_steps.append((edge_set_name, child_tree)) return child_tree @@ -342,8 +364,30 @@ def to_graph_tensor( class EdgeSampling(enum.Enum): - WITH_REPLACEMENT = 'with_replacement' - WITHOUT_REPLACEMENT = 'without_replacement' + """Enum for randomized strategies for sampling neighbors.""" + # Samples each neighbor independently. It assumes that *every node* has at + # least one outgoing neighbor, for all sampled edge-sets. + W_REPLACEMENT = 'w_replacement' + + # Samples each neighbor independently. It assumes that some nodes might have + # zero outgoing edges. This option causes `sample_one_hop()` to also return + # `valid_mask` (boolean tf.Tensor) marking positions corresponding to an + # actual edge, which will be False iff sampling from orphan nodes. + W_REPLACEMENT_W_ORPHANS = 'w_replacement_w_orphans' + + # Samples neighbors without replacement. However, if (int) `S` neighbors were + # requested, and there are only `s` neighbors (with `s < S`), then the samples + # will be repeated. You *must* ensure that each node has at least one outgoing + # neighbor. If your graph has orphan nodes, use `WO_REPLACEMENT_WO_REPEAT` or + # `W_REPLACEMENT_W_ORPHANS`. + WO_REPLACEMENT = 'wo_replacement' + + # Like the above. In cases if some nodes have very few neighbors (less than + # `sample_size`), then nodes will only be sampled once. This option also works + # when some nodes have zero outgoing edges. + # This option causes `sample_one_hop()` to also return `valid_mask` (boolean + # tf.Tensor) marking positions corresponding to an actual edge. + WO_REPLACEMENT_WO_REPEAT = 'wo_replacement_wo_repeat' class GraphSampler: @@ -359,7 +403,7 @@ def __init__(self, make_undirected: bool = False, ensure_self_loops: bool = False, reduce_memory_footprint: bool = True, - sampling: EdgeSampling = EdgeSampling.WITHOUT_REPLACEMENT): + sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT): self.dataset = dataset self.sampling = sampling self.edge_types = {} # edge set name -> (src node set name, dst *). @@ -416,36 +460,51 @@ def make_sample_layer(self, edge_set_name, sample_size=3, sampling=None): def sample_one_hop( self, source_nodes: tf.Tensor, edge_set_name: tfgnn.EdgeSetName, - sample_size: int, sampling: Optional[EdgeSampling] = None) -> tf.Tensor: + sample_size: int, sampling: Optional[EdgeSampling] = None, + ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]: """Samples one-hop from source-nodes using edge `edge_set_name`.""" if sampling is None: - sampling = EdgeSampling.WITH_REPLACEMENT + sampling = EdgeSampling.WO_REPLACEMENT all_degrees = self.degrees[edge_set_name] node_degrees = tf.gather(all_degrees, source_nodes) offsets = self.degrees_cumsum[edge_set_name] - if sampling == EdgeSampling.WITH_REPLACEMENT: + next_nodes = valid_mask = None # Answer, to be populated, below. + + if sampling in (EdgeSampling.W_REPLACEMENT, + EdgeSampling.W_REPLACEMENT_W_ORPHANS): sample_indices = tf.random.uniform( shape=source_nodes.shape + [sample_size], minval=0, maxval=1, dtype=tf.float32) - sample_indices = sample_indices * tf.cast( - tf.expand_dims(node_degrees, -1), tf.float32) + node_degrees_expanded = tf.expand_dims(node_degrees, -1) + sample_indices = sample_indices * tf.cast(node_degrees_expanded, + tf.float32) # According to https://www.pcg-random.org/posts/bounded-rands.html, this # sample is biased. NOTE: we plan to adopt one of the linked alternatives. sample_indices = tf.cast(tf.math.floor(sample_indices), tf.int64) + + if sampling == EdgeSampling.W_REPLACEMENT_W_ORPHANS: + valid_mask = sample_indices < node_degrees_expanded + # Shape: (sample_size, nodes_reshaped.shape[0]) sample_indices += tf.expand_dims(tf.gather(offsets, source_nodes), -1) nonzero_cols = self.edge_lists[edge_set_name][1] + if sampling == EdgeSampling.W_REPLACEMENT_W_ORPHANS: + sample_indices = tf.where( + valid_mask, sample_indices, tf.zeros_like(sample_indices)) next_nodes = tf.gather(nonzero_cols, sample_indices) - elif sampling == EdgeSampling.WITHOUT_REPLACEMENT: + elif sampling in (EdgeSampling.WO_REPLACEMENT, + EdgeSampling.WO_REPLACEMENT_WO_REPEAT): # shape=(total_input_nodes). nodes_reshaped = tf.reshape(source_nodes, [-1]) # shape=(total_input_nodes). reshaped_node_degrees = tf.reshape(node_degrees, [-1]) + reshaped_node_degrees_or_1 = tf.maximum( + reshaped_node_degrees, tf.ones_like(reshaped_node_degrees)) # shape=(sample_size, total_input_nodes). sample_upto = tf.stack([reshaped_node_degrees] * sample_size, axis=0) @@ -453,7 +512,13 @@ def sample_one_hop( subtract_mod = tf.stack( [tf.range(sample_size, dtype=tf.int64)] * nodes_reshaped.shape[0], axis=-1) - subtract_mod = subtract_mod % sample_upto + if sampling == EdgeSampling.WO_REPLACEMENT_WO_REPEAT: + valid_mask = subtract_mod < reshaped_node_degrees + valid_mask = tf.reshape( + tf.transpose(valid_mask), source_nodes.shape + [sample_size]) + + subtract_mod = subtract_mod % tf.maximum( + sample_upto, tf.ones_like(sample_upto)) # [[d, d-1, d-2, ... 1, d, d-1, ...]].T # where 'd' is degree of node in row corresponding to nodes_reshaped. @@ -475,7 +540,7 @@ def sample_one_hop( for i in range(1, sample_size): already_sampled = tf.where( - i % reshaped_node_degrees == 0, + i % reshaped_node_degrees_or_1 == 0, tf.ones_like(already_sampled) * max_degree, already_sampled) next_sample = sample_indices[i] for j in range(i): @@ -493,10 +558,13 @@ def sample_one_hop( sample_indices += tf.expand_dims(tf.gather(offsets, nodes_reshaped), 0) sample_indices = tf.reshape(tf.transpose(sample_indices), - [source_nodes.shape[0], -1]) + source_nodes.shape + [sample_size]) nonzero_cols = self.edge_lists[edge_set_name][1] + if sampling == EdgeSampling.WO_REPLACEMENT_WO_REPEAT: + sample_indices = tf.where( + valid_mask, sample_indices, tf.zeros_like(sample_indices)) + next_nodes = tf.gather(nonzero_cols, sample_indices) - next_nodes = tf.reshape(next_nodes, source_nodes.shape + [sample_size]) else: raise ValueError('Unknown sampling ' + str(sampling)) @@ -504,13 +572,16 @@ def sample_one_hop( # It could happen, e.g., if edge-list is int32 and input seed is int64. next_nodes = tf.cast(next_nodes, source_nodes.dtype) - return next_nodes + if valid_mask is None: + return next_nodes + else: + return next_nodes, valid_mask def generate_subgraphs( self, batch_size: int, sampling_spec: sampling_spec_pb2.SamplingSpec, split: str = 'train', - sampling=EdgeSampling.WITH_REPLACEMENT): + sampling=EdgeSampling.WO_REPLACEMENT): """Infinitely yields random subgraphs each rooted on node in train set.""" if isinstance(split, bytes): split = split.decode() @@ -532,7 +603,7 @@ def generate_subgraphs( def random_walk_tree( self, node_idx: tf.Tensor, sampling_spec: sampling_spec_pb2.SamplingSpec, - sampling: EdgeSampling = EdgeSampling.WITH_REPLACEMENT) -> TypedWalkTree: + sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT) -> TypedWalkTree: """Returns `TypedWalkTree` where `nodes` are seed root-nodes. Args: @@ -566,8 +637,13 @@ def process_sampling_op(sampling_op: sampling_spec_pb2.SamplingOp): next_nodes = self.sample_one_hop( parent_nodes, sampling_op.edge_set_name, sample_size=sampling_op.sample_size, sampling=sampling) - child_tree = parent_trees[0].add_step( - sampling_op.edge_set_name, next_nodes) + if isinstance(next_nodes, tuple): + next_nodes, valid_mask = next_nodes + child_tree = parent_trees[0].add_step( + sampling_op.edge_set_name, next_nodes, valid_mask=valid_mask) + else: + child_tree = parent_trees[0].add_step( + sampling_op.edge_set_name, next_nodes) op_name_to_tree[sampling_op.op_name] = child_tree @@ -581,7 +657,7 @@ def process_sampling_op(sampling_op: sampling_spec_pb2.SamplingOp): def sample_sub_graph_tensor( self, node_idx: tf.Tensor, sampling_spec: sampling_spec_pb2.SamplingSpec, - sampling: EdgeSampling = EdgeSampling.WITH_REPLACEMENT + sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT ) -> tfgnn.GraphTensor: """Samples GraphTensor starting from seed nodes `node_idx`. @@ -589,9 +665,9 @@ def sample_sub_graph_tensor( node_idx: (int) tf.Tensor of node indices to seed random-walk trees. sampling_spec: Specifies the hops (edge set names) to be sampled, and the number of sampled edges per hop. - sampling: If `== EdgeSampling.WITH_REPLACEMENT`, then neighbors for a node + sampling: If `== EdgeSampling.W_REPLACEMENT`, then neighbors for a node will be sampled uniformly and indepedently. If - `== EdgeSampling.WITHOUT_REPLACEMENT`, then a node's neighbors will be + `== EdgeSampling.WO_REPLACEMENT`, then a node's neighbors will be chosen in (random) round-robin order. If more samples are requested are larger than neighbors, then the samples will be repeated (each time, in a different random order), such that, all neighbors appears exactly the @@ -621,7 +697,7 @@ def make_sampled_subgraphs_dataset( batch_size: int = 64, split='train', make_undirected: bool = False, - sampling=EdgeSampling.WITH_REPLACEMENT + sampling=EdgeSampling.WO_REPLACEMENT ) -> Tuple[tf.TensorSpec, tf.data.Dataset]: """Infinite tf.data.Dataset wrapping generate_subgraphs.""" subgraph_generator = GraphSampler(dataset, make_undirected=make_undirected) diff --git a/examples/in_memory/int_arithmetic_sampler_test.py b/examples/in_memory/int_arithmetic_sampler_test.py index 6e5047d5..6e509151 100644 --- a/examples/in_memory/int_arithmetic_sampler_test.py +++ b/examples/in_memory/int_arithmetic_sampler_test.py @@ -25,6 +25,7 @@ import datasets from tensorflow_gnn.examples.in_memory import int_arithmetic_sampler as ia_sampler +from tensorflow_gnn.sampler import sampling_spec_builder class ToyDataset(datasets.NodeClassificationDatasetWrapper): @@ -35,6 +36,7 @@ def __init__(self): 'cat': ['spider', 'rat', 'water'], 'monkey': ['banana', 'water'], 'cow': ['banana', 'water'], + 'unicorn': [], } food_set = set() @@ -69,6 +71,25 @@ def node_features_dicts(self, add_id=True) -> Mapping[ } } + def num_classes(self) -> int: + return 10 + + def node_split(self) -> datasets.NodeSplit: + empty_vector = tf.zeros([0], dtype=tf.int32) + return datasets.NodeSplit( + train=tf.range(len(self.id2animal), dtype=tf.int32), + test=empty_vector, valid=empty_vector) + + def labels(self) -> tf.Tensor: + return tf.ones(len(self.id2animal), dtype=tf.int32) + + def test_labels(self) -> tf.Tensor: + return self.labels() + + @property + def labeled_nodeset(self) -> str: + return 'animals' + def node_counts(self) -> Mapping[tfgnn.NodeSetName, int]: return {'food': len(self.id2food), 'animals': len(self.id2animal)} @@ -78,11 +99,11 @@ def edge_lists(self) -> Mapping[Tuple[str, str, str], tf.Tensor]: } -class IntArithmeticSamplerTest(parameterized.TestCase): +class IntArithmeticSamplerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( - ('WithReplacement', ia_sampler.EdgeSampling.WITH_REPLACEMENT), - ('WithoutReplacement', ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT)) + ('WithReplacement', ia_sampler.EdgeSampling.W_REPLACEMENT), + ('WithoutReplacement', ia_sampler.EdgeSampling.WO_REPLACEMENT)) def test_sample_one_hop(self, strategy): toy_dataset = ToyDataset() sampler = ia_sampler.GraphSampler(toy_dataset) @@ -95,7 +116,8 @@ def test_sample_one_hop(self, strategy): next_hop = sampler.sample_one_hop( tf.convert_to_tensor(source_node_ids), 'eats', sample_size=sample_size, sampling=strategy) - + self.assertIsInstance(next_hop, tf.Tensor) + assert isinstance(next_hop, tf.Tensor) # Assert needed to access `.shape`. self.assertEqual(next_hop.shape[0], source_node_ids.shape[0]) self.assertEqual(next_hop.shape[1], sample_size) @@ -118,14 +140,14 @@ def test_sample_one_hop(self, strategy): # Each node has 2 neighbors. Make sure that nothing is "sampled too much" self.assertGreater(min(sampled_counts.values()), sample_size // 10) - if strategy == ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT: - # WITHOUT_REPLACEMENT is fair. It will pick each item once, then + if strategy == ia_sampler.EdgeSampling.WO_REPLACEMENT: + # WO_REPLACEMENT is fair. It will pick each item once, then # re-iterate (in another random order). self.assertEqual(min(sampled_counts.values()), sample_size // 2) @parameterized.named_parameters( - ('WithReplacement', ia_sampler.EdgeSampling.WITH_REPLACEMENT), - ('WithoutReplacement', ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT)) + ('WithReplacement', ia_sampler.EdgeSampling.W_REPLACEMENT), + ('WithoutReplacement', ia_sampler.EdgeSampling.WO_REPLACEMENT)) def test_sample_two_hops(self, strategy): toy_dataset = ToyDataset() sampler = ia_sampler.GraphSampler(toy_dataset) @@ -139,11 +161,13 @@ def test_sample_two_hops(self, strategy): hop1 = sampler.sample_one_hop( tf.convert_to_tensor(source_node_ids), 'eats', sample_size=hop1_size, sampling=strategy) - + self.assertIsInstance(hop1, tf.Tensor) + assert isinstance(hop1, tf.Tensor) # Assert needed to access `.shape`. hop2_size = 20 hop2 = sampler.sample_one_hop( hop1, 'rev_eats', sample_size=hop2_size, sampling=strategy) - + self.assertIsInstance(hop2, tf.Tensor) + assert isinstance(hop2, tf.Tensor) # Assert needed to access `.shape`. self.assertEqual(hop1.shape, (batch_size, hop1_size)) self.assertEqual(hop2.shape, (batch_size, hop1_size, hop2_size)) np_edgelist = toy_dataset.eats_edgelist.numpy() @@ -169,8 +193,8 @@ def test_sample_two_hops(self, strategy): # Each node has 2 neighbors. Make sure that nothing is "sampled too much" self.assertGreater(min(sampled_counts.values()), hop1_size // 10) - if strategy == ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT: - # WITHOUT_REPLACEMENT is fair. It will pick each item once, then + if strategy == ia_sampler.EdgeSampling.WO_REPLACEMENT: + # WO_REPLACEMENT is fair. It will pick each item once, then # re-iterate (in another random order). self.assertEqual(min(sampled_counts.values()), hop1_size // 2) @@ -185,11 +209,61 @@ def test_sample_two_hops(self, strategy): self.assertEmpty( set(sampled_hop2_counts.keys()).difference(actual_edge_set)) - if strategy == ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT: + if strategy == ia_sampler.EdgeSampling.WO_REPLACEMENT: max_count = max(sampled_hop2_counts.values()) min_count = min(sampled_hop2_counts.values()) self.assertLessEqual(max_count - min_count, 1) # Fair sampling. + @parameterized.named_parameters( + ('WithReplacementWithOrphans', + ia_sampler.EdgeSampling.W_REPLACEMENT_W_ORPHANS), + ('WithoutReplacementWithoutRepeat', + ia_sampler.EdgeSampling.WO_REPLACEMENT_WO_REPEAT)) + def test_sample_random_walk_tree_with_validation(self, strategy): + toy_dataset = ToyDataset() + sampler = ia_sampler.GraphSampler(toy_dataset) + source_node_names = ['dog', 'unicorn'] + source_node_ids = [toy_dataset.animal2id[name] + for name in source_node_names] + source_node_ids = np.array(source_node_ids) + source_node_ids = tf.convert_to_tensor(source_node_ids) + + toy_graph_schema = toy_dataset.export_graph_schema() + + hop1_samples = hop2_samples = 10 + spec = sampling_spec_builder.SamplingSpecBuilder( + toy_graph_schema, + default_strategy=sampling_spec_builder.SamplingStrategy.RANDOM_UNIFORM) + spec = (spec.seed('animals').sample(hop1_samples, 'eats') + .sample(hop2_samples, 'rev_eats').to_sampling_spec()) + walk_tree = sampler.random_walk_tree( + source_node_ids, spec, sampling=strategy) + + # Root node contains source nodes, all of which are valid. + self.assertAllEqual(walk_tree.nodes, source_node_ids) + self.assertAllEqual(walk_tree.valid_mask, + tf.ones(shape=source_node_ids.shape, dtype=tf.bool)) + self.assertLen(walk_tree.next_steps, 1) # Sampled one edge from root. + self.assertEqual(walk_tree.next_steps[0][0], 'eats') # Sampled edge 'eats'. + hop1 = walk_tree.next_steps[0][1] + + self.assertLen(hop1.next_steps, 1) # Sampled one edge from hop1. + self.assertEqual(hop1.next_steps[0][0], 'rev_eats') + hop2 = hop1.next_steps[0][1] + + if strategy == ia_sampler.EdgeSampling.W_REPLACEMENT_W_ORPHANS: + self.assertTrue(np.all(hop1.valid_mask[0])) # dog eats some things. + self.assertFalse(np.any(hop1.valid_mask[1])) # unicorn eats nothing. + # Validity should be propagated. + self.assertFalse(np.any(hop2.valid_mask[1])) + elif strategy == ia_sampler.EdgeSampling.WO_REPLACEMENT_WO_REPEAT: + self.assertTrue(np.all(hop1.valid_mask[0, :2])) # dog eats 2 things. + self.assertFalse(np.any(hop1.valid_mask[0, 2:])) # dog eats 2 things. + self.assertFalse(np.any(hop1.valid_mask[1])) # unicorn eats nothing + # Validity should be propagated. + self.assertFalse(np.any(hop2.valid_mask[0, 2:])) + self.assertFalse(np.any(hop2.valid_mask[1])) + if __name__ == '__main__': tf.test.main() diff --git a/examples/in_memory/keras_minibatch_trainer.py b/examples/in_memory/keras_minibatch_trainer.py index d0706aa1..ec0012a0 100644 --- a/examples/in_memory/keras_minibatch_trainer.py +++ b/examples/in_memory/keras_minibatch_trainer.py @@ -97,7 +97,7 @@ def init_node_state(node_set, node_set_name): _, train_dataset = int_arithmetic_sampler.make_sampled_subgraphs_dataset( dataset_wrapper, sampling_spec=train_sampling_spec, batch_size=FLAGS.batch_size, - sampling=int_arithmetic_sampler.EdgeSampling.WITH_REPLACEMENT, + sampling=int_arithmetic_sampler.EdgeSampling.W_REPLACEMENT, make_undirected=prefers_undirected) train_labels_dataset = train_dataset.map( @@ -107,7 +107,7 @@ def init_node_state(node_set, node_set_name): _, validation_ds = int_arithmetic_sampler.make_sampled_subgraphs_dataset( dataset_wrapper, sampling_spec=train_sampling_spec, batch_size=FLAGS.batch_size, - sampling=int_arithmetic_sampler.EdgeSampling.WITHOUT_REPLACEMENT, + sampling=int_arithmetic_sampler.EdgeSampling.WO_REPLACEMENT, split='valid', make_undirected=prefers_undirected) validation_ds = validation_ds.map( functools.partial(reader_utils.pair_graphs_with_labels, num_classes))