Skip to content

Commit

Permalink
Merge pull request #79 from JHU-CLSP/kevin
Browse files Browse the repository at this point in the history
Start of comprehensive test, hotfix on clear text
  • Loading branch information
danyaljj authored Sep 27, 2023
2 parents 7d45cbf + a0a4daa commit 3d3b238
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 9 deletions.
50 changes: 44 additions & 6 deletions src/4_run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def extract_input_values_from_url(self, url, task_name, input_names=None) -> Lis
"""
# TODO I think we can drop "url" parameter later.

inputs = []
# if a list of input names are provided in the input, then extract the input fields with those names
# otherwise, look for inputs that may look like input fields
if input_names:
inputs = []
for name in input_names:
# use selenium to find the input field
try:
Expand Down Expand Up @@ -498,12 +498,10 @@ def enumerate_tasks(self, max_instance_count: int):
first_instance_id = min(instance_ids)
print("First instance id:", first_instance_id)

# if maximum is less than the number of instances, we sample a random subset of instances
if max_instance_count < len(instance_ids):
# random sample
instance_ids = random.sample(instance_ids, max_instance_count)
# Create a random sample
instance_ids = random.sample(instance_ids, min(max_instance_count, len(instance_ids)))

# Sample random instances of each task
# Go through the instances of each task in this random sample
for instance_id in instance_ids:

# wait for a keyboard press before continuing
Expand Down Expand Up @@ -705,6 +703,46 @@ def enumerate_tasks(self, max_instance_count: int):
print("----------------------------------------------")
print(f'Field statistics per task: {task_field_statistics}')

def enumerate_comprehensive_tests(self, max_instance_count: int):
"""
Enumerate all the tasks comprehensively, so going upto max_instance_count which should be high
It will keep going despite failures and errors (and not skip any available tasks)
:param max_instance_count
returns:
a list of tasks tuple (task name, % completed, avg score)
- % completed will be what percentage of the instances completed with a score of 1
- avg score is a running mean of their score
"""

input_format = "both"

tasks = self.load_task_names()
ret = []
self.driver.get(TURKLE_URL)

for task_name in tqdm(tasks):
print(f"{Fore.BLUE} = = = = = = = = = = = = starting new task: `{task_name}` = = = = = = = = = = = = ")
instance_ids = self.tasks_ids[task_name]
first_instance_id = min(instance_ids) # TODO: Check if this is also just the first one, might be with how the JSON is formatted

instance_ids = random.sample(instance_ids, min(max_instance_count, len(instance_ids)))

for instance_id in instance_ids:
row_num = instance_id - first_instance_id

url = f'{TURKLE_URL}/task/{instance_id}/iframe/'
self.driver.get(url)

# get the name of the fields
df = pd.read_csv(f'../tasks/{task_name}/batch.csv', nrows=0)
input_names = [col[len('Answer.'):] for col in df.columns if col.startswith('Answer.')]
inputs = self.extract_input_values_from_url(url=url, task_name=task_name, input_names=input_names)

return



if __name__ == "__main__":
# user argparser to recive he input parameter
Expand Down
42 changes: 42 additions & 0 deletions src/comprehensive_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
eval = __import__('4_run_evaluation')
from evaluation.actions import MyActions
from evaluation.baselines import Baseline
from utils.hidden_prints import HiddenPrints
import statistics

evaluation = eval.Evaluation(solver_type="oracle", tasks="all",
do_eval=True, dump_features=True, report_field_stats=True)


def test_actions():
baseline = Baseline(driver=evaluation.driver, actions=evaluation.actions)
action_list = baseline.get_action_list()
print(action_list)
assert len(action_list) > 0, f"The action list should not be empty: {action_list}"

encoded_actions_prompt = baseline.get_encoded_action_list()
print(encoded_actions_prompt)
assert len(encoded_actions_prompt) > 0, f"The encoded actions prompt should not be empty: {encoded_actions_prompt}"


def test_evaluation():
evaluation.enumerate_tasks(max_instance_count=1)


if __name__ == "__main__":
print("Running initial pass on tests without logs")
try:
with HiddenPrints():
test_evaluation()
test_actions()
except Exception as error:
print("An error occurred:", error)
print("Rerunning tests with logs now")
test_evaluation()
test_actions()

# TODO: test that we can apply the gold labels on the tasks

# TODO: test the actions

# TODO: test the evaluation
8 changes: 5 additions & 3 deletions src/evaluation/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from PIL import Image, ImageDraw
import requests
import platform
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
Expand Down Expand Up @@ -57,11 +58,12 @@ def is_float(element: any) -> bool:

@staticmethod
def clear_text(action: ActionChains):
key = Keys.COMMAND if platform.system() == "Darwin" else Keys.CONTROL

# Perform Ctrl+A (select all)
action.key_down(Keys.CONTROL).send_keys('a').key_up(Keys.CONTROL)
action.key_down(key).send_keys('a').key_up(key)
# Perform Delete
action.send_keys(Keys.DELETE)

action.send_keys(Keys.BACKSPACE)

class MyActions:
"""
Expand Down

0 comments on commit 3d3b238

Please sign in to comment.