Skip to content

Commit

Permalink
task configs: metrics and filters as list + clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
AguirreNicolas committed Jun 18, 2024
1 parent c7328de commit bc10890
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 31 deletions.
4 changes: 3 additions & 1 deletion apps/python/evaluator/activities/lmeh/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def lmeh_evaluate(args: PocketNetworkEvaluationTaskRequest) -> bool:
# generate configurable tasks
try:
open_llm_cfg = open_llm_config.get_task_config(task_names[0])
open_llm_metrics = open_llm_cfg["metric"]
open_llm_filters = open_llm_cfg.get("filters", ["none"])
open_llm_metrics = open_llm_cfg["metrics"]
task_dict = lmeh_generator.get_configurable_task(
tasks=[task_name],
num_fewshot=args.num_fewshot,
Expand Down Expand Up @@ -181,6 +182,7 @@ async def lmeh_evaluate(args: PocketNetworkEvaluationTaskRequest) -> bool:
task_dict=task_dict,
task_id=args.task_id,
mongo_client=mongo_client,
selected_filters=open_llm_filters,
selected_metrics=open_llm_metrics,
eval_logger=eval_logger,
)
Expand Down
12 changes: 0 additions & 12 deletions packages/python/lmeh/pocket_lm_eval/models/pocket_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def _collate(x):
ctxlens = []
response = []
for cache_key, context_enc, continuation_enc, resp in chunk:
evaluation_logger.debug("Loaded Sample: ", cache_key=cache_key, context_enc=context_enc, continuation_enc=continuation_enc, resp=resp)
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1):]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
Expand All @@ -428,7 +427,6 @@ def _collate(x):
answer = get_result(resp.choices[0], ctxlen)

res.append(answer)
evaluation_logger.debug("Response: ", answer=answer)
return re_ord.get_original(res)

def generate_until(self, requests, disable_tqdm: bool = True) -> List[CompletionRequest]:
Expand Down Expand Up @@ -503,16 +501,6 @@ def loglikelihood(

new_reqs = []
for (([context, continuation],), context_enc, continuation_enc, resp) in [(req.args, req.prompt.context_enc, req.prompt.continuation_enc, req.resp) for req in requests]:
# for context, continuation in [req.args for req in requests]:
# if context == "":
# # BOS or EOS as context
# context_enc, continuation_enc = (
# [self.prefix_token_id],
# self.tok_encode(continuation),
# )
# else:
# context_enc, continuation_enc = self._encode_pair(context, continuation)

new_reqs.append(((context, continuation), context_enc, continuation_enc, resp))

return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
Expand Down
15 changes: 9 additions & 6 deletions packages/python/lmeh/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ async def evaluate(
task_dict,
task_id: ObjectId,
mongo_client: MongoClient,
selected_metrics: str,
selected_filters: List[str],
selected_metrics: List[str],
limit: Optional[int] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
Expand Down Expand Up @@ -434,12 +435,10 @@ async def save_results(

# run requests through model
resps = getattr(lm, reqtype)(cloned_reqs)
eval_logger.debug("Response:", resps=resps)

# put responses from model into a list of length K for each request.
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
eval_logger.debug("Request:", req=req, resps=req.resps, x=x)

RANK = lm.rank
WORLD_SIZE = lm.world_size
Expand All @@ -464,6 +463,10 @@ async def save_results(
scores = []
result_num_samples = set()
for filter_key in task.instances[0].filtered_resps.keys():
if filter_key not in selected_filters:
eval_logger.debug("Skipping Filter Key:", filter_key=filter_key)
continue
eval_logger.debug("Entering Filter Key:", filter_key=filter_key)
doc_iterator = task.doc_iterator(
rank=RANK, limit=limit, world_size=WORLD_SIZE
)
Expand All @@ -490,9 +493,9 @@ async def save_results(
task_output.logged_samples.append(example)
for metric, value in metrics.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
if selected_metrics in metrics:
numericSample = NumericSample(score=example[selected_metrics], id=doc_id)
scores.append(numericSample)
if metric in selected_metrics:
numericSample = NumericSample(score=example[metric], id=doc_id)
scores.append(numericSample)

base_result = PocketNetworkMongoDBResultBase(
task_id=task_id,
Expand Down
25 changes: 13 additions & 12 deletions packages/python/lmeh/utils/task_config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
task_cnfg = {
"arc_challenge": {
"metric": "acc_norm",
"num_fewshot":25
"metrics": ["acc_norm"],
"num_fewshot": 25,
},
"hellaswag": {
"metric": "acc_norm",
"num_fewshot":10
"metrics": ["acc_norm"],
"num_fewshot": 10,
},
"truthfulqa_mc2": {
"metric": "acc",
"num_fewshot":0
"metrics": ["acc"],
"num_fewshot": 0,
},
"mmlu": {
"metric": "acc",
"num_fewshot":5
"metrics": ["acc"],
"num_fewshot": 5,
},
"winogrande": {
"metric": "acc",
"num_fewshot":5
"metrics": ["acc"],
"num_fewshot": 5,
},
"gsm8k": {
"metric": "exact_match",
"num_fewshot":5
"metrics": ["exact_match"],
"num_fewshot": 5,
"filters": ["flexible-extract"]
}
}

Expand Down

0 comments on commit bc10890

Please sign in to comment.