Skip to content

Commit

Permalink
Fix generation_until in LM classes. (#78)
Browse files Browse the repository at this point in the history
* Fix `generation_until` in LM classes.
The requester now consider tokenizer eos.
The evaluator handle the `generation_until` call.
`gsm8k` changed metric from `acc` to `exact_match` in confings.

* fix eos token id + fix gen_kwargs in `generate+_until` + desactivate tqdm in LMs
  • Loading branch information
AguirreNicolas authored Jun 18, 2024
1 parent 31d5687 commit c7328de
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 70 deletions.
2 changes: 1 addition & 1 deletion apps/python/evaluator/activities/lmeh/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def lmeh_evaluate(args: PocketNetworkEvaluationTaskRequest) -> bool:
selected_metrics=open_llm_metrics,
eval_logger=eval_logger,
)
eval_logger.info("Evaluation completed successfully.")
eval_logger.info("Evaluation completed successfully.", task_id=str(args.task_id))
except ApplicationError as e:
raise e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sleep 3
everything="arc_challenge,hellaswag,truthfulqa_mc2,mmlu_abstract_algebra,mmlu_anatomy,mmlu_astronomy,mmlu_business_ethics,mmlu_clinical_knowledge,mmlu_college_biology,mmlu_college_chemistry,mmlu_college_computer_science,mmlu_college_mathematics,mmlu_college_medicine,mmlu_college_physics,mmlu_computer_security,mmlu_conceptual_physics,mmlu_econometrics,mmlu_electrical_engineering,mmlu_elementary_mathematics,mmlu_formal_logic,mmlu_global_facts,mmlu_high_school_biology,mmlu_high_school_chemistry,mmlu_high_school_computer_science,mmlu_high_school_european_history,mmlu_high_school_geography,mmlu_high_school_government_and_politics,mmlu_high_school_macroeconomics,mmlu_high_school_mathematics,mmlu_high_school_microeconomics,mmlu_high_school_physics,mmlu_high_school_psychology,mmlu_high_school_statistics,mmlu_high_school_us_history,mmlu_high_school_world_history,mmlu_human_aging,mmlu_human_sexuality,mmlu_international_law,mmlu_jurisprudence,mmlu_logical_fallacies,mmlu_machine_learning,mmlu_management,mmlu_marketing,mmlu_medical_genetics,mmlu_miscellaneous,mmlu_moral_disputes,mmlu_moral_scenarios,mmlu_nutrition,mmlu_philosophy,mmlu_prehistory,mmlu_professional_accounting,mmlu_professional_law,mmlu_professional_medicine,mmlu_professional_psychology,mmlu_public_relations,mmlu_security_studies,mmlu_sociology,mmlu_us_foreign_policy,mmlu_virology,mmlu_world_religions,winogrande,gsm8k"
mmlu="mmlu_abstract_algebra,mmlu_anatomy,mmlu_astronomy,mmlu_business_ethics,mmlu_clinical_knowledge,mmlu_college_biology,mmlu_college_chemistry,mmlu_college_computer_science,mmlu_college_mathematics,mmlu_college_medicine,mmlu_college_physics,mmlu_computer_security,mmlu_conceptual_physics,mmlu_econometrics,mmlu_electrical_engineering,mmlu_elementary_mathematics,mmlu_formal_logic,mmlu_global_facts,mmlu_high_school_biology,mmlu_high_school_chemistry,mmlu_high_school_computer_science,mmlu_high_school_european_history,mmlu_high_school_geography,mmlu_high_school_government_and_politics,mmlu_high_school_macroeconomics,mmlu_high_school_mathematics,mmlu_high_school_microeconomics,mmlu_high_school_physics,mmlu_high_school_psychology,mmlu_high_school_statistics,mmlu_high_school_us_history,mmlu_high_school_world_history,mmlu_human_aging,mmlu_human_sexuality,mmlu_international_law,mmlu_jurisprudence,mmlu_logical_fallacies,mmlu_machine_learning,mmlu_management,mmlu_marketing,mmlu_medical_genetics,mmlu_miscellaneous,mmlu_moral_disputes,mmlu_moral_scenarios,mmlu_nutrition,mmlu_philosophy,mmlu_prehistory,mmlu_professional_accounting,mmlu_professional_law,mmlu_professional_medicine,mmlu_professional_psychology,mmlu_public_relations,mmlu_security_studies,mmlu_sociology,mmlu_us_foreign_policy,mmlu_virology,mmlu_world_religions"
heavy="arc_challenge,hellaswag,truthfulqa_mc2,winogrande,gsm8k"
one="mmlu_astronomy"
one="gsm8k"
# change this if you want a different set of datasets, by default it create everything
keys=$one

Expand Down
125 changes: 58 additions & 67 deletions packages/python/lmeh/pocket_lm_eval/models/pocket_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Adapted from lm_eval/models/openai_completions.py
import copy
import transformers

from typing import List, Optional, Tuple
Expand Down Expand Up @@ -87,14 +88,15 @@ async def load_tokenizer(self):
wf_id=self.wf_id,
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
self.end_of_text_token_id = self.tokenizer.eos_token_id
eval_logger.debug(
"Tokenizer loaded successfully.",
adress=self.requester_args.address,
service=self.requester_args.service,
)

def tok_encode(self, string: str, **kwargs) -> List[int]:
# TODO: Add options like in lm_eval/models/vllm_causallms.py
if not self.tokenizer:
raise "must call await <instance>.load_tokenizer()"
return self.tokenizer.encode(string)
Expand All @@ -105,7 +107,7 @@ def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)

def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
self, requests,disable_tqdm: bool = True
) -> List[CompletionRequest]:
res = []

Expand Down Expand Up @@ -163,7 +165,7 @@ def _collate(x):

return re_ord.get_original(res)

def generate_until(self, requests, disable_tqdm: bool = False) -> List[CompletionRequest]:
def generate_until(self, requests,disable_tqdm: bool = True) -> List[CompletionRequest]:
if not requests:
return []
res = []
Expand Down Expand Up @@ -199,8 +201,28 @@ def sameuntil_chunks(xs, size):
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append(inp)

until = request_args.get("until", ["<|endoftext|>"])
gen_kwargs = request_args
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
request_args["temperature"] = request_args.get("temperature", 0)
############################################################
# START: POCKET NETWORK CODE
Expand All @@ -223,7 +245,10 @@ def sameuntil_chunks(xs, size):
)
for prompt_i, (context, args_) in zip(request.prompt, chunk):
req_dict = request.to_dict(remove_fields=["prompt"])
# context is a string
req_dict["prompt"] = prompt_i
req_dict["ctxlen"] = len(prompt_i)
req_dict["context_enc"] = prompt_i
req_i = CompletionRequest(**req_dict)
res.append(req_i)
############################################################
Expand All @@ -240,7 +265,7 @@ def _model_generate(self, context, max_length, eos_token_id):
raise NotImplementedError()

def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
self, requests,disable_tqdm: bool = True
) -> List[float]:
loglikelihoods = []

Expand Down Expand Up @@ -273,7 +298,7 @@ def loglikelihood_rolling(
return loglikelihoods

def loglikelihood(
self, requests, disable_tqdm: bool = False
self, requests, disable_tqdm: bool = True
) -> List[CompletionRequest]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
Expand Down Expand Up @@ -365,7 +390,7 @@ def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)

def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
self, requests,disable_tqdm: bool = True
) -> List[Tuple[float, bool]]:
res = []

Expand Down Expand Up @@ -397,40 +422,31 @@ def _collate(x):
ctxlens.append(ctxlen)
response.append(resp)

# response = oa_completion(
# client=self.client,
# model=self.model,
# prompt=inps,
# echo=True,
# max_tokens=0,
# temperature=0.0,
# logprobs=10,
# seed=self.seed,
# )

for resp, ctxlen, (cache_key, context_enc, continuation_enc, resp) in zip(
response, ctxlens, chunk
):
answer = get_result(resp.choices[0], ctxlen)

res.append(answer)
evaluation_logger.debug("Response: ", answer=answer)
# partial caching
#if cache_key is not None:
# self.cache_hook.add_partial("loglikelihood", cache_key, answer)
#return ApplicationError("END OF TEST",non_retryable=True)
return re_ord.get_original(res)

def generate_until(self, requests, disable_tqdm: bool = False) -> List[CompletionRequest]:
def generate_until(self, requests, disable_tqdm: bool = True) -> List[CompletionRequest]:
if not requests:
return []
res = []
requests = [req.args for req in requests]

# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args[0] for req in requests))
context_encoding = [req.prompt.context_enc for req in requests]
responses = [req.resp for req in requests]
completion_requests = [req.prompt.data for req in requests]
requests = [
((a, b, cr, r), c) for a, b, cr, r, c in zip(context, context_encoding, completion_requests, responses, all_gen_kwargs)
]
evaluation_logger.debug("Qty of requests: ", qty_req=len(requests))
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]

toks = x[0][1]
return len(toks), x[0][0]
re_ord = utils.Reorderer(requests, _collate)

def sameuntil_chunks(xs, size):
Expand All @@ -445,48 +461,23 @@ def sameuntil_chunks(xs, size):

if ret:
yield ret, lastuntil

# todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
return ApplicationError("Currently evaluation of task with generate_until are not suported",non_retryable=True)
inps = []
self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
context, _, completion_request, response = chunk[0][0]

until = request_args.get("until", ["<|endoftext|>"])
request_args["temperature"] = request_args.get("temperature", 0)
############################################################
# START: POCKET NETWORK CODE
############################################################
extra_args = {
k: v
for k, v in request_args.items()
if k not in ["do_sample", "max_gen_toks", "until"]
}
evaluation_logger.debug("CompletionRequest: ", model=self.model, prompt=inps, max_tokens=self.max_gen_toks, stop=until, seed=self.seed)
evaluation_logger.debug("Extra args: ", **extra_args)
request = CompletionRequest(
model=self.model,
prompt=inps,
max_tokens=self.max_gen_toks,
stop=until,
seed=self.seed,
**extra_args,
)
for prompt_i, (context, args_) in zip(request.prompt, chunk):
req_dict = request.to_dict(remove_fields=["prompt"])
req_dict["prompt"] = prompt_i
req_i = CompletionRequest(**req_dict)
res.append(req_i)
############################################################
# END: POCKET NETWORK CODE
############################################################
until = completion_request.stop
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text")

until_ = until

for term in until_:
if len(term) > 0:
s = s.split(term)[0]
res.append(s) #
return re_ord.get_original(res)

def _model_call(self, inps):
Expand All @@ -498,13 +489,13 @@ def _model_generate(self, context, max_length, eos_token_id):
raise NotImplementedError()

def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
self, requests, disable_tqdm: bool = True
) -> List[float]:
# TODO: Update this method in order to be available for the Pocket Network
return ApplicationError("Currently evaluation of task with loglikelihood_rolling are not suported",non_retryable=True)

def loglikelihood(
self, requests, disable_tqdm: bool = False
self, requests,disable_tqdm: bool = True
) -> List[CompletionRequest]:
# Modify this in order to insted of get contex and continuation,
# get the context, continuation, context_enc and continuation_enc.
Expand Down
2 changes: 1 addition & 1 deletion packages/python/lmeh/utils/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"num_fewshot":5
},
"gsm8k": {
"metric": "acc",
"metric": "exact_match",
"num_fewshot":5
}
}
Expand Down

0 comments on commit c7328de

Please sign in to comment.