forked from lss233/chatgpt-mirai-qq-bot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conversation.py
115 lines (94 loc) · 3.7 KB
/
conversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import List
import manager.bot
from adapter.botservice import BotAdapter
from adapter.chatgpt.api import ChatGPTAPIAdapter
from adapter.chatgpt.web import ChatGPTWebAdapter
from adapter.ms.bing import BingAdapter
from constants import config
from exceptions import PresetNotFoundException, BotTypeNotFoundException
from renderer.renderer import Renderer, FullTextRenderer, MarkdownImageRenderer
handlers = dict()
class ConversationContext:
type: str
adapter: BotAdapter
renderer: Renderer
preset: str = None
def __init__(self, _type: str, session_id: str):
self.session_id = session_id
if config.text_to_image.default:
self.renderer = MarkdownImageRenderer()
else:
self.renderer = FullTextRenderer()
if _type == 'chatgpt-web':
self.adapter = ChatGPTWebAdapter(self.session_id)
elif _type == 'chatgpt-api':
self.adapter = ChatGPTAPIAdapter(self.session_id)
elif _type == 'bing':
self.adapter = BingAdapter(self.session_id)
else:
raise BotTypeNotFoundException(_type)
self.type = _type
async def reset(self):
await self.adapter.on_reset()
yield config.response.reset
async def ask(self, prompt: str, name: str = None):
async with self.renderer:
async for item in self.adapter.ask(prompt):
yield await self.renderer.render(item)
yield await self.renderer.result()
async def rollback(self):
resp = await self.adapter.rollback()
if isinstance(resp, bool):
yield config.response.rollback_success if resp else config.response.rollback_fail.format(reset=config.trigger.reset_command)
else:
yield resp
async def load_preset(self, keyword: str):
if keyword not in config.presets.keywords:
if not keyword == 'default':
raise PresetNotFoundException(keyword)
else:
presets = config.load_preset(keyword)
for text in presets:
if text.startswith('#'):
continue
else:
# 判断格式是否为 role: 文本
if ':' in text:
role, text = text.split(':', 1)
else:
role = 'system'
async for item in self.adapter.preset_ask(role=role.lower().strip(), text=text.strip()):
yield item
self.preset = keyword
class ConversationHandler:
"""
每个聊天窗口拥有一个 ConversationHandler,
负责管理多个不同的 ConversationContext
"""
conversations = {}
"""当前聊天窗口下所有的会话"""
current_conversation: ConversationContext = None
session_id: str = 'unknown'
def __init__(self, session_id: str):
self.session_id = session_id
def list(self) -> List[ConversationContext]:
...
"""创建新的上下文"""
async def create(self, _type: str):
if _type in self.conversations:
return self.conversations[_type]
else:
conversation = ConversationContext(_type, self.session_id)
self.conversations[_type] = conversation
return conversation
"""切换对话上下文"""
def switch(self, index: int) -> bool:
if len(self.conversations) > index:
self.current_conversation = self.conversations[index]
return True
return False
@classmethod
async def get_handler(cls, session_id: str):
if session_id not in handlers:
handlers[session_id] = ConversationHandler(session_id)
return handlers[session_id]