Skip to content

Commit

Permalink
Merge pull request #896 from Undertone0809/v1.18.3/change-storage-path
Browse files Browse the repository at this point in the history
feat: Add new beta test file for streamlit sidebar
  • Loading branch information
Undertone0809 authored Sep 6, 2024
2 parents 8b06e45 + 6d6b0cd commit 765d6ba
Show file tree
Hide file tree
Showing 40 changed files with 1,240 additions and 1,141 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
poetry config virtualenvs.in-project true
poetry install
make install
- name: Run tests
run: |
Expand Down
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ DEV_TEST_TOOL_FILES := ./tests/tools/test_human_feedback_tool.py ./tests/tools/t
DEV_TEST_HOOK_FILES := ./tests/hook/test_llm.py ./tests/hook/test_tool_hook.py
DEV_TEST_LLM_FILES := ./tests/llms/test_openai.py ./tests/llms/test_factory.py
DEV_TEST_AGENT_FILES := ./tests/agents/test_tool_agent.py ./tests/agents/test_assistant_agent.py
DEV_TEST_FILES := $(DEV_TEST_TOOL_FILES) $(DEV_TEST_HOOK_FILES) $(DEV_TEST_LLM_FILES) $(DEV_TEST_AGENT_FILES) ./tests/test_chat.py ./tests/output_formatter ./tests/test_import.py ./tests/utils/test_string_template.py
DEV_TEST_BETA := ./tests/beta/test_st.py
DEV_TEST_FILES := $(DEV_TEST_BETA) $(DEV_TEST_TOOL_FILES) $(DEV_TEST_HOOK_FILES) $(DEV_TEST_LLM_FILES) $(DEV_TEST_AGENT_FILES) ./tests/test_chat.py ./tests/output_formatter ./tests/test_import.py ./tests/utils/test_string_template.py


ifeq ($(OS),win32)
PYTHONPATH := $(shell python -c "import os; print(os.getcwd())")
TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate $(DEV_TEST_FILES)
# TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate $(DEV_TEST_FILES)
TEST_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests/basic
TEST_PROD_COMMAND := set PYTHONPATH=$(PYTHONPATH) && poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests
else
PYTHONPATH := `pwd`
TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate $(DEV_TEST_FILES)
# TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate $(DEV_TEST_FILES)
TEST_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests/basic
TEST_PROD_COMMAND := PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=promptulate tests
endif

Expand Down
777 changes: 379 additions & 398 deletions poetry.lock

Large diffs are not rendered by default.

26 changes: 25 additions & 1 deletion promptulate/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, **kwargs):
hooks(List[Callable]): for adding to hook_manager
"""
warnings.warn(
"BaseTool is deprecated at v1.7.0. promptulate.tools.base.Tool is recommended.", # noqa: E501
"BaseTool is deprecated at v1.7.0. promptulate.tools.base.Tool and function type declaration is recommended.", # noqa: E501
DeprecationWarning,
)
super().__init__(**kwargs)
Expand Down Expand Up @@ -233,6 +233,30 @@ def to_schema(self) -> Dict[str, Any]:
f"The 'parameters' attribute of {self.__class__.__name__} must be either a subclass of BaseModel or a dictionary representing a schema." # noqa: E501
)

def _args_to_kwargs(self, *args, **kwargs) -> Dict:
"""Converts positional arguments to keyword arguments based on tool parameters.
This method takes in both positional and keyword arguments. It then attempts to
match the positional arguments to the tool's parameters, converting them to
keyword arguments. Any additional keyword arguments are also included in the
final dictionary.
Returns:
Dict: A dictionary containing the converted keyword arguments.
"""
all_kwargs = {}

if isinstance(self.parameters, dict) and "properties" in self.parameters:
all_kwargs.update(dict(zip(self.parameters["properties"].keys(), args)))
elif isinstance(self.parameters, type) and issubclass(
self.parameters, BaseModel
):
all_kwargs.update(dict(zip(self.parameters.__fields__.keys(), args)))

all_kwargs.update(kwargs)

return all_kwargs


class ToolImpl(Tool):
def __init__(
Expand Down
18 changes: 18 additions & 0 deletions promptulate/tools/human_feedback/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable

from promptulate.tools import Tool
Expand All @@ -19,6 +20,16 @@ class HumanFeedBackTool(Tool):
"is lacking or reasoning cannot continue. Please enter the content you wish for"
"human feedback and interaction, but do not ask for knowledge or let humans reason." # noqa
)
parameters: dict = {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content to be presented for human feedback",
}
},
"required": ["content"],
}

def __init__(
self,
Expand All @@ -27,6 +38,13 @@ def __init__(
*args,
**kwargs,
):
warnings.warn(
(
"HumanFeedBackTool will be removed in v1.21.0. "
"You can custom tool if you want to get human feedback"
),
DeprecationWarning,
)
super().__init__(*args, **kwargs)
self.output_func = output_func
self.input_func = input_func
Expand Down
9 changes: 7 additions & 2 deletions promptulate/tools/math/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import math
import re

import numexpr

from promptulate.llms.base import BaseLLM
from promptulate.llms.openai import ChatOpenAI
from promptulate.tools.base import Tool
Expand All @@ -27,6 +25,13 @@ def _evaluate_expression(expression: str) -> str:
Raises:
ValueError: If the evaluation fails.
"""
try:
import numexpr
except ImportError:
raise ValueError(
" Please install the numexpr package using `pip install numexpr`."
)

try:
local_dict = {"pi": math.pi, "e": math.e}
output = str(
Expand Down
6 changes: 4 additions & 2 deletions promptulate/tools/paper/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ class PaperSummaryTool(BaseTool):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True
model_config = {
"extra": "forbid",
"arbitrary_types_allowed": True,
}

def _run(self, query: str, **kwargs) -> str:
"""A paper summary tool that passes in the article name (or arxiv id) and
Expand Down
22 changes: 10 additions & 12 deletions promptulate/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,37 +108,35 @@ def get_cache():
return cache


def set_openai_api_key(value: str):
cache["OPENAI_API_KEY"] = value


def convert_backslashes(path: str):
"""Convert all \\ to / of file path."""
return path.replace("\\", "/")


def get_default_storage_path(module_name: str = "") -> str:
"""Get the default storage path for the current module. The storage path is
created in the user's home directory, or in a temporary directory if permission
is denied.
"""Get the default storage path for the current module.
The storage path is created in the user's home directory under ~/.zeeland/pne,
or in a temporary directory if permission is denied.
Args:
module_name(str): The name of the module to create a storage path for.
module_name (str): The name of the module to create a storage path for.
Returns:
str: The default storage path for the current module.
"""
storage_path = os.path.expanduser("~/.pne")
storage_path = os.path.expanduser("~/.zeeland/pne")

if module_name:
storage_path = os.path.join(storage_path, module_name)

# Try to create the storage path (with module subdirectory if specified)
# Use a temporary directory instead if permission is denied,
try:
os.makedirs(storage_path, exist_ok=True)
except PermissionError:
storage_path = os.path.join(tempfile.gettempdir(), "pne", module_name)
temp_path = os.path.join(tempfile.gettempdir(), "zeeland", "pne")
storage_path = (
os.path.join(temp_path, module_name) if module_name else temp_path
)
os.makedirs(storage_path, exist_ok=True)

return convert_backslashes(storage_path)
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ broadcast-service = "1.3.2"
click = "^8.1.7"
cushy-storage = "^1.3.7"
litellm = "^1.39.6"
numexpr = "^2.8.4"
pydantic = ">=1,<3"
python = ">=3.8.1,<4.0"
python-dotenv = "^1.0.0"
questionary = "^2.0.1"
requests = "^2.31.0"
jinja2 = "^3.1.3"
typing-extensions = "^4.10.0"


Expand Down
55 changes: 31 additions & 24 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
aiohttp==3.9.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
aiohappyeyeballs==2.4.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
aiohttp==3.10.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
annotated-types==0.7.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
anyio==4.4.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_version < "3.11"
attrs==23.2.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
attrs==24.2.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
broadcast-service==1.3.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
certifi==2024.6.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
certifi==2024.8.30 ; python_full_version >= "3.8.1" and python_version < "4.0"
charset-normalizer==3.3.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
click==8.1.7 ; python_full_version >= "3.8.1" and python_version < "4.0"
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_version < "4.0" and platform_system == "Windows"
cushy-storage==1.3.8 ; python_full_version >= "3.8.1" and python_version < "4.0"
distro==1.9.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
exceptiongroup==1.2.1 ; python_full_version >= "3.8.1" and python_version < "3.11"
filelock==3.14.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
exceptiongroup==1.2.2 ; python_full_version >= "3.8.1" and python_version < "3.11"
filelock==3.15.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
frozenlist==1.4.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
fsspec==2024.6.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
fsspec==2024.9.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
h11==0.14.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
httpcore==1.0.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
httpx==0.27.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
huggingface-hub==0.23.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
idna==3.7 ; python_full_version >= "3.8.1" and python_version < "4.0"
importlib-metadata==7.1.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
httpx==0.27.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
huggingface-hub==0.24.6 ; python_full_version >= "3.8.1" and python_version < "4.0"
idna==3.8 ; python_full_version >= "3.8.1" and python_version < "4.0"
importlib-metadata==8.4.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
importlib-resources==6.4.4 ; python_full_version >= "3.8.1" and python_version < "3.9"
jinja2==3.1.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
litellm==1.40.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
jiter==0.5.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
jsonschema-specifications==2023.12.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
jsonschema==4.23.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
litellm==1.44.19 ; python_full_version >= "3.8.1" and python_version < "4.0"
markupsafe==2.1.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
multidict==6.0.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
numexpr==2.8.6 ; python_full_version >= "3.8.1" and python_version < "4.0"
numpy==1.24.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
openai==1.31.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
openai==1.43.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
packaging==23.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
pkgutil-resolve-name==1.3.10 ; python_full_version >= "3.8.1" and python_version < "3.9"
prompt-toolkit==3.0.36 ; python_full_version >= "3.8.1" and python_version < "4.0"
pydantic-core==2.18.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
pydantic==2.7.3 ; python_full_version >= "3.8.1" and python_version < "4.0"
pydantic-core==2.23.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
pydantic==2.9.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
python-dotenv==1.0.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
pyyaml==6.0.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
questionary==2.0.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
regex==2024.5.15 ; python_full_version >= "3.8.1" and python_version < "4.0"
referencing==0.35.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
regex==2024.7.24 ; python_full_version >= "3.8.1" and python_version < "4.0"
requests==2.32.3 ; python_full_version >= "3.8.1" and python_version < "4.0"
rpds-py==0.20.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
sniffio==1.3.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
tiktoken==0.7.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
tokenizers==0.19.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
tqdm==4.66.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
typing-extensions==4.12.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
urllib3==2.2.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
tokenizers==0.20.0 ; python_full_version >= "3.8.1" and python_version < "4.0"
tqdm==4.66.5 ; python_full_version >= "3.8.1" and python_version < "4.0"
typing-extensions==4.12.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.2.2 ; python_full_version >= "3.8.1" and python_version < "4.0"
wcwidth==0.2.13 ; python_full_version >= "3.8.1" and python_version < "4.0"
yarl==1.9.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
zipp==3.19.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
yarl==1.9.11 ; python_full_version >= "3.8.1" and python_version < "4.0"
zipp==3.20.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
import json
from typing import Optional

from promptulate import BaseMessage, MessageSet
from promptulate.beta.agents.assistant_agent.agent import AssistantAgent
from promptulate.beta.agents.assistant_agent.schema import (
AgentPlanResponse,
Plan,
)
from promptulate.llms.base import BaseLLM


class FakeLLM(BaseLLM):
def _predict(
self, messages: MessageSet, *args, **kwargs
) -> Optional[type(BaseMessage)]:
return None

def __call__(self, instruction: str, *args, **kwargs):
return "FakeLLM output"


def fake_tool_1():
"""Fake tool 1"""
return "Fake tool 1 output"


def test_init_assistant_agent():
llm: BaseLLM = FakeLLM()
agent = AssistantAgent(llm=llm, tools=[fake_tool_1])

assert len(agent.tool_manager.tools) == 1


def test_schema():
raw_data = """{\n \"goals\": [\"Find the hometown of the 2024 Australian Open winner\"],\n \"tasks\": [\n {\n \"task_id\": 1,\n \"description\": \"Identify the winner of the 2024 Australian Open.\"\n },\n {\n \"task_id\": 2,\n \"description\": \"Search for the biography or profile of the identified winner.\"\n },\n {\n \"task_id\": 3,\n \"description\": \"Locate the section of the biography or profile that specifies the player's hometown.\"\n },\n {\n \"task_id\": 4,\n \"description\": \"Record the hometown of the 2024 Australian Open winner.\"\n }\n ]\n}""" # noqa
json_data = json.loads(raw_data)
plan_resp = AgentPlanResponse(**json_data)

plan = Plan.parse_obj({**plan_resp.dict(), "next_task_id": 1})
assert plan.next_task_id == 1
import json
from typing import Optional

from promptulate import BaseMessage, MessageSet
from promptulate.beta.agents.assistant_agent.agent import AssistantAgent
from promptulate.beta.agents.assistant_agent.schema import (
AgentPlanResponse,
Plan,
)
from promptulate.llms.base import BaseLLM


class FakeLLM(BaseLLM):
def _predict(
self, messages: MessageSet, *args, **kwargs
) -> Optional[type(BaseMessage)]:
return None

def __call__(self, instruction: str, *args, **kwargs):
return "FakeLLM output"


def fake_tool_1():
"""Fake tool 1"""
return "Fake tool 1 output"


def test_init_assistant_agent():
llm: BaseLLM = FakeLLM()
agent = AssistantAgent(llm=llm, tools=[fake_tool_1])

assert len(agent.tool_manager.tools) == 1


def test_schema():
raw_data = """{\n \"goals\": [\"Find the hometown of the 2024 Australian Open winner\"],\n \"tasks\": [\n {\n \"task_id\": 1,\n \"description\": \"Identify the winner of the 2024 Australian Open.\"\n },\n {\n \"task_id\": 2,\n \"description\": \"Search for the biography or profile of the identified winner.\"\n },\n {\n \"task_id\": 3,\n \"description\": \"Locate the section of the biography or profile that specifies the player's hometown.\"\n },\n {\n \"task_id\": 4,\n \"description\": \"Record the hometown of the 2024 Australian Open winner.\"\n }\n ]\n}""" # noqa
json_data = json.loads(raw_data)
plan_resp = AgentPlanResponse(**json_data)

plan = Plan.parse_obj({**plan_resp.dict(), "next_task_id": 1})
assert plan.next_task_id == 1
Loading

0 comments on commit 765d6ba

Please sign in to comment.