Skip to content
This repository has been archived by the owner on Apr 9, 2022. It is now read-only.

Commit

Permalink
Update to support Allen NLP 1.0 (#25)
Browse files Browse the repository at this point in the history
* Update requirements.txt

* Rename decode to make_output_human_readable

* Rename decode to make_output_human_readable

* Rename decode to make_output_human_readable

* Rename decode to make_output_human_readable

* remove conversion to float

* fix tests

* fix configs

* flake

* more flake...

* change test names

* fix more tests

Co-authored-by: Matt Gardner <[email protected]>
  • Loading branch information
tedgoddard and matt-gardner authored Jun 26, 2020
1 parent 339e617 commit 28b7cf5
Show file tree
Hide file tree
Showing 47 changed files with 218 additions and 249 deletions.
2 changes: 1 addition & 1 deletion allennlp_semparse/dataset_readers/atis.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def text_to_instance( # type: ignore
action_sequence = world.get_action_sequence(sql_query)
except ParseError:
action_sequence = []
logger.debug(f"Parsing error")
logger.debug("Parsing error")

tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
utterance_field = TextField(tokenized_utterance, self._token_indexers)
Expand Down
5 changes: 5 additions & 0 deletions allennlp_semparse/domain_languages/domain_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,8 @@ def _construct_node_from_actions(
Tree(right_side, [])
) # you add a child to an nltk.Tree with `append`
return remaining_actions

def __len__(self):
# This method exists just to make it easier to use this in a MetadataField. Kind of
# annoying, but oh well.
return 0
3 changes: 3 additions & 0 deletions allennlp_semparse/fields/knowledge_graph_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
def index(self, vocab: Vocabulary):
self._entity_text_field.index(vocab)

def __len__(self) -> int:
return len(self.utterance_tokens)

@overrides
def get_padding_lengths(self) -> Dict[str, int]:
padding_lengths = {
Expand Down
9 changes: 7 additions & 2 deletions allennlp_semparse/models/atis/atis_semantic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _get_initial_state(
linking_scores: torch.Tensor,
) -> GrammarBasedState:
embedded_utterance = self._utterance_embedder(utterance)
utterance_mask = util.get_text_field_mask(utterance).float()
utterance_mask = util.get_text_field_mask(utterance)

batch_size = embedded_utterance.size(0)
num_entities = max([len(world.entities) for world in worlds])
Expand Down Expand Up @@ -546,7 +546,9 @@ def _create_grammar_state(
return GrammarStatelet(["statement"], translated_valid_actions, self.is_nonterminal)

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
time, to finalize predictions. This is (confusingly) a separate notion from the "decoder"
Expand Down Expand Up @@ -581,3 +583,6 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
batch_action_info.append(instance_action_info)
output_dict["predicted_actions"] = batch_action_info
return output_dict


default_predictor = "atis-parser"
20 changes: 6 additions & 14 deletions allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,27 +207,19 @@ def forward(
agenda: torch.LongTensor,
identifier: List[str] = None,
labels: torch.LongTensor = None,
epoch_num: List[int] = None,
metadata: List[Dict[str, Any]] = None,
) -> Dict[str, torch.Tensor]:
"""
Decoder logic for producing type constrained target sequences that maximize coverage of
their respective agendas, and minimize a denotation based loss.
"""
# We look at the epoch number and adjust the checklist cost weight if needed here.
instance_epoch_num = epoch_num[0] if epoch_num is not None else None
if self._dynamic_cost_rate is not None:
if self.training and instance_epoch_num is None:
raise RuntimeError(
"If you want a dynamic cost weight, use the "
"BucketIterator with track_epoch=True."
)
if instance_epoch_num != self._last_epoch_in_forward:
if instance_epoch_num >= self._dynamic_cost_wait_epochs:
decrement = self._checklist_cost_weight * self._dynamic_cost_rate
self._checklist_cost_weight -= decrement
logger.info("Checklist cost weight is now %f", self._checklist_cost_weight)
self._last_epoch_in_forward = instance_epoch_num
# This could be added back pretty easily with an EpochCallback passed to the Trainer (it
# just has to set the epoch number on the model, which could then be queried in here).
logger.warning(
"Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, "
"use version 0.9. We will just use the static checklist cost weight."
)
batch_size = len(worlds)

initial_rnn_state = self._get_initial_rnn_state(sentence)
Expand Down
6 changes: 4 additions & 2 deletions allennlp_semparse/models/nlvr/nlvr_semantic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self): # type: ignore
def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
embedded_input = self._sentence_embedder(sentence)
# (batch_size, sentence_length)
sentence_mask = util.get_text_field_mask(sentence).float()
sentence_mask = util.get_text_field_mask(sentence)

batch_size = embedded_input.size(0)

Expand Down Expand Up @@ -217,7 +217,9 @@ def _create_grammar_state(
return GrammarStatelet([START_SYMBOL], translated_valid_actions, world.is_nonterminal)

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
time, to finalize predictions. We only transform the action string sequences into logical
Expand Down
6 changes: 4 additions & 2 deletions allennlp_semparse/models/text2sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(
trailing dimension.
"""
embedded_utterance = self._utterance_embedder(tokens)
mask = util.get_text_field_mask(tokens).float()
mask = util.get_text_field_mask(tokens)
batch_size = embedded_utterance.size(0)

# (batch_size, num_tokens, encoder_output_dim)
Expand Down Expand Up @@ -385,7 +385,9 @@ def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> Gramm
)

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
time, to finalize predictions. This is (confusingly) a separate notion from the "decoder"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,6 @@ def forward(
actions, best_final_states, world, target_values, metadata, outputs
)
return outputs


default_predictor = "wikitables-parser"
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def _get_initial_rnn_and_grammar_state(
table_text = table["text"]
# (batch_size, question_length, embedding_dim)
embedded_question = self._question_embedder(question)
question_mask = util.get_text_field_mask(question).float()
question_mask = util.get_text_field_mask(question)
# (batch_size, num_entities, num_entity_tokens, embedding_dim)
embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()
table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1)

batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
num_question_tokens = embedded_question.size(1)
Expand Down Expand Up @@ -740,7 +740,9 @@ def _compute_validation_outputs(
outputs["question_tokens"] = [x["question_tokens"] for x in metadata]

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
time, to finalize predictions. This is (confusingly) a separate notion from the "decoder"
Expand Down
6 changes: 3 additions & 3 deletions allennlp_semparse/parsimonious_languages/worlds/atis_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _update_grammar(self):
new_grammar["col_ref"],
Literal("BETWEEN"),
new_grammar["time_range_start"],
Literal(f"AND"),
Literal("AND"),
new_grammar["time_range_end"],
],
),
Expand All @@ -124,7 +124,7 @@ def _update_grammar(self):
Literal("NOT"),
Literal("BETWEEN"),
new_grammar["time_range_start"],
Literal(f"AND"),
Literal("AND"),
new_grammar["time_range_end"],
],
),
Expand All @@ -135,7 +135,7 @@ def _update_grammar(self):
Literal("not"),
Literal("BETWEEN"),
new_grammar["time_range_start"],
Literal(f"AND"),
Literal("AND"),
new_grammar["time_range_end"],
],
),
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# To be changed to allennlp>=1.0 once that's released.
git+git://github.com/allenai/allennlp@fa14bd8c177a12709a2fac0616eafaa9ad430b0a
allennlp==1.0

# Used to create grammars for parsing SQL
parsimonious>=0.8.0
Expand All @@ -15,3 +14,6 @@ editdistance

# Used for the type system for some languages
nltk

# Used for some tests
flaky
3 changes: 1 addition & 2 deletions test_fixtures/atis/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
"dropout": 0.5,
"database_file": "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db"
},
"iterator": {
"type": "basic",
"data_loader": {
"batch_size" : 4
},
"trainer": {
Expand Down
2 changes: 1 addition & 1 deletion test_fixtures/elmo/config/characters_token_embedder.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
["transitions$", {"type": "l2", "alpha": 0.01}]
]
},
"iterator": {"type": "basic", "batch_size": 32},
"data_loader": {"batch_size": 32},
"trainer": {
"optimizer": "adam",
"num_epochs": 5,
Expand Down
11 changes: 6 additions & 5 deletions test_fixtures/nlvr_coverage_semantic_parser/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@
"dropout": 0.3,
"penalize_non_agenda_actions": true
},
"iterator": {
"type": "bucket",
"track_epoch": true,
"padding_noise": 0.0,
"batch_size" : 4
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 4
}
},
"trainer": {
"num_epochs": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
"penalize_non_agenda_actions": true,
"initial_mml_model_file": "test_fixtures/semantic_parsing/nlvr_direct_semantic_parser/serialization/model.tar.gz"
},
"iterator": {
"type": "bucket",
"track_epoch": true,
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
},
"penalize_non_agenda_actions": true
},
"iterator": {
"type": "bucket",
"track_epoch": true,
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 1,
Expand Down
10 changes: 6 additions & 4 deletions test_fixtures/nlvr_direct_semantic_parser/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
"attention": {"type": "dot_product"},
"dropout": 0.2
},
"iterator": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 1,
Expand Down
3 changes: 1 addition & 2 deletions test_fixtures/text2sql/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
"input_attention": {"type": "dot_product"},
"dropout": 0.0
},
"iterator": {
"type": "basic",
"data_loader": {
"batch_size" : 4
},
"trainer": {
Expand Down
10 changes: 6 additions & 4 deletions test_fixtures/wikitables/experiment-elmo-no-features.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@
"use_neighbor_similarity_for_linking": true,
"tables_directory": "test_fixtures/data/wikitables/"
},
"iterator": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 2,
Expand Down
6 changes: 2 additions & 4 deletions test_fixtures/wikitables/experiment-erm.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@
"decoder_num_finished_states": 100,
"attention": {"type": "dot_product"}
},
"iterator": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_size": 2
},
"trainer": {
"num_epochs": 2,
Expand Down
10 changes: 6 additions & 4 deletions test_fixtures/wikitables/experiment-mixture.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
"max_decoding_steps": 200,
"attention": {"type": "dot_product"}
},
"iterator": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 2,
Expand Down
10 changes: 6 additions & 4 deletions test_fixtures/wikitables/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
"max_decoding_steps": 200,
"attention": {"type": "dot_product"}
},
"iterator": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size" : 2
"data_loader": {
"batch_sampler": {
"type": "bucket",
"padding_noise": 0.0,
"batch_size": 2
}
},
"trainer": {
"num_epochs": 2,
Expand Down
Loading

0 comments on commit 28b7cf5

Please sign in to comment.