Skip to content

Commit

Permalink
- drop the timeout.
Browse files Browse the repository at this point in the history
  • Loading branch information
danyaljj committed Jan 11, 2024
1 parent 26e19af commit 75b51f7
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/evaluation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def __init__(self, solver_type: str, tasks: str, do_eval: bool, dump_features: b
self.excluded_input_names = [
'csrfmiddlewaretoken', # hidden field automatically added external css files
'worker_ip', # hidden field for bookkeeping
'ee'
'ee',
'submit'
]

def filter_TAP_tasks(self, task_name):
Expand Down Expand Up @@ -370,13 +371,14 @@ def extract_values(self, inputs: List[Input]):
]:

values = self.driver.execute_script(
f"return Array.from(document.getElementsByName(`{input.name}`)).map((element) => element.value);"
f"return Array.from(document.getElementsByName(`{input.name}`)).filter((element) => element.readOnly == false).map((element) => element.value);"
)

if input.type in ['textarea']:
visible_values = self.driver.execute_script(
f"return Array.from(document.getElementsByName(`{input.name}`)).map((element) => element.innerHTML);"
f"return Array.from(document.getElementsByName(`{input.name}`)).filter((element) => element.readOnly == false).map((element) => element.innerHTML);"
)

elif input.type == 'select':
visible_values = self.driver.execute_script(
f"return Array.from(document.getElementsByName(`{input.name}`)[0].children).filter((el) => el.selected == true).map((el) => el.value);"
Expand All @@ -400,15 +402,25 @@ def extract_values(self, inputs: List[Input]):
assert len(values) <= 1, f"The number of values should be 1 or 0 but it is `{len(values)}` for {input}"
assert len(visible_values) <= 1, \
f"The number of visible values should be 1 or 0 but it is `{len(visible_values)}` for {input}"

elif input.type in ['checkbox']:
command = f"""return Array.from(document.querySelectorAll(`input[name='{input.name}']:checked`)).map(element => element.value);"""
if "'" in input.name:
command = f"""return Array.from(document.querySelectorAll(`input[name="{input.name}"]:checked`)).map(element => element.value);"""
else:
command = f"""return Array.from(document.querySelectorAll(`input[name='{input.name}']:checked`)).map(element => element.value);"""
values = self.driver.execute_script(command)

command = f"""return Array.from(document.getElementsByName(`{input.name}`)).filter(element => element.checked).map(element => element.value);"""
visible_values = self.driver.execute_script(command)

elif input.type in ['submit']:
# do nothing
values = []
visible_values = []
else:
raise Exception(
f"{Fore.RED}To be implemented for type `{input.type}`")

clean_visible_values = clean_values(visible_values)
clean_visible_values = [
html.unescape(v) for v in clean_visible_values
Expand Down Expand Up @@ -652,6 +664,9 @@ def score_outputs(self, inputs: List[Input], answers_map: Dict, task_results: Di
if i.name in self.excluded_input_names:
continue

if i.type == 'submit':
continue

element = self.driver.find_element(By.NAME, i.name)
if not element.is_displayed() or element.size['width'] <= 0 or element.size['height'] <= 0:
print(f'{Fore.RED}Skipping element `{i.name}` since it is not visible.')
Expand Down Expand Up @@ -697,6 +712,8 @@ def score_outputs(self, inputs: List[Input], answers_map: Dict, task_results: Di
answers = clean_values(answers)
answers = list(set(answers))

answers_map[i.name] = answers

if answers == [] or answers == [""]:
continue

Expand All @@ -713,7 +730,10 @@ def score_outputs(self, inputs: List[Input], answers_map: Dict, task_results: Di
score += score_per_field
count += 1

score /= count # average score for this instance
# There are difficult tasks that you need do some movemenets in order for the inputs to appear.
# Otherwise, nothing would be visible and the score would be 0.
if count > 0:
score /= count # average score for this instance

return score

Expand Down Expand Up @@ -867,8 +887,11 @@ def enumerate_tasks(self, max_instance_count: int, **kwargs):
kwargs = {'answers': answers_map[i.name]}
oracle_action_sequence = self.solver.solve(i, **kwargs)
elif self.solver_type == 'model':
# TODO: the name should be "offline" here?
# TODO: check if we really need to pass "answer_map" here?
self.solver.solve(i, output=answer_map[instance_id][i.name])
else:
# random, nothing solvers, or model solvers that don't need to be trained
kwargs = {'url': url}
self.solver.solve(i, **kwargs)

Expand Down Expand Up @@ -896,7 +919,7 @@ def enumerate_tasks(self, max_instance_count: int, **kwargs):
per_task_score += score

if self.solver_type == 'oracle':
assert score > 0.99, f"{Fore.RED}The oracle baseline should always get a score of 1.0"
assert score > 0.99, f"{Fore.RED}The oracle baseline should always get a score of 1.0. Instead got `{score}`."
elif self.solver_type == 'model':
kwargs["scores"].append(score)

Expand Down

0 comments on commit 75b51f7

Please sign in to comment.