Skip to content

Commit

Permalink
feat: add tool_call_id equivalent to LettaMessage FunctionCallMessage…
Browse files Browse the repository at this point in the history
… and FunctionResponse
  • Loading branch information
cpacker committed Oct 19, 2024
1 parent 4c08015 commit 71df141
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
12 changes: 8 additions & 4 deletions letta/schemas/letta_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ class FunctionCall(BaseModel):

name: str
arguments: str
function_call_id: str


class FunctionCallDelta(BaseModel):

name: Optional[str]
arguments: Optional[str]
function_call_id: Optional[str]

# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
Expand Down Expand Up @@ -129,10 +131,10 @@ class Config:
@classmethod
def validate_function_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"])
elif "name" in v or "arguments" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"))
if "name" in v and "arguments" in v and "function_call_id" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
elif "name" in v or "arguments" in v or "function_call_id" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
return v
Expand All @@ -147,11 +149,13 @@ class FunctionReturn(LettaMessage):
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
"""

message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str


# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
Expand Down
3 changes: 3 additions & 0 deletions letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def to_letta_message(
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
function_call_id=tool_call.id,
),
)
)
Expand All @@ -203,6 +204,7 @@ def to_letta_message(
raise ValueError(f"Invalid status: {status}")
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {self.text}")
assert self.tool_call_id is not None
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
Expand All @@ -211,6 +213,7 @@ def to_letta_message(
date=self.created_at,
function_return=self.text,
status=status_enum,
function_call_id=self.tool_call_id,
)
)
elif self.role == MessageRole.user:
Expand Down
17 changes: 15 additions & 2 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ def _process_chunk_to_letta_style(
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)

else:
Expand All @@ -548,7 +552,11 @@ def _process_chunk_to_letta_style(
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)

elif choice.finish_reason is not None:
Expand Down Expand Up @@ -759,6 +767,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
function_call_id=function_call.id,
),
)

Expand Down Expand Up @@ -786,21 +795,25 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="success",
function_call_id=msg_obj.tool_call_id,
)

elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="error",
function_call_id=msg_obj.tool_call_id,
)

else:
Expand Down

0 comments on commit 71df141

Please sign in to comment.