diff --git a/allennlp_semparse/domain_languages/__init__.py b/allennlp_semparse/domain_languages/__init__.py index 61b90d1..7784643 100644 --- a/allennlp_semparse/domain_languages/__init__.py +++ b/allennlp_semparse/domain_languages/__init__.py @@ -5,4 +5,5 @@ predicate_with_side_args, ) from allennlp_semparse.domain_languages.nlvr_language import NlvrLanguage +from allennlp_semparse.domain_languages.nlvr_language_v2 import NlvrLanguageFuncComposition from allennlp_semparse.domain_languages.wikitables_language import WikiTablesLanguage diff --git a/allennlp_semparse/domain_languages/nlvr_language_v2.py b/allennlp_semparse/domain_languages/nlvr_language_v2.py new file mode 100644 index 0000000..d75dce7 --- /dev/null +++ b/allennlp_semparse/domain_languages/nlvr_language_v2.py @@ -0,0 +1,715 @@ +from collections import defaultdict +from typing import Callable, Dict, List, NamedTuple, Set + +from allennlp.common.util import JsonDict + +from allennlp_semparse.domain_languages.domain_language import DomainLanguage, predicate + + +class Object: + """ + ``Objects`` are the geometric shapes in the NLVR domain. They have values for attributes shape, + color, x_loc, y_loc and size. We take a dict read from the JSON file and store it here, and + define a get method for getting the attribute values. We need this to be hashable because need + to make sets of ``Objects`` during execution, which get passed around between functions. + + Parameters + ---------- + attributes : ``JsonDict`` + The dict for each object from the json file. + """ + + def __init__(self, attributes: JsonDict, box_id: str) -> None: + object_color = attributes["color"].lower() + # The dataset has a hex code only for blue for some reason. + if object_color.startswith("#"): + self.color = "blue" + else: + self.color = object_color + object_shape = attributes["type"].lower() + self.shape = object_shape + self.x_loc = attributes["x_loc"] + self.y_loc = attributes["y_loc"] + self.size = attributes["size"] + self._box_id = box_id + + def __str__(self): + if self.size == 10: + size = "small" + elif self.size == 20: + size = "medium" + else: + size = "big" + return f"{size} {self.color} {self.shape} at ({self.x_loc}, {self.y_loc}) in {self._box_id}" + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return str(self) == str(other) + + +class Box: + """ + This class represents each box containing objects in NLVR. + + Parameters + ---------- + objects_list : ``List[JsonDict]`` + List of objects in the box, as given by the json file. + box_id : ``int`` + An integer identifying the box index (0, 1 or 2). + """ + + def __init__(self, objects_list: List[JsonDict], box_id: int) -> None: + self._name = f"box {box_id + 1}" + self._objects_string = str([str(_object) for _object in objects_list]) + self.objects = {Object(object_dict, self._name) for object_dict in objects_list} + self.colors = {obj.color for obj in self.objects} + self.shapes = {obj.shape for obj in self.objects} + + def __str__(self): + # Add box_name to str to differentiate boxes if object set if exactly same + return self._name + ": " + self._objects_string + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return str(self) == str(other) + + +class Color(NamedTuple): + color: str + + +class Shape(NamedTuple): + shape: str + + +class NlvrLanguageFuncComposition(DomainLanguage): + def __init__( + self, + boxes: Set[Box], + allow_function_currying: bool = True, + allow_function_composition: bool = True, + metadata=None, + ) -> None: + self.boxes = boxes + self.objects: Set[Object] = set() + for box in self.boxes: + self.objects.update(box.objects) + allowed_constants = { + "color_blue": Color("blue"), + "color_black": Color("black"), + "color_yellow": Color("yellow"), + "shape_triangle": Shape("triangle"), + "shape_square": Shape("square"), + "shape_circle": Shape("circle"), + "1": 1, + "2": 2, + "3": 3, + "4": 4, + "5": 5, + "6": 6, + "7": 7, + "8": 8, + "9": 9, + } + super().__init__( + start_types={bool}, + allowed_constants=allowed_constants, + allow_function_currying=allow_function_currying, + allow_function_composition=allow_function_composition, + ) + + # Removing the "Set[Object] -> [, Set[Object]]" production from grammar + # calling to populate productions dictionary + _ = self.get_nonterminal_productions() + self._nonterminal_productions["Set[Object]"].remove( + "Set[Object] -> [, Set[Object]]" + ) + + # Mapping from terminal strings to productions that produce them. + # Eg.: "yellow" -> " -> yellow" + # We use this in the agenda-related methods, and some models that use this language look at + # this field to know how many terminals to plan for. + self.terminal_productions: Dict[str, str] = {} + for name, types in self._function_types.items(): + self.terminal_productions[name] = f"{types[0]} -> {name}" + + self.metadata = metadata + + # These first two methods are about getting an "agenda", which, given an input utterance, + # tries to guess what production rules should be needed in the logical form. + + def get_agenda_for_sentence(self, sentence: str) -> List[str]: + """ + Given a ``sentence``, returns a list of actions the sentence triggers as an ``agenda``. The + ``agenda`` can be used while by a parser to guide the decoder. sequences as possible. This + is a simplistic mapping at this point, and can be expanded. + + Parameters + ---------- + sentence : ``str`` + The sentence for which an agenda will be produced. + """ + agenda = [] + sentence = sentence.lower() + + if sentence.startswith("there is a box") or sentence.startswith("there is a tower "): + agenda.append(self.terminal_productions["box_exists"]) + agenda.append(self.terminal_productions["box_filter"]) + agenda.append(self.terminal_productions["all_boxes"]) + # # TODO(nitish): v3, v4, v5 - added this elif; v5 added the "of a box" condition + elif ("box" in sentence or "tower" in sentence) and not ( + "of a box" in sentence or "of a tower" in sentence + ): + agenda.append(self.terminal_productions["box_filter"]) + agenda.append(self.terminal_productions["all_boxes"]) + elif sentence.startswith("there is a "): + agenda.append(self.terminal_productions["object_exists"]) + + # TODO(nitish): v2, v3, v4, v5 - removed the if-condition; object-filters can be used inside box_filter + # if " -> box_exists" not in agenda: + # These are object filters and do not apply if we have a box_exists at the top. + if "touch" in sentence: + if "top" in sentence: + agenda.append(self.terminal_productions["touch_top"]) + elif "bottom" in sentence or "base" in sentence: + agenda.append(self.terminal_productions["touch_bottom"]) + elif "corner" in sentence: + agenda.append(self.terminal_productions["touch_corner"]) + elif "right" in sentence: + agenda.append(self.terminal_productions["touch_right"]) + elif "left" in sentence: + agenda.append(self.terminal_productions["touch_left"]) + elif "wall" in sentence or "edge" in sentence: + agenda.append(self.terminal_productions["touch_wall"]) + else: + agenda.append(self.terminal_productions["touch_object"]) + else: + # The words "top" and "bottom" may be referring to top and bottom blocks in a tower. + if "top" in sentence: + agenda.append(self.terminal_productions["top"]) + elif "bottom" in sentence or "base" in sentence: + agenda.append(self.terminal_productions["bottom"]) + + if " not " in sentence: + agenda.append(self.terminal_productions["negate_filter"]) + + if self.terminal_productions["box_filter"] not in agenda: + if " contains " in sentence or " has " in sentence: + agenda.append(self.terminal_productions["all_boxes"]) + agenda.append(self.terminal_productions["box_filter"]) + + # This takes care of shapes, colors, top, bottom, big, small etc. + for constant, production in self.terminal_productions.items(): + # TODO(pradeep): Deal with constant names with underscores. + if "top" in constant or "bottom" in constant: + # We already dealt with top, bottom, touch_top and touch_bottom above. + continue + if constant in sentence: + # TODO(nitish) v4,v5 -- choose `yellow()` instead of `color_yellow` action + agenda.append(production) + + number_productions = self._get_number_productions(sentence) + for production in number_productions: + agenda.append(production) + if not agenda: + # None of the rules above was triggered! + if "box" in sentence: + agenda.append(self.terminal_productions["box_filter"]) + agenda.append(self.terminal_productions["all_boxes"]) + else: + agenda.append(self.terminal_productions["all_objects"]) + return agenda + + @staticmethod + def _get_number_productions(sentence: str) -> List[str]: + """ + Gathers all the numbers in the sentence, and returns productions that lead to them. + """ + # The mapping here is very simple and limited, which also shouldn't be a problem + # because numbers seem to be represented fairly regularly. + number_strings = { + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + number_productions = [] + tokens = sentence.split() + numbers = number_strings.values() + for token in tokens: + if token in numbers: + number_productions.append(f"int -> {token}") + elif token in number_strings: + number_productions.append(f"int -> {number_strings[token]}") + return number_productions + + def __eq__(self, other): + if isinstance(self, other.__class__): + return self.boxes == other.boxes and self.objects == other.objects + return NotImplemented + + # All methods below here are predicates in the NLVR language, or helper methods for them. + + @predicate + def all_boxes(self) -> Set[Box]: + return self.boxes + + @predicate + def all_objects(self) -> Set[Object]: + return self.objects + + @predicate + def box_exists(self, boxes: Set[Box]) -> bool: + return len(boxes) > 0 + + @predicate + def object_exists(self, objects: Set[Object]) -> bool: + return len(objects) > 0 + + @predicate + def object_in_box(self, box: Set[Box]) -> Set[Object]: + return_set: Set[Object] = set() + for box_ in box: + return_set.update(box_.objects) + return return_set + + @predicate + def black(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.color == "black"} + + @predicate + def blue(self, objects: Set[Object]) -> Set[Object]: + # print("Blue input:{}".format(len(objects))) + # print([str(o) for o in objects]) + output_set = {obj for obj in objects if obj.color == "blue"} + # print("output:{}".format(len(output_set))) + # print([str(o) for o in output_set]) + return {obj for obj in objects if obj.color == "blue"} + + @predicate + def yellow(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.color == "yellow"} + + @predicate + def circle(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.shape == "circle"} + + @predicate + def square(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.shape == "square"} + + @predicate + def triangle(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.shape == "triangle"} + + @predicate + def same_color(self, objects: Set[Object]) -> Set[Object]: + """ + Filters the set of objects, and returns those objects whose color is the most frequent + color in the initial set of objects, if the highest frequency is greater than 1, or an + empty set otherwise. + + This is an unusual name for what the method does, but just as ``blue`` filters objects to + those that are blue, this filters objects to those that are of the same color. + """ + return self._get_objects_with_same_attribute(objects, lambda x: x.color) + + @predicate + def same_shape(self, objects: Set[Object]) -> Set[Object]: + """ + Filters the set of objects, and returns those objects whose color is the most frequent + color in the initial set of objects, if the highest frequency is greater than 1, or an + empty set otherwise. + + This is an unusual name for what the method does, but just as ``triangle`` filters objects + to those that are triangles, this filters objects to those that are of the same shape. + """ + return self._get_objects_with_same_attribute(objects, lambda x: x.shape) + + @predicate + def touch_bottom(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.y_loc + obj.size == 100} + + @predicate + def touch_left(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.x_loc == 0} + + @predicate + def touch_top(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.y_loc == 0} + + @predicate + def touch_right(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.x_loc + obj.size == 100} + + @predicate + def touch_wall(self, objects: Set[Object]) -> Set[Object]: + return_set: Set[Object] = set() + return return_set.union( + self.touch_top(objects), + self.touch_left(objects), + self.touch_right(objects), + self.touch_bottom(objects), + ) + + @predicate + def touch_corner(self, objects: Set[Object]) -> Set[Object]: + return_set: Set[Object] = set() + return return_set.union( + self.touch_top(objects).intersection(self.touch_right(objects)), + self.touch_top(objects).intersection(self.touch_left(objects)), + self.touch_bottom(objects).intersection(self.touch_right(objects)), + self.touch_bottom(objects).intersection(self.touch_left(objects)), + ) + + @predicate + def touch_object(self, objects: Set[Object]) -> Set[Object]: + """ + Returns all objects that touch the given set of objects. + """ + objects_per_box = self._separate_objects_by_boxes(objects) + return_set = set() + for box, box_objects in objects_per_box.items(): + candidate_objects = box.objects + for object_ in box_objects: + for candidate_object in candidate_objects: + if self._objects_touch_each_other(object_, candidate_object): + return_set.add(candidate_object) + return return_set + + @predicate + def top(self, objects: Set[Object]) -> Set[Object]: + """ + Return the topmost objects (i.e. minimum y_loc). The comparison is done separately for each + box. + """ + objects_per_box = self._separate_objects_by_boxes(objects) + return_set: Set[Object] = set() + for _, box_objects in objects_per_box.items(): + min_y_loc = min([obj.y_loc for obj in box_objects]) + return_set.update({obj for obj in box_objects if obj.y_loc == min_y_loc}) + return return_set + + @predicate + def bottom(self, objects: Set[Object]) -> Set[Object]: + """ + Return the bottom most objects(i.e. maximum y_loc). The comparison is done separately for + each box. + """ + objects_per_box = self._separate_objects_by_boxes(objects) + return_set: Set[Object] = set() + for _, box_objects in objects_per_box.items(): + max_y_loc = max([obj.y_loc for obj in box_objects]) + return_set.update({obj for obj in box_objects if obj.y_loc == max_y_loc}) + return return_set + + @predicate + def above(self, objects: Set[Object]) -> Set[Object]: + """ + Returns the set of objects in the same boxes that are above the given objects. That is, if + the input is a set of two objects, one in each box, we will return a union of the objects + above the first object in the first box, and those above the second object in the second box. + """ + # print("Above in:{}".format(len(objects))) + # print([str(o) for o in objects]) + objects_per_box: Dict[Box, List[Object]] = self._separate_objects_by_boxes(objects) + return_set = set() + for box in objects_per_box: + # min_y_loc corresponds to the top-most object. + # min_y_loc = min([obj.y_loc for obj in objects_per_box[box]]) + # TODO(nitish): changing from returning objs above the top-most input object, return objs that are above + # any input object + y_locs = [obj.y_loc for obj in objects_per_box[box]] + for candidate_obj in box.objects: + if any(candidate_obj.y_loc < y_loc for y_loc in y_locs): + # if candidate_obj.y_loc < min_y_loc: + return_set.add(candidate_obj) + # print("Above out:{}".format(len(return_set))) + # print([str(o) for o in return_set]) + return return_set + + @predicate + def below(self, objects: Set[Object]) -> Set[Object]: + """ + Returns the set of objects in the same boxes that are below the given objects. That is, if + the input is a set of two objects, one in each box, we will return a union of the objects + below the first object in the first box, and those below the second object in the second box. + """ + objects_per_box = self._separate_objects_by_boxes(objects) + return_set = set() + for box in objects_per_box: + # max_y_loc corresponds to the bottom-most object. + # max_y_loc = max([obj.y_loc for obj in objects_per_box[box]]) + # TODO(nitish): changing from returning objs above the top-most input object, return objs that are above + # any input object + y_locs = [obj.y_loc for obj in objects_per_box[box]] + for candidate_obj in box.objects: + if any(candidate_obj.y_loc > y_loc for y_loc in y_locs): + # if candidate_obj.y_loc > max_y_loc: + return_set.add(candidate_obj) + return return_set + + @predicate + def small(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.size == 10} + + @predicate + def medium(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.size == 20} + + @predicate + def big(self, objects: Set[Object]) -> Set[Object]: + return {obj for obj in objects if obj.size == 30} + + @predicate + def box_count_equals(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) == count + + @predicate + def box_count_not_equals(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) != count + + @predicate + def box_count_greater(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) > count + + @predicate + def box_count_greater_equals(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) >= count + + @predicate + def box_count_lesser(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) < count + + @predicate + def box_count_lesser_equals(self, count: int, boxes: Set[Box]) -> bool: + return len(boxes) <= count + + @predicate + def object_color_all_equals(self, color: Color, objects: Set[Object]) -> bool: + return all([obj.color == color.color for obj in objects]) + + @predicate + def object_color_any_equals(self, color: Color, objects: Set[Object]) -> bool: + return any([obj.color == color.color for obj in objects]) + + @predicate + def object_color_none_equals(self, color: Color, objects: Set[Object]) -> bool: + return all([obj.color != color.color for obj in objects]) + + @predicate + def object_shape_all_equals(self, shape: Shape, objects: Set[Object]) -> bool: + return all([obj.shape == shape.shape for obj in objects]) + + @predicate + def object_shape_any_equals(self, shape: Shape, objects: Set[Object]) -> bool: + return any([obj.shape == shape.shape for obj in objects]) + + @predicate + def object_shape_none_equals(self, shape: Shape, objects: Set[Object]) -> bool: + return all([obj.shape != shape.shape for obj in objects]) + + @predicate + def object_count_equals(self, count: int, objects: Set[Object]) -> bool: + return len(objects) == count + + @predicate + def object_count_not_equals(self, count: int, objects: Set[Object]) -> bool: + return len(objects) != count + + @predicate + def object_count_greater(self, count: int, objects: Set[Object]) -> bool: + return len(objects) > count + + @predicate + def object_count_greater_equals(self, count: int, objects: Set[Object]) -> bool: + return len(objects) >= count + + @predicate + def object_count_lesser(self, count: int, objects: Set[Object]) -> bool: + return len(objects) < count + + @predicate + def object_count_lesser_equals(self, count: int, objects: Set[Object]) -> bool: + return len(objects) <= count + + @predicate + def object_color_count_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) == count + + @predicate + def object_color_count_not_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) != count + + @predicate + def object_color_count_greater(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) > count + + @predicate + def object_color_count_greater_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) >= count + + @predicate + def object_color_count_lesser(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) < count + + @predicate + def object_color_count_lesser_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.color for obj in objects}) <= count + + @predicate + def object_shape_count_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) == count + + @predicate + def object_shape_count_not_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) != count + + @predicate + def object_shape_count_greater(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) > count + + @predicate + def object_shape_count_greater_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) >= count + + @predicate + def object_shape_count_lesser(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) < count + + @predicate + def object_shape_count_lesser_equals(self, count: int, objects: Set[Object]) -> bool: + return len({obj.shape for obj in objects}) <= count + + @predicate + def box_filter( + self, boxes: Set[Box], filter_function: Callable[[Set[Object]], bool] + ) -> Set[Box]: + filtered_boxes = set() + for box in boxes: + # if self.metadata is not None and self.metadata["identifier"] == "3840": + # print("\nBOX - {}".format(box_num)) + # Wrapping a single box in a {set} + objects = self.object_in_box(box={box}) + if filter_function(objects): + filtered_boxes.add(box) + # if self.metadata is not None and self.metadata["identifier"] == "3840": + # import pdb + # pdb.set_trace() + return filtered_boxes + + @predicate + def box_filter_and( + self, + box_filter_1: Callable[[Set[Object]], bool], + box_filter_2: Callable[[Set[Object]], bool], + ) -> Callable[[Set[Object]], bool]: + def new_box_filter(objects: Set[Object]) -> bool: + return box_filter_1(objects) and box_filter_2(objects) + + return new_box_filter + + @predicate + def object_shape_same(self, objects: Set[Object]) -> bool: + # Empty set is True + if len(objects) == 0: + return True + else: + return self.object_shape_count_equals(1, objects) + + @predicate + def object_color_same(self, objects: Set[Object]) -> bool: + # Empty set is True + if len(objects) == 0: + return True + else: + return self.object_color_count_equals(1, objects) + + @predicate + def object_shape_different(self, objects: Set[Object]) -> bool: + # Empty set is False + if len(objects) == 0: + return False + else: + return self.object_shape_count_not_equals(1, objects) + + @predicate + def object_color_different(self, objects: Set[Object]) -> bool: + # Empty set is False + if len(objects) == 0: + return False + else: + return self.object_color_count_not_equals(1, objects) + + @predicate + def negate_filter( + self, filter_function: Callable[[Set[Object]], Set[Object]] + ) -> Callable[[Set[Object]], Set[Object]]: + def negated_filter(objects: Set[Object]) -> Set[Object]: + return objects.difference(filter_function(objects)) + + return negated_filter + + def _objects_touch_each_other(self, object1: Object, object2: Object) -> bool: + """ + Returns true iff the objects touch each other. + """ + in_vertical_range = ( + object1.y_loc <= object2.y_loc + object2.size + and object1.y_loc + object1.size >= object2.y_loc + ) + in_horizantal_range = ( + object1.x_loc <= object2.x_loc + object2.size + and object1.x_loc + object1.size >= object2.x_loc + ) + touch_side = ( + object1.x_loc + object1.size == object2.x_loc + or object2.x_loc + object2.size == object1.x_loc + ) + touch_top_or_bottom = ( + object1.y_loc + object1.size == object2.y_loc + or object2.y_loc + object2.size == object1.y_loc + ) + return (in_vertical_range and touch_side) or (in_horizantal_range and touch_top_or_bottom) + + def _separate_objects_by_boxes(self, objects: Set[Object]) -> Dict[Box, List[Object]]: + """ + Given a set of objects, separate them by the boxes they belong to and return a dict. + """ + objects_per_box: Dict[Box, List[Object]] = defaultdict(list) + for box in self.boxes: + for object_ in objects: + if object_ in box.objects: + objects_per_box[box].append(object_) + return objects_per_box + + def _get_objects_with_same_attribute( + self, objects: Set[Object], attribute_function: Callable[[Object], str] + ) -> Set[Object]: + """ + Returns the set of objects for which the attribute function returns an attribute value that + is most frequent in the initial set, if the frequency is greater than 1. If not, all + objects have different attribute values, and this method returns an empty set. + """ + objects_of_attribute: Dict[str, Set[Object]] = defaultdict(set) + for entity in objects: + objects_of_attribute[attribute_function(entity)].add(entity) + if not objects_of_attribute: + return set() + most_frequent_attribute = max( + objects_of_attribute, key=lambda x: len(objects_of_attribute[x]) + ) + if len(objects_of_attribute[most_frequent_attribute]) <= 1: + return set() + return objects_of_attribute[most_frequent_attribute] diff --git a/tests/domain_languages/nlvr_language_v2_test.py b/tests/domain_languages/nlvr_language_v2_test.py new file mode 100644 index 0000000..0f5416b --- /dev/null +++ b/tests/domain_languages/nlvr_language_v2_test.py @@ -0,0 +1,481 @@ +import json + +from .. import SemparseTestCase + +from allennlp_semparse.domain_languages import NlvrLanguageFuncComposition +from allennlp_semparse.domain_languages.nlvr_language_v2 import Box + + +class TestNlvrLanguage(SemparseTestCase): + def setup_method(self): + super().setup_method() + test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl" + data = [json.loads(line)["structured_rep"] for line in open(test_filename).readlines()] + box_lists = [ + [Box(object_reps, i) for i, object_reps in enumerate(box_rep)] for box_rep in data + ] + self.languages = [NlvrLanguageFuncComposition(boxes) for boxes in box_lists] + # y_loc increases as we go down from top to bottom, and x_loc from left to right. That is, + # the origin is at the top-left corner. + custom_rep = [ + [ + {"y_loc": 79, "size": 20, "type": "triangle", "x_loc": 27, "color": "Yellow"}, + {"y_loc": 55, "size": 10, "type": "circle", "x_loc": 47, "color": "Black"}, + ], + [ + {"y_loc": 44, "size": 30, "type": "square", "x_loc": 10, "color": "#0099ff"}, + {"y_loc": 74, "size": 30, "type": "square", "x_loc": 40, "color": "Yellow"}, + ], + [{"y_loc": 60, "size": 10, "type": "triangle", "x_loc": 12, "color": "#0099ff"}], + ] + self.custom_language = NlvrLanguageFuncComposition( + [Box(object_rep, i) for i, object_rep in enumerate(custom_rep)] + ) + + def test_logical_form_with_assert_executes_correctly(self): + executor = self.languages[0] + # Utterance is "There is a circle closely touching a corner of a box." and label is "True". + logical_form_true = "(object_count_greater_equals 1 (touch_corner (circle all_objects)))" + assert executor.execute(logical_form_true) is True + logical_form_false = "(object_count_equals 9 (touch_corner (circle all_objects)))" + assert executor.execute(logical_form_false) is False + + def test_logical_form_with_box_filter_executes_correctly(self): + executor = self.languages[2] + # Utterance is "There is a box without a blue item." and label is "False". + logical_form = "(box_exists (box_filter all_boxes (object_color_none_equals color_blue)))" + assert executor.execute(logical_form) is False + + def test_logical_form_with_box_filter_within_object_filter_executes_correctly(self): + executor = self.languages[2] + # Utterance is "There are at least three blue items in boxes with blue items" and label + # is "True". + # TODO(nitish): I've just converted the old lf into the new one (which I think is wrong; see below) + # Original logical_form below I think is wrong; this would select box with at least 1 blue item, + # and then check if the total number of objects in those boxes is greater than 3 + logical_form = "(object_count_greater_equals 3 \ + (object_in_box \ + (box_filter all_boxes (object_color_any_equals color_blue))))" + # TODO(nitish): Original logical_form below I think is wrong; this would select box with at least 1 blue item, + # and then check if the total number of objects in those boxes is greater than 3 + assert executor.execute(logical_form) is True + + def test_logical_form_with_same_color_executes_correctly(self): + executor = self.languages[1] + # Utterance is "There are exactly two blocks of the same color." and label is "True". + logical_form = "(object_count_equals 2 (same_color all_objects))" + assert executor.execute(logical_form) is True + + def test_logical_form_with_same_shape_executes_correctly(self): + executor = self.languages[0] + # Utterance is "There are less than three black objects of the same shape" and label is "False". + logical_form = "(object_count_lesser 3 (same_shape (black (all_objects))))" + assert executor.execute(logical_form) is False + + def test_logical_form_with_touch_wall_executes_correctly(self): + executor = self.languages[0] + # Utterance is "There are two black circles touching a wall" and label is "False". + logical_form = "(object_count_greater_equals 2 (touch_wall (black (circle (all_objects)))))" + assert executor.execute(logical_form) is False + + def test_logical_form_with_not_executes_correctly(self): + executor = self.languages[2] + # Utterance is "There are at most two medium triangles not touching a wall." and label is "True". + logical_form = ( + "(object_count_lesser_equals 2 ((negate_filter touch_wall) " + "(medium (triangle (all_objects)))))" + ) + assert executor.execute(logical_form) is True + + def test_logical_form_with_color_comparison_executes_correctly(self): + executor = self.languages[0] + # Utterance is "The color of the circle touching the wall is black." and label is "True". + logical_form = "(object_color_all_equals color_black (circle (touch_wall (all_objects))))" + assert executor.execute(logical_form) is True + + def test_spatial_relations_return_objects_in_the_same_box(self): + # "above", "below", "top", "bottom" are relations defined only for objects within the same + # box. So they should not return objects from other boxes. + # Asserting that the color of the objects above the yellow triangle is only black (it is not + # yellow or blue, which are colors of objects from other boxes) + assert ( + self.custom_language.execute( + "(object_color_all_equals color_black (above (yellow (triangle all_objects))))" + ) + is True + ) + # Asserting that the only shape below the blue square is a square. + assert ( + self.custom_language.execute( + "(object_shape_all_equals shape_square (below (blue (square all_objects))))" + ) + is True + ) + # Asserting the shape of the object at the bottom in the box with a circle is triangle. + logical_form = ( + "(object_shape_all_equals shape_triangle (bottom (object_in_box" + " (box_filter all_boxes (object_shape_any_equals shape_circle)))))" + ) + assert self.custom_language.execute(logical_form) is True + + # Asserting the shape of the object at the top of the box with all squares is a square (!). + logical_form = ( + "(object_shape_all_equals shape_square (top (object_in_box" + " (box_filter all_boxes (object_shape_all_equals shape_square)))))" + ) + assert self.custom_language.execute(logical_form) is True + + def test_touch_object_executes_correctly(self): + # Assert that there is a yellow square touching a blue square. + assert ( + self.custom_language.execute( + "(object_exists (yellow (square (touch_object (blue (square all_objects))))))" + ) + is True + ) + # Assert that the triangle does not touch the circle (they are out of vertical range). + assert ( + self.custom_language.execute( + "(object_shape_none_equals shape_circle (touch_object (triangle all_objects)))" + ) + is True + ) + + def test_spatial_relations_with_objects_from_different_boxes(self): + # When the objects are from different boxes, top and bottom should return objects from + # respective boxes. + # There are triangles in two boxes, so top should return the top objects from both boxes. + assert ( + self.custom_language.execute( + "(object_count_equals 2 (top (object_in_box" + " (box_filter all_boxes (object_shape_any_equals shape_triangle)))))" + ) + is True + ) + + def test_same_and_different_execute_correctly(self): + # All the objects in the box with two objects of the same shape are squares. + # TODO(nitish): This seems like the incorrect parse of the utterance above; a box with 3 triangles and 2 squares + # would not satisfy the logical-form, but should satisfy the utterance. Another read of the utterance could be + # ((box with two objects) of the same shape). Nevertheless, the logical form can be checked for execution. + assert ( + self.custom_language.execute( + "(object_shape_all_equals shape_square (object_in_box" + " (box_filter all_boxes" + " (box_filter_and object_shape_same (object_count_equals 2)))))" + ) + is True + ) + # There is a circle in the box with objects of different shapes. + assert ( + self.custom_language.execute( + "(object_shape_any_equals shape_circle (object_in_box " + "(box_filter all_boxes object_shape_different)))" + ) + is True + ) + + def test_get_action_sequence_handles_multi_arg_functions(self): + language = self.languages[0] + # box_color_filter + logical_form = "(box_exists (box_filter all_boxes (object_color_all_equals color_blue)))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert ( + "Set[Box] -> [:Set[Box]>, Set[Box], ]" + in action_sequence + ) + assert ":Set[Box]> -> box_filter" in action_sequence + assert " -> [, Color]" in action_sequence + assert " -> object_color_all_equals" in action_sequence + + # box_shape_filter + logical_form = "(box_exists (box_filter all_boxes (object_shape_none_equals shape_square)))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert " -> [, Shape]" in action_sequence + assert " -> object_shape_none_equals" + assert "Shape -> shape_square" in action_sequence + + # box_count_filter + logical_form = "(box_exists (box_filter all_boxes (object_count_equals 3)))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert " -> [, int]" in action_sequence + assert " -> object_count_equals" + assert "int -> 3" in action_sequence + + # box_object_shape_different_filter + logical_form = "(box_exists (box_filter all_boxes (* object_shape_different yellow)))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert ( + " -> [*, , ]" + in action_sequence + ) + assert " -> object_shape_different" in action_sequence + assert " -> yellow" + + # assert_color + logical_form = "(object_color_all_equals color_blue all_objects)" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert "bool -> [, Color, Set[Object]]" in action_sequence + + # assert_shape + logical_form = "(object_shape_all_equals shape_square all_objects)" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert "bool -> [, Shape, Set[Object]]" in action_sequence + + # assert_box_count + logical_form = "(box_count_equals 1 all_boxes)" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert "bool -> [, int, Set[Box]]" in action_sequence + + # assert_object_count + logical_form = "(object_count_equals 1 all_objects)" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert "bool -> [, int, Set[Object]]" in action_sequence + + def test_logical_form_with_object_filter_returns_correct_action_sequence(self): + language = self.languages[0] + logical_form = "(object_color_all_equals color_black ((* circle touch_wall) all_objects))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert action_sequence == [ + "@start@ -> bool", + "bool -> [, Color, Set[Object]]", + " -> object_color_all_equals", + "Color -> color_black", + "Set[Object] -> [, Set[Object]]", + " -> [*, , ]", + " -> circle", + " -> touch_wall", + "Set[Object] -> all_objects", + ] + + def test_logical_form_with_negate_filter_returns_correct_action_sequence(self): + language = self.languages[0] + logical_form = "(object_exists ((negate_filter touch_wall) all_objects))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + negate_filter_production = ( + " -> " + "[<:>, " + "]" + ) + assert action_sequence == [ + "@start@ -> bool", + "bool -> [, Set[Object]]", + " -> object_exists", + "Set[Object] -> [, Set[Object]]", + negate_filter_production, + "<:> -> negate_filter", + " -> touch_wall", + "Set[Object] -> all_objects", + ] + + def test_logical_form_with_box_filter_returns_correct_action_sequence(self): + language = self.languages[0] + logical_form = "(box_exists (box_filter all_boxes (object_color_none_equals color_blue)))" + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert action_sequence == [ + "@start@ -> bool", + "bool -> [, Set[Box]]", + " -> box_exists", + "Set[Box] -> [:Set[Box]>, Set[Box], ]", + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> [, Color]", + " -> object_color_none_equals", + "Color -> color_blue", + ] + + def test_logical_form_with_box_filter_and_returns_correct_action_sequence(self): + language = self.languages[0] + logical_form = ( + "(box_count_greater 3 (box_filter all_boxes" + " (box_filter_and (object_count_not_equals 4)" + " (object_shape_any_equals shape_circle))))" + ) + action_sequence = language.logical_form_to_action_sequence(logical_form) + assert action_sequence == [ + "@start@ -> bool", + "bool -> [, int, Set[Box]]", + " -> box_count_greater", + "int -> 3", + "Set[Box] -> [:Set[Box]>, Set[Box], ]", + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> [<,:>, , ]", + "<,:> -> box_filter_and", + " -> [, int]", + " -> object_count_not_equals", + "int -> 4", + " -> [, Shape]", + " -> object_shape_any_equals", + "Shape -> shape_circle", + ] + + def test_complex_action_sequence_execution(self): + language = self.languages[0] + action_sequence = [ + "@start@ -> bool", + "bool -> [, int, Set[Box]]", + " -> box_count_equals", + "int -> 2", + "Set[Box] -> [:Set[Box]>, Set[Box], ]", + ":Set[Box]> -> [*, , :Set[Box]>]", + " -> [:Set[Box]>, ]", + ":Set[Box]> -> [*, , :Set[Box]>]", + " -> [:Set[Box]>, ]", + ":Set[Box]> -> box_filter", + " -> object_exists", + ":Set[Box]> -> box_filter", + " -> object_exists", + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> [*, , ]", + " -> object_exists", + " -> yellow", + ] + + # Logical-form + # '(box_count_equals 2 ((* ((* (box_filter object_exists) box_filter) object_exists) box_filter) all_boxes (* object_exists yellow)))' + result = language.execute_action_sequence(action_sequence=action_sequence) + assert result is False + + action_sequence = [ + "@start@ -> bool", + "bool -> [, int, Set[Box]]", + " -> box_count_equals", + "int -> 1", + "Set[Box] -> [:Set[Box]>, Set[Box], ]", + ":Set[Box]> -> [*, , :Set[Box]>]", + " -> [:Set[Box]>, ]", + ":Set[Box]> -> [*, , :Set[Box]>]", + " -> [:Set[Box]>, ]", + ":Set[Box]> -> box_filter", + " -> object_exists", + ":Set[Box]> -> box_filter", + " -> object_exists", + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> [*, , ]", + " -> object_exists", + " -> black", + ] + # logical-form + # '(box_count_equals 1 ((* ((* (box_filter object_exists) box_filter) object_exists) box_filter) all_boxes (* object_exists black)))' + + result = language.execute_action_sequence(action_sequence=action_sequence) + assert result is True + + def test_get_agenda_for_sentence(self): + language = self.languages[0] + agenda = language.get_agenda_for_sentence("there is a tower with exactly two yellow blocks") + assert set(agenda) == { + " -> box_exists", + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> yellow", + "int -> 2", + } + agenda = language.get_agenda_for_sentence( + "There is at most one yellow item closely touching the bottom of a box." + ) + assert set(agenda) == { + " -> touch_bottom", + " -> yellow", + "int -> 1", + } + agenda = language.get_agenda_for_sentence( + "There is at most one yellow item closely touching " "the right wall of a box." + ) + assert set(agenda) == { + " -> touch_right", + " -> yellow", + "int -> 1", + } + agenda = language.get_agenda_for_sentence( + "There is at most one yellow item closely touching " "the left wall of a box." + ) + assert set(agenda) == { + " -> yellow", + " -> touch_left", + "int -> 1", + } + agenda = language.get_agenda_for_sentence( + "There is at most one yellow item closely touching " "a wall of a box." + ) + assert set(agenda) == { + " -> yellow", + " -> touch_wall", + "int -> 1", + } + agenda = language.get_agenda_for_sentence("There is exactly one square touching any edge") + assert set(agenda) == { + " -> square", + " -> touch_wall", + "int -> 1", + } + agenda = language.get_agenda_for_sentence( + "There is exactly one square not touching any edge" + ) + assert set(agenda) == { + " -> square", + " -> touch_wall", + "int -> 1", + "<:> -> negate_filter", + } + agenda = language.get_agenda_for_sentence( + "There is only 1 tower with 1 blue block at the base" + ) + assert set(agenda) == { + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> blue", + "int -> 1", + " -> bottom", + "int -> 1", + } + agenda = language.get_agenda_for_sentence( + "There is only 1 tower that has 1 blue block at the top" + ) + assert set(agenda) == { + ":Set[Box]> -> box_filter", + "Set[Box] -> all_boxes", + " -> blue", + "int -> 1", + " -> top", + "int -> 1", + "Set[Box] -> all_boxes", + } + agenda = language.get_agenda_for_sentence( + "There is exactly one square touching the blue " "triangle" + ) + assert set(agenda) == { + " -> square", + " -> blue", + " -> triangle", + " -> touch_object", + "int -> 1", + } + + def test_get_agenda_for_sentence_correctly_adds_object_filters(self): + # In logical forms that contain "box_exists" at the top, there can never be object filtering + # operations like "blue", "square" etc. In those cases, strings like "blue" and "square" in + # sentences should map to "color_blue" and "shape_square" respectively. + language = self.languages[0] + agenda = language.get_agenda_for_sentence( + "there is a box with exactly two yellow triangles touching the top edge" + ) + assert " -> box_exists" in agenda + assert ":Set[Box]> -> box_filter" in agenda + assert "Set[Box] -> all_boxes" in agenda + assert " -> yellow" in agenda + assert " -> triangle" in agenda + assert " -> touch_top" in agenda + assert "int -> 2" in agenda + + agenda = language.get_agenda_for_sentence( + "there are exactly two yellow triangles touching the top edge" + ) + assert " -> yellow" in agenda + assert "Color -> color_yellow" not in agenda + assert " -> triangle" in agenda + assert "Shape -> shape_triangle" not in agenda + assert " -> touch_top" in agenda