Skip to content

Commit

Permalink
Support more models (#8)
Browse files Browse the repository at this point in the history
* Update .gitignore

* Create utils_inference.py

* Update web_demo.py
  • Loading branch information
ypwhs authored Mar 29, 2023
1 parent 06b961a commit c30fbed
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
57 changes: 57 additions & 0 deletions utils_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from typing import List, Tuple
from transformers import LogitsProcessor, LogitsProcessorList


class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 20005] = 5e4
return scores


@torch.no_grad()
def stream_chat_continue(
self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
max_length: int = 2048, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
if len(history) > 0:
answer = history[-1][1]
else:
answer = ''
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
if i != len(history) - 1:
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
else:
prompt += "[Round {}]\n问:{}\n答:".format(i, old_query)
batch_input = tokenizer([prompt], return_tensors="pt", padding=True)
batch_input = batch_input.to(self.device)

batch_answer = tokenizer(answer, return_tensors="pt")
batch_answer = batch_answer.to(self.device)

input_length = len(batch_input['input_ids'][0])
final_input_ids = torch.cat([batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], dim=-1).cuda()
attention_mask = torch.ones_like(final_input_ids).bool().cuda()
attention_mask[:, input_length:] = False

batch_input['input_ids'] = final_input_ids
batch_input['attention_mask'] = attention_mask

for outputs in self.stream_generate(**batch_input, **gen_kwargs):
outputs = outputs.tolist()[0][input_length:]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
yield response, new_history
16 changes: 9 additions & 7 deletions web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ def inference(*args, **kwargs):
for i in range(len(one_output)):
yield one_output[:i + 1]
else:
from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
from chatglm.tokenization_chatglm import ChatGLMTokenizer
tokenizer = ChatGLMTokenizer.from_pretrained(model_name, trust_remote_code=True, resume_download=True)
model = ChatGLMForConditionalGeneration.from_pretrained(
model_name, trust_remote_code=True, resume_download=True).half().cuda()
from transformers import AutoModel, AutoTokenizer
from utils_inference import stream_chat_continue

print('Loading model')
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, resume_download=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, resume_download=True).half().cuda()
model = model.eval()
print(f'Successfully loaded model {model_name}')

def inference(input, max_length, top_p, temperature, allow_generate, history=None):
if history is None:
history = []
for response, history in model.stream_chat_continue(tokenizer, input, history, max_length=max_length,
top_p=top_p, temperature=temperature):
for response, history in stream_chat_continue(
model, tokenizer, input, history, max_length=max_length,
top_p=top_p, temperature=temperature):
yield response
if not allow_generate[0]:
break
Expand Down

0 comments on commit c30fbed

Please sign in to comment.