Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add learnable agent #4

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ repos:
--disable-error-code=import-untyped,
--disable-error-code=truthy-function,
--follow-imports=skip,
--disable-error-code=override,
]
# - repo: https://github.com/numpy/numpydoc
# rev: v1.6.0
Expand Down
135 changes: 135 additions & 0 deletions src/agentscope/agents/learnable_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
""" LearnableAgent agent class for Agent """
from abc import ABC
from typing import Optional, Union, Any, Callable, Type
from loguru import logger

from agentscope.message import Msg
from agentscope.memory import MemoryBase, TemporaryMemory
from agentscope.agents.agent import AgentBase
from agentscope.service.retrieval.similarity import cos_sim


VALUE_ASSESSMENT_PROMPT = (
"Please carefully consider the following record and assess whether it "
"contains information of sufficient value to be suitable for storage in "
"a knowledge base. "
"\nExample:\n"
"'The dragon is the only creature in the Chinese Zodiac that is "
"considered a divine animal.' → Answer 'yes' (because this is basic "
"knowledge about Chinese culture with widespread reference value for "
"understanding related topics)\n"
"Following these guidelines, please respond with 'yes' or 'no' to the "
"following record:\n\n"
"{record}"
)

EXTRACTION_SUMMARY_PROMPT = (
"Please read the following record, extract key knowledge points or "
"question-answer pairs, and provide a concise and clear summary. "
"\nExample:\n"
"Record: 'Due to the rotation of the Earth, we experience the "
"alternation of day and night. "
"The Earth completes one rotation every 24 hours.'\n"
"Summary: 'The Earth rotates once every 24 hours, which leads to the "
"phenomenon of day and night alternation.'\n\n"
"{record}"
)


class LearnableAgent(AgentBase, ABC):
"""Class for LearnableAgent"""

def __init__(
self,
name: str,
vdb_path: str,
vdb_cls: Type[MemoryBase] = TemporaryMemory,
config: Optional[dict] = None,
sys_prompt: Optional[str] = None,
model: Optional[Union[Callable[..., Any], str]] = None,
embedding_model: Union[str, Callable] = None,
metric: Callable = cos_sim,
assess_prompt: str = VALUE_ASSESSMENT_PROMPT,
extract_prompt: str = EXTRACTION_SUMMARY_PROMPT,
) -> None:
super().__init__(name, config, sys_prompt, model)
# Notice: [Memory] is for short-term, current conversation, and will
# not persist after the agent is closed.
# [Vector database] is considered long-term, will be reloaded whenever
# agent is invoked
# Build vector database for saving knowledge
self.vdb = vdb_cls(
config,
embedding_model=embedding_model,
vdb_path=vdb_path,
)
self.metric = lambda x, y: metric(x, y).content
self.assess_prompt = assess_prompt
self.extract_prompt = extract_prompt

def reply(self, x: dict = None) -> dict:
"""Forward method for agent"""
# defer the forward function implementation to example agents
raise NotImplementedError

def learn_from_chat(self) -> None:
"""
Iterates through the messages in the learner's memory and processes
each message to potentially learn from it. Messages originating
from the learner itself are ignored. The memory is reset after
processing.

This function calls the `archive_valuable_msg` method on each message
to decide whether to store the message information into the
knowledge base.
"""
if self.memory.size() > 0:
for msg in self.memory:
# Ignore msg from itselves to avoid duplication
if msg.get("name") != self.name:
self.archive_valuable_msg(msg)
self.memory.reset()

def archive_valuable_msg(self, msg: dict) -> None:
"""
Evaluates a single message to determine whether it should be stored
in the knowledge base. The method generates prompts to assess the
value of the message and to extract a summary if the message is
deemed valuable.

Args:
msg (dict): A dictionary representing the message to be
considered for storage. The dictionary typically contains
keys such as 'name' and 'content'.
"""
# Consider whether to deposit message into the knowledge base
prompt = self.assess_prompt.format_map(
{
"record": msg.content,
},
)
res = self.model([Msg(self.name, prompt)])

logger.info(
f"{self.name}:\n {msg.content} \n " f"accessing results: {res}.",
)

if "yes" in res.lower():
prompt = self.extract_prompt.format_map(
{
"record": msg.content,
},
)
res = self.model([Msg(self.name, prompt)])
emb = self._openai_embedding(res)
self.vdb.add(Msg(self.name, res, embedding=emb), embed=False)
logger.info(f"Saving {res} in {self.name}'s vdb.")

def close(self) -> None:
"""
Saves the current state of the vecter database (vdb) to a memory file.
This method should be called before the termination of the program
to ensure that learned information is not lost.
"""
self.vdb.export()
17 changes: 9 additions & 8 deletions src/agentscope/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
"""

from abc import ABC, abstractmethod
from typing import Iterable
from typing import Optional
from typing import Union
from typing import Callable
from typing import Iterable, Optional, Union, Callable, Any


class MemoryBase(ABC):
Expand All @@ -21,6 +18,7 @@ class MemoryBase(ABC):
def __init__(
self,
config: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""MemoryBase is a base class for memory of agents.

Expand All @@ -29,6 +27,7 @@ def __init__(
Configuration of this memory.
"""
self.config = {} if config is None else config
self.kwargs = kwargs

def update_config(self, config: dict) -> None:
"""
Expand All @@ -48,13 +47,13 @@ def get_memory(
"""

@abstractmethod
def add(self, memories: Union[list[dict], dict]) -> None:
def add(self, memories: Union[list[dict], dict], **kwargs: Any) -> None:
"""
Adding new memory fragment, depending on how the memory are stored
"""

@abstractmethod
def delete(self, index: Union[Iterable, int]) -> None:
def delete(self, index: Union[Iterable, int], **kwargs: Any) -> None:
"""
Delete memory fragment, depending on how the memory are stored
and matched
Expand All @@ -65,6 +64,7 @@ def load(
self,
memories: Union[str, dict, list],
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""
Load memory, depending on how the memory are passed, design to load
Expand All @@ -76,14 +76,15 @@ def export(
self,
to_mem: bool = False,
file_path: Optional[str] = None,
**kwargs: Any,
) -> Optional[list]:
"""Export memory, depending on how the memory are stored"""

@abstractmethod
def clear(self) -> None:
def clear(self, **kwargs: Any) -> None:
"""Clean memory, depending on how the memory are stored"""

@abstractmethod
def size(self) -> int:
def size(self, **kwargs: Any) -> int:
"""Returns the number of memory segments in memory."""
raise NotImplementedError
14 changes: 13 additions & 1 deletion src/agentscope/memory/temporary_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ def __init__(
self,
config: Optional[dict] = None,
embedding_model: Union[str, Callable] = None,
mem_path: Optional[str] = None,
) -> None:
super().__init__(config)

self._content = []

self.mem_path = mem_path
if self.mem_path is not None:
self.load()

# prepare embedding model if needed
if isinstance(embedding_model, str):
self.embedding_model = load_model_by_name(embedding_model)
Expand Down Expand Up @@ -105,6 +110,7 @@ def export(
if to_mem:
return self._content

file_path = file_path or self.mem_path
if to_mem is False and file_path is not None:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(self._content, f, indent=4)
Expand All @@ -117,9 +123,15 @@ def export(

def load(
self,
memories: Union[str, dict, list],
memories: Union[str, dict, list] = None,
overwrite: bool = False,
) -> None:
if memories is None:
if self.mem_path is not None:
memories = self.mem_path
else:
return

if isinstance(memories, str):
if os.path.isfile(memories):
with open(memories, "r", encoding="utf-8") as f:
Expand Down