Skip to content

Commit

Permalink
Created a run_single py to run a single task bypassing restrictions i…
Browse files Browse the repository at this point in the history
…n 4_run_evaluation
  • Loading branch information
klxu03 committed Sep 14, 2023
1 parent 4cd4a82 commit 1a4d930
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 1 deletion.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ src/Turkle/
src/Turkle/*
/tasks/*/input.csv
tasks/*/input.csv
.idea
.idea
/src/__pycache__/
/src/evaluation/__pycache__/
/src/oracle_baseline_scores.csv
/src/task_field_statistics.csv
261 changes: 261 additions & 0 deletions src/run_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import argparse
from colorama import Fore
import random
import pandas as pd
import os
import shutil
from selenium.webdriver.common.by import By
import json
run_eval = __import__('4_run_evaluation')

TURKLE_URL = "http://localhost:8000"
TEST_NAME = "Script KD eval LONG V2 - disc result eval 1"

class Run(run_eval.Evaluation):
def run_task(self, task_name: str, max_instance_count: int):
results = {}
self.driver.get(TURKLE_URL)
aggregate_field_statistics = {} # We store the stats related to the field types/frequency here
task_field_statistics = {}
print(f"{Fore.BLUE} = = = = = = = = = = = = starting new task: `{task_name}` = = = = = = = = = = = = ")

instance_ids = self.task_ids[task_name]
first_instance_id = min(instance_ids)
print("First instance id:", first_instance_id)

# if maximum is less than the number of instances, we sample a random subset of instances
if max_instance_count < len(instance_ids):
# random sample
instance_ids = random.sample(instance_ids, max_instance_count)

# Sample random instances of each task
for instance_id in instance_ids:

# wait for a keyboard press before continuing
# input("Press Enter to continue...")

row_number = instance_id - first_instance_id
print(f"instance_id: {instance_id} <-> row_number: {row_number}")

url = f'{TURKLE_URL}/task/{instance_id}/iframe/'
self.driver.get(url)

# get the name of the fields
df = pd.read_csv(f'../tasks/{task_name}/batch.csv', nrows=0)
input_names = [col.replace('Answer.', '') for col in df.columns if col.startswith('Answer.')]
inputs = self.extract_input_values_from_url(url=url, task_name=task_name, input_names=input_names)

print(" --> inputs: {}".format([x.name for x in inputs]))

answers_map = self.retrieve_gold_labels(
task_name, row_number, [x.name for x in inputs]
)

print(" --> input labels: {}".format(answers_map))

# TODO: check if all the files (images, videos, audio, css, etc.) in the HTML are accessible
# TODO: find all the URLS in the HTML and check if they are accessible

if self.dump_features:
directory = f'features/{task_name}'
images_directory = f'{directory}/images'
html_directory = f'{directory}/HTML'

if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory)

if not os.path.exists(html_directory):
os.makedirs(html_directory)

# for counting overall statistics
if self.report_field_stats:
if task_name not in task_field_statistics:
task_field_statistics[task_name] = {}

for i in inputs:
if i.type not in aggregate_field_statistics:
aggregate_field_statistics[i.type] = 0

aggregate_field_statistics[i.type] += 1

if i.type not in task_field_statistics[task_name]:
task_field_statistics[task_name][i.type] = 0
task_field_statistics[task_name][i.type] += 1

if self.dump_features:
data_to_be_dumped = []

for input_idx, i in enumerate(inputs):
print(f"{Fore.GREEN} - - - - - - starting a new element: `{i}` - - - - - - ")

# make sure that the element is visible
element = self.driver.find_element(By.NAME, i.name)
if not element.is_displayed() or element.size['width'] <= 0 or element.size['height'] <= 0:
print(f'{Fore.RED}Skipping element `{i.name}` since it is not visible.')
continue

if self.dump_features and i.type != 'hidden':
image_format = "bordered_div" # the most reasonable option
# create directory if needed
if not os.path.exists(f'{images_directory}_{image_format}'):
os.makedirs(f'{images_directory}_{image_format}')
if image_format == 'full_page':
task_image = self.actions.take_page_screenshots().outcome
elif image_format == 'bordered_div':
task_image = self.actions.take_element_screenshot_with_border(i).outcome
else:
raise Exception(f"{Fore.RED}to be implemented for image format `{image_format}`")

if isinstance(task_image, list):
img_ids = []
for j, image in enumerate(task_image):
image_id = f'{instance_id}_{input_idx}_{i.name}_{j}.png'
image.save(f'{images_directory}_{image_format}/{image_id}')
img_ids.append(image_id)
image_id = img_ids
else:
image_id = f'{instance_id}_{input_idx}_{i.name}.png'
task_image.save(f'{images_directory}_{image_format}/{image_id}')

html_id = f'{instance_id}_{i.name}.html'
with open(f'{html_directory}/{html_id}', 'w') as f:
# note, we can't use "driver.page_source" since it would return the default source without any changes
# TODO: double-check that this HTML code indeed contains the latest changes
f.write(self.driver.execute_script("return document.documentElement.outerHTML;"))

# *after* we dump *input* features, we execute the action
if self.solver_type == 'oracle':
kwargs = {'answers': answers_map[i.name]}
oracle_action_sequence = self.solver.solve(i, **kwargs)
else:
self.solver.solve(i)

# *after* we execute the action, we dump the *output* features
if self.dump_features:
data_to_be_dumped.append({
'input_type': i.type,
'input_name': i.name,
'image_id': image_id,
'html_id': html_id,
'output': oracle_action_sequence
})

# get the input values from the web page
inputs_with_values = self.extract_values(inputs)

# collecting field statistics
if task_name not in results:
results[task_name] = {}

# TODO: move this inside a evaluation function to keep here clean
score = 0.0
for i in inputs_with_values:
if i.name in self.excluded_input_names:
continue
# if checkmarks, sort the values alphabetically
if i.type == "checkbox":
i.values = "|".join(sorted(i.values))
for idx in range(len(answers_map[i.name])):
x = answers_map[i.name][idx]
if type(x) == str and "|" in x:
answers_map[i.name][idx] = "|".join(sorted(x.split("|")))
else:
if len(i.values) > 0:
i.values = i.values[0]
else:
i.values = ''
score_per_field = self.calculate_rouge(answers_map[i.name], i.type, i.values)

if i.type not in results[task_name]:
results[task_name][i.type] = []

results[task_name][i.type].append(score_per_field)

score += score_per_field

score /= len(inputs_with_values)
print(f"{Fore.CYAN} --> Overall score: {score}")

if self.solver_type == 'oracle':
assert score > 0.99, f"{Fore.RED}The oracle baseline should always get a score of 1.0"

if self.dump_features:
with open(f'{directory}/{task_name}.json', 'w') as f:
json.dump(data_to_be_dumped, f, indent=4)

df = pd.DataFrame()
for task_name, inputs in results.items():
for input_type, scores in inputs.items():
# print(scores)
avg_score = sum(scores) / len(scores)
# TODO: check if we can safely change the "projects" in the following lines to tasks
df = pd.concat(
[
df, pd.DataFrame({
'project': [task_name],
'input_type': [input_type],
'score': [avg_score]
})
],
ignore_index=True)

if 'project' not in df.columns:
df.insert(0, 'project', '')
if 'input_type' not in df.columns:
df.insert(1, 'input_type', '')
if 'score' not in df.columns:
df.insert(1, 'score', '')

df = df.pivot(index='project', columns='input_type', values='score')
df.to_csv('oracle_baseline_scores.csv', index=True)

# Close the driver
self.driver.quit()

print("Now let's print the field statistics")

# save task_field_statistics (hashmap of hashmaps mapped to integers) as a csv file
# first turn this hashmap into data frame
# then save it as a csv file
results = pd.DataFrame.from_dict(task_field_statistics)
results.to_csv('task_field_statistics.csv', index=True)

print("----------------------------------------------")
print(f'Number of tasks: {len(task_field_statistics.keys())}')
print("----------------------------------------------")
print(f'Number of fields: {len(aggregate_field_statistics.keys())}')
print("----------------------------------------------")
print(f'Overall field statistics: {aggregate_field_statistics}')
print("----------------------------------------------")
print(f'Field statistics per task: {task_field_statistics}')

if __name__ == "__main__":
# user argparser to recive he input parameter
parser = argparse.ArgumentParser()
parser.add_argument("--solver_type", help="random or oracle", default="random")
parser.add_argument("--tasks", help="train, test, or subjective_test", default="test")
parser.add_argument("--max_instance_count", help="maximum number of instances per task", default=1)
parser.add_argument("--do_eval", help="whether to compute the quality aginst the gold data", default=True)
parser.add_argument("--dump_features", help="whether to dump the features", default=False)
parser.add_argument("--report_field_stats", help="whether to collect statistics for the HTML fields", default=True)

args = parser.parse_args()
args.solver_type = "oracle"
args.max_instance_count = 5
print(f"{Fore.BLUE}Solver: {args.solver_type}")
max_instance_count = int(args.max_instance_count)

do_eval = args.do_eval
dump_features = args.dump_features
report_field_stats = args.report_field_stats

if dump_features and not args.solver_type != "oracle":
raise Exception(f"{Fore.RED}dump_features can only be used with oracle solver")

eval = Run(solver_type=args.solver_type, tasks=args.tasks,
do_eval=do_eval, dump_features=dump_features, report_field_stats=report_field_stats)

# input_format = config.get('DEFAULT', 'input_format')
# image_format = config.get('DEFAULT', 'image_format', fallback='full_page')
eval.run_task(TEST_NAME, args.max_instance_count)

0 comments on commit 1a4d930

Please sign in to comment.