diff --git a/utils_inference.py b/utils_inference.py new file mode 100644 index 0000000..04886a4 --- /dev/null +++ b/utils_inference.py @@ -0,0 +1,57 @@ +import torch +from typing import List, Tuple +from transformers import LogitsProcessor, LogitsProcessorList + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 20005] = 5e4 + return scores + + +@torch.no_grad() +def stream_chat_continue( + self, tokenizer, query: str, history: List[Tuple[str, str]] = None, + max_length: int = 2048, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + if len(history) > 0: + answer = history[-1][1] + else: + answer = '' + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + if i != len(history) - 1: + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + else: + prompt += "[Round {}]\n问:{}\n答:".format(i, old_query) + batch_input = tokenizer([prompt], return_tensors="pt", padding=True) + batch_input = batch_input.to(self.device) + + batch_answer = tokenizer(answer, return_tensors="pt") + batch_answer = batch_answer.to(self.device) + + input_length = len(batch_input['input_ids'][0]) + final_input_ids = torch.cat([batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], dim=-1).cuda() + attention_mask = torch.ones_like(final_input_ids).bool().cuda() + attention_mask[:, input_length:] = False + + batch_input['input_ids'] = final_input_ids + batch_input['attention_mask'] = attention_mask + + for outputs in self.stream_generate(**batch_input, **gen_kwargs): + outputs = outputs.tolist()[0][input_length:] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history diff --git a/web_demo.py b/web_demo.py index b4b51f2..d4974c9 100644 --- a/web_demo.py +++ b/web_demo.py @@ -26,19 +26,21 @@ def inference(*args, **kwargs): for i in range(len(one_output)): yield one_output[:i + 1] else: - from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration - from chatglm.tokenization_chatglm import ChatGLMTokenizer - tokenizer = ChatGLMTokenizer.from_pretrained(model_name, trust_remote_code=True, resume_download=True) - model = ChatGLMForConditionalGeneration.from_pretrained( - model_name, trust_remote_code=True, resume_download=True).half().cuda() + from transformers import AutoModel, AutoTokenizer + from utils_inference import stream_chat_continue + print('Loading model') + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, resume_download=True) + model = AutoModel.from_pretrained(model_name, trust_remote_code=True, resume_download=True).half().cuda() model = model.eval() + print(f'Successfully loaded model {model_name}') def inference(input, max_length, top_p, temperature, allow_generate, history=None): if history is None: history = [] - for response, history in model.stream_chat_continue(tokenizer, input, history, max_length=max_length, - top_p=top_p, temperature=temperature): + for response, history in stream_chat_continue( + model, tokenizer, input, history, max_length=max_length, + top_p=top_p, temperature=temperature): yield response if not allow_generate[0]: break