From 28b7cf5aaf957e7311b501fe336c30330d398334 Mon Sep 17 00:00:00 2001 From: Ted Goddard Date: Fri, 26 Jun 2020 16:09:55 -0600 Subject: [PATCH] Update to support Allen NLP 1.0 (#25) * Update requirements.txt * Rename decode to make_output_human_readable * Rename decode to make_output_human_readable * Rename decode to make_output_human_readable * Rename decode to make_output_human_readable * remove conversion to float * fix tests * fix configs * flake * more flake... * change test names * fix more tests Co-authored-by: Matt Gardner --- allennlp_semparse/dataset_readers/atis.py | 2 +- .../domain_languages/domain_language.py | 5 + .../fields/knowledge_graph_field.py | 3 + .../models/atis/atis_semantic_parser.py | 9 +- .../nlvr/nlvr_coverage_semantic_parser.py | 20 +--- .../models/nlvr/nlvr_semantic_parser.py | 6 +- allennlp_semparse/models/text2sql_parser.py | 6 +- .../wikitables_mml_semantic_parser.py | 3 + .../wikitables/wikitables_semantic_parser.py | 8 +- .../worlds/atis_world.py | 6 +- requirements.txt | 6 +- test_fixtures/atis/experiment.json | 3 +- .../config/characters_token_embedder.json | 2 +- .../experiment.json | 11 +- .../mml_init_experiment.json | 11 +- .../ungrouped_experiment.json | 11 +- .../experiment.json | 10 +- test_fixtures/text2sql/experiment.json | 3 +- .../experiment-elmo-no-features.json | 10 +- test_fixtures/wikitables/experiment-erm.json | 6 +- .../wikitables/experiment-mixture.json | 10 +- test_fixtures/wikitables/experiment.json | 10 +- tests/common/action_space_walker_test.py | 110 +++++++----------- tests/common/sql/text2sql_utils_test.py | 6 +- .../wikitables/table_question_context_test.py | 4 +- .../grammar_based_text2sql_test.py | 4 +- tests/dataset_readers/wikitables_test.py | 4 +- .../domain_languages/domain_language_test.py | 6 +- tests/domain_languages/nlvr_language_test.py | 4 +- .../wikitables_language_test.py | 42 +++---- tests/fields/knowledge_graph_field_test.py | 6 +- tests/fields/production_rule_field_test.py | 4 +- .../models/atis/atis_grammar_statelet_test.py | 3 +- .../models/atis/atis_semantic_parser_test.py | 6 +- .../nlvr_coverage_semantic_parser_test.py | 25 +--- .../nlvr/nlvr_direct_semantic_parser_test.py | 6 +- tests/models/text2sql_parser_test.py | 6 +- .../wikitables_erm_semantic_parser_test.py | 6 +- .../wikitables_mml_semantic_parser_test.py | 17 +-- tests/nltk_languages/worlds/world_test.py | 4 +- .../executors/sql_executor_test.py | 6 +- .../worlds/atis_world_test.py | 8 +- .../worlds/text2sql_world_test.py | 7 +- tests/predictors/nlvr_parser_test.py | 4 +- .../expected_risk_minimization_test.py | 4 +- .../maximum_marginal_likelihood_test.py | 4 +- .../basic_transition_function_test.py | 10 +- 47 files changed, 218 insertions(+), 249 deletions(-) diff --git a/allennlp_semparse/dataset_readers/atis.py b/allennlp_semparse/dataset_readers/atis.py index c25cc0b..138aee5 100644 --- a/allennlp_semparse/dataset_readers/atis.py +++ b/allennlp_semparse/dataset_readers/atis.py @@ -142,7 +142,7 @@ def text_to_instance( # type: ignore action_sequence = world.get_action_sequence(sql_query) except ParseError: action_sequence = [] - logger.debug(f"Parsing error") + logger.debug("Parsing error") tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) diff --git a/allennlp_semparse/domain_languages/domain_language.py b/allennlp_semparse/domain_languages/domain_language.py index fc144d6..09977fd 100644 --- a/allennlp_semparse/domain_languages/domain_language.py +++ b/allennlp_semparse/domain_languages/domain_language.py @@ -746,3 +746,8 @@ def _construct_node_from_actions( Tree(right_side, []) ) # you add a child to an nltk.Tree with `append` return remaining_actions + + def __len__(self): + # This method exists just to make it easier to use this in a MetadataField. Kind of + # annoying, but oh well. + return 0 diff --git a/allennlp_semparse/fields/knowledge_graph_field.py b/allennlp_semparse/fields/knowledge_graph_field.py index 36d06a0..68d1e41 100644 --- a/allennlp_semparse/fields/knowledge_graph_field.py +++ b/allennlp_semparse/fields/knowledge_graph_field.py @@ -185,6 +185,9 @@ def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): def index(self, vocab: Vocabulary): self._entity_text_field.index(vocab) + def __len__(self) -> int: + return len(self.utterance_tokens) + @overrides def get_padding_lengths(self) -> Dict[str, int]: padding_lengths = { diff --git a/allennlp_semparse/models/atis/atis_semantic_parser.py b/allennlp_semparse/models/atis/atis_semantic_parser.py index 679cd6a..161c966 100644 --- a/allennlp_semparse/models/atis/atis_semantic_parser.py +++ b/allennlp_semparse/models/atis/atis_semantic_parser.py @@ -280,7 +280,7 @@ def _get_initial_state( linking_scores: torch.Tensor, ) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) - utterance_mask = util.get_text_field_mask(utterance).float() + utterance_mask = util.get_text_field_mask(utterance) batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) @@ -546,7 +546,9 @@ def _create_grammar_state( return GrammarStatelet(["statement"], translated_valid_actions, self.is_nonterminal) @overrides - def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" @@ -581,3 +583,6 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info return output_dict + + +default_predictor = "atis-parser" diff --git a/allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py b/allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py index 79a7ac1..8a37c3c 100644 --- a/allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py +++ b/allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py @@ -207,27 +207,19 @@ def forward( agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, - epoch_num: List[int] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ - # We look at the epoch number and adjust the checklist cost weight if needed here. - instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: - if self.training and instance_epoch_num is None: - raise RuntimeError( - "If you want a dynamic cost weight, use the " - "BucketIterator with track_epoch=True." - ) - if instance_epoch_num != self._last_epoch_in_forward: - if instance_epoch_num >= self._dynamic_cost_wait_epochs: - decrement = self._checklist_cost_weight * self._dynamic_cost_rate - self._checklist_cost_weight -= decrement - logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) - self._last_epoch_in_forward = instance_epoch_num + # This could be added back pretty easily with an EpochCallback passed to the Trainer (it + # just has to set the epoch number on the model, which could then be queried in here). + logger.warning( + "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, " + "use version 0.9. We will just use the static checklist cost weight." + ) batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) diff --git a/allennlp_semparse/models/nlvr/nlvr_semantic_parser.py b/allennlp_semparse/models/nlvr/nlvr_semantic_parser.py index ac55d9a..ae0a934 100644 --- a/allennlp_semparse/models/nlvr/nlvr_semantic_parser.py +++ b/allennlp_semparse/models/nlvr/nlvr_semantic_parser.py @@ -84,7 +84,7 @@ def forward(self): # type: ignore def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]): embedded_input = self._sentence_embedder(sentence) # (batch_size, sentence_length) - sentence_mask = util.get_text_field_mask(sentence).float() + sentence_mask = util.get_text_field_mask(sentence) batch_size = embedded_input.size(0) @@ -217,7 +217,9 @@ def _create_grammar_state( return GrammarStatelet([START_SYMBOL], translated_valid_actions, world.is_nonterminal) @overrides - def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. We only transform the action string sequences into logical diff --git a/allennlp_semparse/models/text2sql_parser.py b/allennlp_semparse/models/text2sql_parser.py index 55e4c6b..7111b9e 100644 --- a/allennlp_semparse/models/text2sql_parser.py +++ b/allennlp_semparse/models/text2sql_parser.py @@ -143,7 +143,7 @@ def forward( trailing dimension. """ embedded_utterance = self._utterance_embedder(tokens) - mask = util.get_text_field_mask(tokens).float() + mask = util.get_text_field_mask(tokens) batch_size = embedded_utterance.size(0) # (batch_size, num_tokens, encoder_output_dim) @@ -385,7 +385,9 @@ def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> Gramm ) @overrides - def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" diff --git a/allennlp_semparse/models/wikitables/wikitables_mml_semantic_parser.py b/allennlp_semparse/models/wikitables/wikitables_mml_semantic_parser.py index a71e7bf..59945fb 100644 --- a/allennlp_semparse/models/wikitables/wikitables_mml_semantic_parser.py +++ b/allennlp_semparse/models/wikitables/wikitables_mml_semantic_parser.py @@ -233,3 +233,6 @@ def forward( actions, best_final_states, world, target_values, metadata, outputs ) return outputs + + +default_predictor = "wikitables-parser" diff --git a/allennlp_semparse/models/wikitables/wikitables_semantic_parser.py b/allennlp_semparse/models/wikitables/wikitables_semantic_parser.py index 49960b1..cb7d211 100644 --- a/allennlp_semparse/models/wikitables/wikitables_semantic_parser.py +++ b/allennlp_semparse/models/wikitables/wikitables_semantic_parser.py @@ -169,10 +169,10 @@ def _get_initial_rnn_and_grammar_state( table_text = table["text"] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) - question_mask = util.get_text_field_mask(question).float() + question_mask = util.get_text_field_mask(question) # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) - table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() + table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1) batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) @@ -740,7 +740,9 @@ def _compute_validation_outputs( outputs["question_tokens"] = [x["question_tokens"] for x in metadata] @overrides - def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. This is (confusingly) a separate notion from the "decoder" diff --git a/allennlp_semparse/parsimonious_languages/worlds/atis_world.py b/allennlp_semparse/parsimonious_languages/worlds/atis_world.py index c9a79d7..996c513 100644 --- a/allennlp_semparse/parsimonious_languages/worlds/atis_world.py +++ b/allennlp_semparse/parsimonious_languages/worlds/atis_world.py @@ -113,7 +113,7 @@ def _update_grammar(self): new_grammar["col_ref"], Literal("BETWEEN"), new_grammar["time_range_start"], - Literal(f"AND"), + Literal("AND"), new_grammar["time_range_end"], ], ), @@ -124,7 +124,7 @@ def _update_grammar(self): Literal("NOT"), Literal("BETWEEN"), new_grammar["time_range_start"], - Literal(f"AND"), + Literal("AND"), new_grammar["time_range_end"], ], ), @@ -135,7 +135,7 @@ def _update_grammar(self): Literal("not"), Literal("BETWEEN"), new_grammar["time_range_start"], - Literal(f"AND"), + Literal("AND"), new_grammar["time_range_end"], ], ), diff --git a/requirements.txt b/requirements.txt index e3e638e..cccf7af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -# To be changed to allennlp>=1.0 once that's released. -git+git://github.com/allenai/allennlp@fa14bd8c177a12709a2fac0616eafaa9ad430b0a +allennlp==1.0 # Used to create grammars for parsing SQL parsimonious>=0.8.0 @@ -15,3 +14,6 @@ editdistance # Used for the type system for some languages nltk + +# Used for some tests +flaky diff --git a/test_fixtures/atis/experiment.json b/test_fixtures/atis/experiment.json index 34799f8..0f02b81 100644 --- a/test_fixtures/atis/experiment.json +++ b/test_fixtures/atis/experiment.json @@ -38,8 +38,7 @@ "dropout": 0.5, "database_file": "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" }, - "iterator": { - "type": "basic", + "data_loader": { "batch_size" : 4 }, "trainer": { diff --git a/test_fixtures/elmo/config/characters_token_embedder.json b/test_fixtures/elmo/config/characters_token_embedder.json index 2d9daeb..d31576b 100644 --- a/test_fixtures/elmo/config/characters_token_embedder.json +++ b/test_fixtures/elmo/config/characters_token_embedder.json @@ -41,7 +41,7 @@ ["transitions$", {"type": "l2", "alpha": 0.01}] ] }, - "iterator": {"type": "basic", "batch_size": 32}, + "data_loader": {"batch_size": 32}, "trainer": { "optimizer": "adam", "num_epochs": 5, diff --git a/test_fixtures/nlvr_coverage_semantic_parser/experiment.json b/test_fixtures/nlvr_coverage_semantic_parser/experiment.json index 1b6a352..f00b517 100644 --- a/test_fixtures/nlvr_coverage_semantic_parser/experiment.json +++ b/test_fixtures/nlvr_coverage_semantic_parser/experiment.json @@ -38,11 +38,12 @@ "dropout": 0.3, "penalize_non_agenda_actions": true }, - "iterator": { - "type": "bucket", - "track_epoch": true, - "padding_noise": 0.0, - "batch_size" : 4 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 4 + } }, "trainer": { "num_epochs": 1, diff --git a/test_fixtures/nlvr_coverage_semantic_parser/mml_init_experiment.json b/test_fixtures/nlvr_coverage_semantic_parser/mml_init_experiment.json index b0baa2f..f99566f 100644 --- a/test_fixtures/nlvr_coverage_semantic_parser/mml_init_experiment.json +++ b/test_fixtures/nlvr_coverage_semantic_parser/mml_init_experiment.json @@ -37,11 +37,12 @@ "penalize_non_agenda_actions": true, "initial_mml_model_file": "test_fixtures/semantic_parsing/nlvr_direct_semantic_parser/serialization/model.tar.gz" }, - "iterator": { - "type": "bucket", - "track_epoch": true, - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 1, diff --git a/test_fixtures/nlvr_coverage_semantic_parser/ungrouped_experiment.json b/test_fixtures/nlvr_coverage_semantic_parser/ungrouped_experiment.json index ed39921..f56c203 100644 --- a/test_fixtures/nlvr_coverage_semantic_parser/ungrouped_experiment.json +++ b/test_fixtures/nlvr_coverage_semantic_parser/ungrouped_experiment.json @@ -36,11 +36,12 @@ }, "penalize_non_agenda_actions": true }, - "iterator": { - "type": "bucket", - "track_epoch": true, - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 1, diff --git a/test_fixtures/nlvr_direct_semantic_parser/experiment.json b/test_fixtures/nlvr_direct_semantic_parser/experiment.json index a3042fc..88d3f20 100644 --- a/test_fixtures/nlvr_direct_semantic_parser/experiment.json +++ b/test_fixtures/nlvr_direct_semantic_parser/experiment.json @@ -33,10 +33,12 @@ "attention": {"type": "dot_product"}, "dropout": 0.2 }, - "iterator": { - "type": "bucket", - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 1, diff --git a/test_fixtures/text2sql/experiment.json b/test_fixtures/text2sql/experiment.json index 9a0de9a..d63d49a 100644 --- a/test_fixtures/text2sql/experiment.json +++ b/test_fixtures/text2sql/experiment.json @@ -32,8 +32,7 @@ "input_attention": {"type": "dot_product"}, "dropout": 0.0 }, - "iterator": { - "type": "basic", + "data_loader": { "batch_size" : 4 }, "trainer": { diff --git a/test_fixtures/wikitables/experiment-elmo-no-features.json b/test_fixtures/wikitables/experiment-elmo-no-features.json index d3692a0..8fc1059 100644 --- a/test_fixtures/wikitables/experiment-elmo-no-features.json +++ b/test_fixtures/wikitables/experiment-elmo-no-features.json @@ -49,10 +49,12 @@ "use_neighbor_similarity_for_linking": true, "tables_directory": "test_fixtures/data/wikitables/" }, - "iterator": { - "type": "bucket", - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 2, diff --git a/test_fixtures/wikitables/experiment-erm.json b/test_fixtures/wikitables/experiment-erm.json index 9885ae6..3bc8d32 100644 --- a/test_fixtures/wikitables/experiment-erm.json +++ b/test_fixtures/wikitables/experiment-erm.json @@ -35,10 +35,8 @@ "decoder_num_finished_states": 100, "attention": {"type": "dot_product"} }, - "iterator": { - "type": "bucket", - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_size": 2 }, "trainer": { "num_epochs": 2, diff --git a/test_fixtures/wikitables/experiment-mixture.json b/test_fixtures/wikitables/experiment-mixture.json index 7c6ab36..676fa96 100644 --- a/test_fixtures/wikitables/experiment-mixture.json +++ b/test_fixtures/wikitables/experiment-mixture.json @@ -43,10 +43,12 @@ "max_decoding_steps": 200, "attention": {"type": "dot_product"} }, - "iterator": { - "type": "bucket", - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 2, diff --git a/test_fixtures/wikitables/experiment.json b/test_fixtures/wikitables/experiment.json index 256525b..6323d96 100644 --- a/test_fixtures/wikitables/experiment.json +++ b/test_fixtures/wikitables/experiment.json @@ -35,10 +35,12 @@ "max_decoding_steps": 200, "attention": {"type": "dot_product"} }, - "iterator": { - "type": "bucket", - "padding_noise": 0.0, - "batch_size" : 2 + "data_loader": { + "batch_sampler": { + "type": "bucket", + "padding_noise": 0.0, + "batch_size": 2 + } }, "trainer": { "num_epochs": 2, diff --git a/tests/common/action_space_walker_test.py b/tests/common/action_space_walker_test.py index 9c14ae4..83da2b7 100644 --- a/tests/common/action_space_walker_test.py +++ b/tests/common/action_space_walker_test.py @@ -31,9 +31,9 @@ def all_objects(self) -> Set[Object]: return set() -class ActionSpaceWalkerTest(SemparseTestCase): - def setUp(self): - super(ActionSpaceWalkerTest, self).setUp() +class TestActionSpaceWalker(SemparseTestCase): + def setup_method(self): + super().setup_method() self.world = FakeLanguageWithAssertions(start_types={bool}) self.walker = ActionSpaceWalker(self.world, max_path_length=10) @@ -120,77 +120,57 @@ def test_get_logical_forms_with_agenda_and_partial_match(self): ] ) - def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(self): - with self.assertLogs("allennlp_semparse.common.action_space_walker") as log: - empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( - [], allow_partial_match=True - ) - first_four_logical_forms = empty_agenda_logical_forms[:4] - assert set(first_four_logical_forms) == { - "(object_exists all_objects)", - "(object_exists (black all_objects))", - "(object_exists (touch_wall all_objects))", - "(object_exists (triangle all_objects))", - } - self.assertEqual( - log.output, - [ - "WARNING:allennlp_semparse.common.action_space_walker:" - "Agenda is empty! Returning all paths instead." - ], + def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(self, caplog): + empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( + [], allow_partial_match=True ) + first_four_logical_forms = empty_agenda_logical_forms[:4] + assert set(first_four_logical_forms) == { + "(object_exists all_objects)", + "(object_exists (black all_objects))", + "(object_exists (touch_wall all_objects))", + "(object_exists (triangle all_objects))", + } + assert "Agenda is empty! Returning all paths instead." in caplog.text - def test_get_logical_forms_with_unmatched_agenda_returns_all_logical_forms(self): + def test_get_logical_forms_with_unmatched_agenda_returns_all_logical_forms(self, caplog): agenda = [" -> purple"] - with self.assertLogs("allennlp_semparse.common.action_space_walker") as log: - empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( - agenda, allow_partial_match=True - ) - first_four_logical_forms = empty_agenda_logical_forms[:4] - assert set(first_four_logical_forms) == { - "(object_exists all_objects)", - "(object_exists (black all_objects))", - "(object_exists (touch_wall all_objects))", - "(object_exists (triangle all_objects))", - } - self.assertEqual( - log.output, - [ - "WARNING:allennlp_semparse.common.action_space_walker:" - "Agenda items not in any of the paths found. Returning all paths." - ], + empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( + agenda, allow_partial_match=True ) + first_four_logical_forms = empty_agenda_logical_forms[:4] + assert set(first_four_logical_forms) == { + "(object_exists all_objects)", + "(object_exists (black all_objects))", + "(object_exists (touch_wall all_objects))", + "(object_exists (triangle all_objects))", + } + assert "Agenda items not in any of the paths found. Returning all paths." in caplog.text empty_set = self.walker.get_logical_forms_with_agenda(agenda, allow_partial_match=False) assert empty_set == [] - def test_get_logical_forms_with_agenda_ignores_null_set_item(self): - with self.assertLogs("allennlp_semparse.common.action_space_walker") as log: - agenda = [ - " -> yellow", - " -> black", - " -> triangle", - " -> touch_wall", - ] - yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(agenda) - # Permutations of the three functions, after ignoring yellow. There will not be repetitions - # of any functions because we limit the length of paths to 10 above. - assert set(yellow_black_triangle_touch_forms) == set( - [ - "(object_exists (black (triangle (touch_wall all_objects))))", - "(object_exists (black (touch_wall (triangle all_objects))))", - "(object_exists (triangle (black (touch_wall all_objects))))", - "(object_exists (triangle (touch_wall (black all_objects))))", - "(object_exists (touch_wall (black (triangle all_objects))))", - "(object_exists (touch_wall (triangle (black all_objects))))", - ] - ) - self.assertEqual( - log.output, + def test_get_logical_forms_with_agenda_ignores_null_set_item(self, caplog): + agenda = [ + " -> yellow", + " -> black", + " -> triangle", + " -> touch_wall", + ] + yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(agenda) + # Permutations of the three functions, after ignoring yellow. There will not be repetitions + # of any functions because we limit the length of paths to 10 above. + assert set(yellow_black_triangle_touch_forms) == set( [ - "WARNING:allennlp_semparse.common.action_space_walker:" - " -> yellow is not in any of the paths found! Ignoring it." - ], + "(object_exists (black (triangle (touch_wall all_objects))))", + "(object_exists (black (touch_wall (triangle all_objects))))", + "(object_exists (triangle (black (touch_wall all_objects))))", + "(object_exists (triangle (touch_wall (black all_objects))))", + "(object_exists (touch_wall (black (triangle all_objects))))", + "(object_exists (touch_wall (triangle (black all_objects))))", + ] ) + log = " -> yellow is not in any of the paths found! Ignoring it." + assert log in caplog.text def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. diff --git a/tests/common/sql/text2sql_utils_test.py b/tests/common/sql/text2sql_utils_test.py index dcc8dea..ce8ec07 100644 --- a/tests/common/sql/text2sql_utils_test.py +++ b/tests/common/sql/text2sql_utils_test.py @@ -4,9 +4,9 @@ from allennlp_semparse.common.sql import text2sql_utils -class Text2SqlUtilsTest(SemparseTestCase): - def setUp(self): - super().setUp() +class TestText2SqlUtils(SemparseTestCase): + def setup_method(self): + super().setup_method() self.data = self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants_tiny.json" def test_process_sql_data_blob(self): diff --git a/tests/common/wikitables/table_question_context_test.py b/tests/common/wikitables/table_question_context_test.py index e3ac70a..457ea96 100644 --- a/tests/common/wikitables/table_question_context_test.py +++ b/tests/common/wikitables/table_question_context_test.py @@ -6,8 +6,8 @@ class TestTableQuestionContext(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.tokenizer = SpacyTokenizer(pos_tags=True) def test_table_data(self): diff --git a/tests/dataset_readers/grammar_based_text2sql_test.py b/tests/dataset_readers/grammar_based_text2sql_test.py index 1f98b53..e0cbb16 100644 --- a/tests/dataset_readers/grammar_based_text2sql_test.py +++ b/tests/dataset_readers/grammar_based_text2sql_test.py @@ -9,8 +9,8 @@ @pytest.mark.skip(reason="Mark will fix in a nearby PR.") class TestGrammarBasedText2SqlDatasetReader(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.data_path = str(self.FIXTURES_ROOT / "data" / "text2sql" / "*.json") self.schema = str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants-schema.csv") self.database = str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants.db") diff --git a/tests/dataset_readers/wikitables_test.py b/tests/dataset_readers/wikitables_test.py index ad34e57..55ded5b 100644 --- a/tests/dataset_readers/wikitables_test.py +++ b/tests/dataset_readers/wikitables_test.py @@ -62,7 +62,7 @@ def assert_dataset_correct(dataset): # first one in the file, or the shortest logical form by _string length_. It's also a totally # made up logical form, just to demonstrate that we're sorting things correctly. action_sequence = instance.fields["target_action_sequences"].field_list[0] - action_indices = [l.sequence_index for l in action_sequence.field_list] + action_indices = [action.sequence_index for action in action_sequence.field_list] actions = [actions[i] for i in action_indices] assert actions == [ "@start@ -> Number", @@ -79,7 +79,7 @@ def assert_dataset_correct(dataset): ] -class WikiTablesDatasetReaderTest(SemparseTestCase): +class TestWikiTablesDatasetReader(SemparseTestCase): def test_reader_reads(self): offline_search_directory = ( self.FIXTURES_ROOT / "data" / "wikitables" / "action_space_walker_output" diff --git a/tests/domain_languages/domain_language_test.py b/tests/domain_languages/domain_language_test.py index 883e405..e57e91a 100644 --- a/tests/domain_languages/domain_language_test.py +++ b/tests/domain_languages/domain_language_test.py @@ -108,9 +108,9 @@ def check_productions_match(actual_rules: List[str], expected_right_sides: List[ assert set(actual_right_sides) == set(expected_right_sides) -class DomainLanguageTest(SemparseTestCase): - def setUp(self): - super().setUp() +class TestDomainLanguage(SemparseTestCase): + def setup_method(self): + super().setup_method() self.language = Arithmetic() def test_constant_logical_form(self): diff --git a/tests/domain_languages/nlvr_language_test.py b/tests/domain_languages/nlvr_language_test.py index 049da33..bf9b69b 100644 --- a/tests/domain_languages/nlvr_language_test.py +++ b/tests/domain_languages/nlvr_language_test.py @@ -7,8 +7,8 @@ class TestNlvrLanguage(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl" data = [json.loads(line)["structured_rep"] for line in open(test_filename).readlines()] box_lists = [ diff --git a/tests/domain_languages/wikitables_language_test.py b/tests/domain_languages/wikitables_language_test.py index 620b596..6bc76df 100644 --- a/tests/domain_languages/wikitables_language_test.py +++ b/tests/domain_languages/wikitables_language_test.py @@ -14,8 +14,8 @@ class TestWikiTablesLanguage(SemparseTestCase): # TODO(mattg, pradeep): Add tests for the ActionSpaceWalker as well. - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() # Adding a bunch of random tokens in here so we get them as constants in the language. question_tokens = [ Token(x) @@ -268,20 +268,13 @@ def test_execute_works_with_first(self): cell_list = self.language.execute(logical_form) assert cell_list == ["4th_western"] - def test_execute_logs_warning_with_first_on_empty_list(self): + def test_execute_logs_warning_with_first_on_empty_list(self, caplog): # Selecting "regular season" from the first row where year is greater than 2010. - with self.assertLogs("allennlp_semparse.domain_languages.wikitables_language") as log: - logical_form = """(select_string (first (filter_date_greater all_rows date_column:year - (date 2010 -1 -1))) - string_column:regular_season)""" - self.language.execute(logical_form) - self.assertEqual( - log.output, - [ - "WARNING:allennlp_semparse.domain_languages.wikitables_language:" - "Trying to get first row from an empty list" - ], - ) + logical_form = """(select_string (first (filter_date_greater all_rows date_column:year + (date 2010 -1 -1))) + string_column:regular_season)""" + self.language.execute(logical_form) + assert "Trying to get first row from an empty list" in caplog.text def test_execute_works_with_last(self): # Selecting "regular season" from the last row where year is not equal to 2010. @@ -291,20 +284,13 @@ def test_execute_works_with_last(self): cell_list = self.language.execute(logical_form) assert cell_list == ["5th"] - def test_execute_logs_warning_with_last_on_empty_list(self): + def test_execute_logs_warning_with_last_on_empty_list(self, caplog): # Selecting "regular season" from the last row where year is greater than 2010. - with self.assertLogs("allennlp_semparse.domain_languages.wikitables_language") as log: - logical_form = """(select_string (last (filter_date_greater all_rows date_column:year - (date 2010 -1 -1))) - string_column:regular_season)""" - self.language.execute(logical_form) - self.assertEqual( - log.output, - [ - "WARNING:allennlp_semparse.domain_languages.wikitables_language:" - "Trying to get last row from an empty list" - ], - ) + logical_form = """(select_string (last (filter_date_greater all_rows date_column:year + (date 2010 -1 -1))) + string_column:regular_season)""" + self.language.execute(logical_form) + assert "Trying to get last row from an empty list" in caplog.text def test_execute_works_with_previous(self): # Selecting "regular season" from the row before last where year is not equal to 2010. diff --git a/tests/fields/knowledge_graph_field_test.py b/tests/fields/knowledge_graph_field_test.py index e28a297..85f9f2a 100644 --- a/tests/fields/knowledge_graph_field_test.py +++ b/tests/fields/knowledge_graph_field_test.py @@ -14,8 +14,8 @@ from allennlp_semparse.fields import KnowledgeGraphField -class KnowledgeGraphFieldTest(SemparseTestCase): - def setUp(self): +class TestKnowledgeGraphField(SemparseTestCase): + def setup_method(self): self.tokenizer = SpacyTokenizer(pos_tags=True) self.utterance = self.tokenizer.tokenize("where is mersin?") self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")} @@ -37,7 +37,7 @@ def setUp(self): self.graph, self.utterance, self.token_indexers, self.tokenizer ) - super(KnowledgeGraphFieldTest, self).setUp() + super().setup_method() def test_count_vocab_items(self): namespace_token_counts = defaultdict(lambda: defaultdict(int)) diff --git a/tests/fields/production_rule_field_test.py b/tests/fields/production_rule_field_test.py index 8c213dd..83a28a6 100644 --- a/tests/fields/production_rule_field_test.py +++ b/tests/fields/production_rule_field_test.py @@ -10,8 +10,8 @@ class TestProductionRuleField(SemparseTestCase): - def setUp(self): - super(TestProductionRuleField, self).setUp() + def setup_method(self): + super(TestProductionRuleField, self).setup_method() self.vocab = Vocabulary() self.s_rule_index = self.vocab.add_token_to_namespace( "S -> [NP, VP]", namespace="rule_labels" diff --git a/tests/models/atis/atis_grammar_statelet_test.py b/tests/models/atis/atis_grammar_statelet_test.py index cf5f2ce..9a29ff7 100644 --- a/tests/models/atis/atis_grammar_statelet_test.py +++ b/tests/models/atis/atis_grammar_statelet_test.py @@ -2,7 +2,6 @@ import torch from allennlp.common import Params -from allennlp.modules import SimilarityFunction from allennlp_semparse.models.atis.atis_semantic_parser import AtisSemanticParser from allennlp_semparse.parsimonious_languages.worlds import AtisWorld @@ -10,7 +9,7 @@ from ... import SemparseTestCase -class AtisGrammarStateletTest(SemparseTestCase): +class TestAtisGrammarStatelet(SemparseTestCase): def test_atis_grammar_statelet(self): world = AtisWorld( [("give me all flights from boston to " "philadelphia next week arriving after lunch")] diff --git a/tests/models/atis/atis_semantic_parser_test.py b/tests/models/atis/atis_semantic_parser_test.py index 2095d62..b5db0ba 100644 --- a/tests/models/atis/atis_semantic_parser_test.py +++ b/tests/models/atis/atis_semantic_parser_test.py @@ -6,9 +6,9 @@ ) -class AtisSemanticParserTest(ModelTestCase): - def setUp(self): - super(AtisSemanticParserTest, self).setUp() +class TestAtisSemanticParser(ModelTestCase): + def setup_method(self): + super().setup_method() self.set_up_model( str(self.FIXTURES_ROOT / "atis" / "experiment.json"), str(self.FIXTURES_ROOT / "data" / "atis" / "sample.json"), diff --git a/tests/models/nlvr/nlvr_coverage_semantic_parser_test.py b/tests/models/nlvr/nlvr_coverage_semantic_parser_test.py index a657624..4919aee 100644 --- a/tests/models/nlvr/nlvr_coverage_semantic_parser_test.py +++ b/tests/models/nlvr/nlvr_coverage_semantic_parser_test.py @@ -1,17 +1,17 @@ from numpy.testing import assert_almost_equal import torch +import pytest from allennlp.common import Params from ... import ModelTestCase from allennlp.data import Vocabulary -from allennlp.data.iterators import BucketIterator from allennlp.models import Model from allennlp.models.archival import load_archive -class NlvrCoverageSemanticParserTest(ModelTestCase): - def setUp(self): - super(NlvrCoverageSemanticParserTest, self).setUp() +class TestNlvrCoverageSemanticParser(ModelTestCase): + def setup_method(self): + super().setup_method() self.set_up_model( self.FIXTURES_ROOT / "nlvr_coverage_semantic_parser" / "experiment.json", self.FIXTURES_ROOT / "data" / "nlvr" / "sample_grouped_data.jsonl", @@ -48,21 +48,6 @@ def test_get_checklist_info(self): assert_almost_equal(terminal_actions.data.numpy(), [[0], [2], [4]]) assert_almost_equal(checklist_mask.data.numpy(), [[1], [1], [1]]) - def test_forward_with_epoch_num_changes_cost_weight(self): - # Redefining model. We do not want this to change the state of ``self.model``. - params = Params.from_file(self.param_file) - model = Model.from_params(vocab=self.vocab, params=params["model"]) - # Initial cost weight, before forward is called. - assert model._checklist_cost_weight == 0.8 - iterator = BucketIterator(track_epoch=True) - cost_weights = [] - for epoch_data in iterator(self.dataset, num_epochs=4): - model.forward(**epoch_data) - cost_weights.append(model._checklist_cost_weight) - # The config file has ``wait_num_epochs`` set to 0, so the model starts decreasing the cost - # weight at epoch 0 itself. - assert_almost_equal(cost_weights, [0.72, 0.648, 0.5832, 0.52488]) - def test_initialize_weights_from_archive(self): original_model_parameters = self.model.named_parameters() original_model_weights = { @@ -81,7 +66,7 @@ def test_initialize_weights_from_archive(self): changed_weight = changed_model_parameters[name].data.numpy() # We want to make sure that the weights in the original model have indeed been changed # after a call to ``_initialize_weights_from_archive``. - with self.assertRaises(AssertionError, msg=f"{name} has not changed"): + with pytest.raises(AssertionError, match="Arrays are not almost equal"): assert_almost_equal(original_weight, changed_weight) # This also includes the sentence token embedder. Those weights will be the same # because the two models have the same vocabulary. diff --git a/tests/models/nlvr/nlvr_direct_semantic_parser_test.py b/tests/models/nlvr/nlvr_direct_semantic_parser_test.py index 30a7ee6..99152d7 100644 --- a/tests/models/nlvr/nlvr_direct_semantic_parser_test.py +++ b/tests/models/nlvr/nlvr_direct_semantic_parser_test.py @@ -1,9 +1,9 @@ from ... import ModelTestCase -class NlvrDirectSemanticParserTest(ModelTestCase): - def setUp(self): - super(NlvrDirectSemanticParserTest, self).setUp() +class TestNlvrDirectSemanticParser(ModelTestCase): + def setup_method(self): + super().setup_method() self.set_up_model( self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "experiment.json", self.FIXTURES_ROOT / "data" / "nlvr" / "sample_processed_data.jsonl", diff --git a/tests/models/text2sql_parser_test.py b/tests/models/text2sql_parser_test.py index c399531..1de3a3a 100644 --- a/tests/models/text2sql_parser_test.py +++ b/tests/models/text2sql_parser_test.py @@ -5,9 +5,9 @@ from allennlp_semparse.parsimonious_languages.worlds.text2sql_world import Text2SqlWorld -class Text2SqlParserTest(ModelTestCase): - def setUp(self): - super().setUp() +class TestText2SqlParser(ModelTestCase): + def setup_method(self): + super().setup_method() self.set_up_model( str(self.FIXTURES_ROOT / "text2sql" / "experiment.json"), diff --git a/tests/models/wikitables/wikitables_erm_semantic_parser_test.py b/tests/models/wikitables/wikitables_erm_semantic_parser_test.py index 243f120..93cb4b5 100644 --- a/tests/models/wikitables/wikitables_erm_semantic_parser_test.py +++ b/tests/models/wikitables/wikitables_erm_semantic_parser_test.py @@ -3,9 +3,9 @@ from ... import ModelTestCase -class WikiTablesVariableFreeErmTest(ModelTestCase): - def setUp(self): - super(WikiTablesVariableFreeErmTest, self).setUp() +class TestWikiTablesVariableFreeErm(ModelTestCase): + def setup_method(self): + super().setup_method() config_path = self.FIXTURES_ROOT / "wikitables" / "experiment-erm.json" data_path = self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples" self.set_up_model(config_path, data_path) diff --git a/tests/models/wikitables/wikitables_mml_semantic_parser_test.py b/tests/models/wikitables/wikitables_mml_semantic_parser_test.py index 0a3a36f..43fabd3 100644 --- a/tests/models/wikitables/wikitables_mml_semantic_parser_test.py +++ b/tests/models/wikitables/wikitables_mml_semantic_parser_test.py @@ -5,12 +5,11 @@ from allennlp.common import Params from ... import ModelTestCase -from allennlp.data.iterators import DataIterator -class WikiTablesMmlSemanticParserTest(ModelTestCase): - def setUp(self): - super(WikiTablesMmlSemanticParserTest, self).setUp() +class TestWikiTablesMmlSemanticParser(ModelTestCase): + def setup_method(self): + super().setup_method() print(self.FIXTURES_ROOT) config_path = self.FIXTURES_ROOT / "wikitables" / "experiment.json" data_path = self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples" @@ -20,15 +19,11 @@ def setUp(self): def test_model_can_train_save_and_load(self): self.ensure_model_can_train_save_and_load(self.param_file) - def test_model_decode(self): - params = Params.from_file(self.param_file) - iterator_params = params["iterator"] - iterator = DataIterator.from_params(iterator_params) - iterator.index_with(self.model.vocab) - model_batch = next(iterator(self.dataset, shuffle=False)) + def test_make_output_human_readable(self): + model_batch = self.dataset.as_tensor_dict(self.dataset.get_padding_lengths()) self.model.training = False forward_output = self.model(**model_batch) - decode_output = self.model.decode(forward_output) + decode_output = self.model.make_output_human_readable(forward_output) assert "predicted_actions" in decode_output def test_get_neighbor_indices(self): diff --git a/tests/nltk_languages/worlds/world_test.py b/tests/nltk_languages/worlds/world_test.py index 9ac5471..e4b56b9 100644 --- a/tests/nltk_languages/worlds/world_test.py +++ b/tests/nltk_languages/worlds/world_test.py @@ -33,8 +33,8 @@ def all_possible_actions(self): class TestWorld(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.world_without_recursion = FakeWorldWithoutRecursion() self.world_with_recursion = FakeWorldWithRecursion() diff --git a/tests/parsimonious_languages/executors/sql_executor_test.py b/tests/parsimonious_languages/executors/sql_executor_test.py index dd74c1b..f6817d2 100644 --- a/tests/parsimonious_languages/executors/sql_executor_test.py +++ b/tests/parsimonious_languages/executors/sql_executor_test.py @@ -3,9 +3,9 @@ from allennlp_semparse.parsimonious_languages.executors import SqlExecutor -class SqlExecutorTest(SemparseTestCase): - def setUp(self): - super().setUp() +class TestSqlExecutor(SemparseTestCase): + def setup_method(self): + super().setup_method() self._database_file = "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" def test_sql_accuracy_is_scored_correctly(self): diff --git a/tests/parsimonious_languages/worlds/atis_world_test.py b/tests/parsimonious_languages/worlds/atis_world_test.py index 2376976..fd36925 100644 --- a/tests/parsimonious_languages/worlds/atis_world_test.py +++ b/tests/parsimonious_languages/worlds/atis_world_test.py @@ -14,8 +14,8 @@ class TestAtisWorld(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() test_filename = self.FIXTURES_ROOT / "data" / "atis" / "sample.json" self.data = open(test_filename).readlines() self.database_file = cached_path("https://allennlp.s3.amazonaws.com/datasets/atis/atis.db") @@ -862,7 +862,7 @@ def test_atis_helper_methods(self): world.grammar["col_ref"], Literal("BETWEEN"), world.grammar["time_range_start"], - Literal(f"AND"), + Literal("AND"), world.grammar["time_range_end"], ], ) == Sequence( @@ -872,7 +872,7 @@ def test_atis_helper_methods(self): world.grammar["ws"], world.grammar["time_range_start"], world.grammar["ws"], - Literal(f"AND"), + Literal("AND"), world.grammar["ws"], world.grammar["time_range_end"], world.grammar["ws"], diff --git a/tests/parsimonious_languages/worlds/text2sql_world_test.py b/tests/parsimonious_languages/worlds/text2sql_world_test.py index 0e09162..316d4a1 100644 --- a/tests/parsimonious_languages/worlds/text2sql_world_test.py +++ b/tests/parsimonious_languages/worlds/text2sql_world_test.py @@ -1,6 +1,7 @@ import sqlite3 from parsimonious import Grammar, ParseError +import pytest from ... import SemparseTestCase @@ -12,8 +13,8 @@ class TestText2SqlWorld(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.schema = str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants-schema.csv") self.database_path = str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants.db") @@ -106,7 +107,7 @@ def test_variable_free_world_cannot_parse_as_statements(self): grammar = Grammar(format_grammar_string(world.base_grammar_dictionary)) sql_visitor = SqlVisitor(grammar) - with self.assertRaises(ParseError): + with pytest.raises(ParseError): sql_visitor.parse(" ".join(sql_with_as)) sql = [ diff --git a/tests/predictors/nlvr_parser_test.py b/tests/predictors/nlvr_parser_test.py index 0f51446..d2cb3f5 100644 --- a/tests/predictors/nlvr_parser_test.py +++ b/tests/predictors/nlvr_parser_test.py @@ -7,8 +7,8 @@ class TestNlvrParserPredictor(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.inputs = { "worlds": [ [ diff --git a/tests/state_machines/trainers/expected_risk_minimization_test.py b/tests/state_machines/trainers/expected_risk_minimization_test.py index ee16091..31f205d 100644 --- a/tests/state_machines/trainers/expected_risk_minimization_test.py +++ b/tests/state_machines/trainers/expected_risk_minimization_test.py @@ -9,8 +9,8 @@ class TestExpectedRiskMinimization(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])]) self.decoder_step = SimpleTransitionFunction() # Cost is the number of odd elements in the action history. diff --git a/tests/state_machines/trainers/maximum_marginal_likelihood_test.py b/tests/state_machines/trainers/maximum_marginal_likelihood_test.py index c34d076..1f4e7ae 100644 --- a/tests/state_machines/trainers/maximum_marginal_likelihood_test.py +++ b/tests/state_machines/trainers/maximum_marginal_likelihood_test.py @@ -10,8 +10,8 @@ class TestMaximumMarginalLikelihood(SemparseTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() self.initial_state = SimpleState( [0, 1], [[], []], [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1] ) diff --git a/tests/state_machines/transition_functions/basic_transition_function_test.py b/tests/state_machines/transition_functions/basic_transition_function_test.py index eaf32d4..6dc241e 100644 --- a/tests/state_machines/transition_functions/basic_transition_function_test.py +++ b/tests/state_machines/transition_functions/basic_transition_function_test.py @@ -9,9 +9,9 @@ from allennlp_semparse.state_machines.transition_functions import BasicTransitionFunction -class BasicTransitionFunctionTest(SemparseTestCase): - def setUp(self): - super().setUp() +class TestBasicTransitionFunction(SemparseTestCase): + def setup_method(self): + super().setup_method() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, @@ -130,8 +130,8 @@ def test_take_step(self): # For batch instance 1, we should have selected action 0 from group index 1 - there was # only one allowed action. assert new_state.batch_indices == [1] - # These two have values taken from what's defined in setUp() - the prior action history - # ([3, 4]) and the nonterminals corresponding to the action we picked ('q'). + # These two have values taken from what's defined in setup_method() - the prior action + # history ([3, 4]) and the nonterminals corresponding to the action we picked ('q'). assert new_state.action_history == [[3, 4, 0]] assert new_state.grammar_state[0]._nonterminal_stack == ["q"] # And these should just be copied from the prior state.