Skip to content

Commit

Permalink
bug fix after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo committed Sep 5, 2024
1 parent d1ab450 commit 9c2cd1b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 33 deletions.
102 changes: 73 additions & 29 deletions src/agentscope/cpp_server/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ using std::to_string;
using namespace pybind11::literals;

using WorkerArgs::AgentArgs;
using WorkerArgs::AgentMemoryReturn;
using WorkerArgs::CreateAgentArgs;
using WorkerArgs::ModelConfigsArgs;
using WorkerArgs::ObserveArgs;
using WorkerArgs::ReplyArgs;
using WorkerArgs::ReplyReturn;
using WorkerArgs::MsgReturn;
using WorkerArgs::AgentListReturn;

Task::Task(const int task_id)
: _task_id(task_id),
Expand Down Expand Up @@ -85,6 +85,9 @@ Worker::Worker(
_max_tasks(std::max(max_tasks, 1u)),
_max_timeout_seconds(std::max(max_timeout_seconds, 1u))
{
py::object serialize_lib = py::module::import("agentscope.serialize");
_serialize = serialize_lib.attr("serialize");
_deserialize = serialize_lib.attr("deserialize");
py::gil_scoped_release release;
struct stat info;
if (stat("./logs/", &info) != 0)
Expand Down Expand Up @@ -526,7 +529,7 @@ pair<bool, string> Worker::get_task_result(const int task_id)
{
string result_str = _tasks[idx].second->get_result();
logger("get_task_result 3: task_id = " + to_string(task_id) + " idx = " + to_string(idx) + " result_str = [" + result_str + "]");
ReplyReturn result;
MsgReturn result;
result.ParseFromString(result_str);
logger("get_task_result 4: task_id = " + to_string(task_id) + " idx = " + to_string(idx) + " result_ok = " + to_string(result.ok()) + " result_str = [" + result_str + "]");
return make_pair(result.ok(), result.message());
Expand Down Expand Up @@ -730,33 +733,49 @@ string Worker::call_get_agent_list()
call_id_list.push_back(call_id);
}
}
string final_result = "[";
// string final_result = "[";
vector<string> result_list;
for (auto call_id : call_id_list)
{
string result = get_result(call_id);
logger("call_get_agent_list 1: call_id = " + to_string(call_id) + " result = [" + result + "]");
if (final_result != "[" && !result.empty())
final_result += ",";
final_result += result;
string result_str = get_result(call_id);
AgentListReturn result;
result.ParseFromString(result_str);
for (const auto &agent_str : result.agent_str_list())
{
result_list.push_back(agent_str);
}
// logger("call_get_agent_list 1: call_id = " + to_string(call_id) + " result = [" + result + "]");
// if (final_result != "[" && !result.empty())
// final_result += ",";
// final_result += result;
}
final_result += "]";
// final_result += "]";
py::gil_scoped_acquire acquire;
logger("call_get_agent_list 1: result_list.size() = [" + to_string(result_list.size()) + "]");
// py::object serialize_lib = py::module::import("agentscope.serialize");
string final_result = _serialize(result_list).cast<string>();
logger("call_get_agent_list 2: result = [" + final_result + "]");
return final_result;
}

void Worker::get_agent_list_worker(const int call_id)
{
py::gil_scoped_acquire acquire;
vector<string> agent_str_list;
// vector<string> agent_str_list;
AgentListReturn result;
{
shared_lock<shared_mutex> lock(_agent_pool_mutex);
for (auto &iter : _agent_pool)
{
agent_str_list.push_back(iter.second.attr("__str__")().cast<string>());
// agent_str_list.push_back(iter.second.attr("__str__")().cast<string>());
result.add_agent_str_list(iter.second.attr("__str__")().cast<string>());
}
}
string result = py::module::import("json").attr("dumps")(agent_str_list).cast<string>();
set_result(call_id, result.substr(1, result.size() - 2));
// string result = py::module::import("json").attr("dumps")(agent_str_list).cast<string>();
// py::object serialize_lib = py::module::import("agentscope.serialize");
// string result = serialize_lib.attr("serialize")(agent_str_list).cast<string>();
// set_result(call_id, result.substr(1, result.size() - 2));
set_result(call_id, result.SerializeAsString());
}

string Worker::call_set_model_configs(const string &model_configs)
Expand Down Expand Up @@ -800,8 +819,10 @@ pair<bool, string> Worker::call_get_agent_memory(const string &agent_id)
AgentArgs args;
args.set_agent_id(agent_id);
int call_id = call_worker_func(worker_id, function_ids::get_agent_memory, &args);
string result = get_result(call_id);
return make_pair(result[0] == 'T', result.substr(1, result.size() - 1));
string result_str = get_result(call_id);
MsgReturn result;
result.ParseFromString(result_str);
return make_pair(result.ok(), result.message());
}

void Worker::get_agent_memory_worker(const int call_id)
Expand All @@ -814,15 +835,23 @@ void Worker::get_agent_memory_worker(const int call_id)
shared_lock<shared_mutex> lock(_agent_pool_mutex);
py::object agent = _agent_pool[agent_id];
py::object memory = agent.attr("memory");
MsgReturn result;
if (memory.is_none())
{
set_result(call_id, "FAgent [" + agent_id + "] has no memory.");
// set_result(call_id, "FAgent [" + agent_id + "] has no memory.");
result.set_ok(false);
result.set_message("Agent [" + agent_id + "] has no memory.");
}
else
{
py::object memory_info = memory.attr("get_memory")();
set_result(call_id, "T" + py::module::import("json").attr("dumps")(memory_info).cast<string>());
// py::object serialize_lib = py::module::import("agentscope.serialize");
string memory_msg = _serialize(memory_info).cast<string>();
result.set_ok(true);
result.set_message(memory_msg);
// set_result(call_id, "T" + py::module::import("json").attr("dumps")(memory_info).cast<string>());
}
set_result(call_id, result.SerializeAsString());
}

pair<bool, string> Worker::call_reply(const string &agent_id, const string &message)
Expand Down Expand Up @@ -863,12 +892,10 @@ void Worker::reply_worker(const int call_id)
shared_lock<shared_mutex> lock(_agent_pool_mutex);
py::object agent = _agent_pool[agent_id];
py::object message_lib = py::module::import("agentscope.message");
py::object py_message = message.size() ? message_lib.attr("deserialize")(message) : py::none();
// py::object serialize_lib = py::module::import("agentscope.serialize");
py::object py_message = message.size() ? _deserialize(message) : py::none();

py::object msg_class = message_lib.attr("Msg");
py::object msg = msg_class(
"name"_a = agent.attr("name"), "content"_a = py::none(), "task_id"_a = task_id);
string msg_str = msg.attr("serialize")().cast<string>();
string msg_str = to_string(task_id);
logger("reply_worker 3: call_id = " + to_string(call_id) + " agent_id = " + agent_id + " task_id = " + to_string(task_id) + " callback_id = " + to_string(callback_id) + " msg_str = " + msg_str);
set_result(call_id, msg_str);

Expand All @@ -878,12 +905,13 @@ void Worker::reply_worker(const int call_id)
{
py_message.attr("update_value")();
}
ReplyReturn result;
MsgReturn result;
try
{
logger("reply_worker 3.1: call_id = " + to_string(call_id) + " agent_id = " + agent_id + " task_id = " + to_string(task_id) + " callback_id = " + to_string(callback_id) + " call reply");
result.set_ok(true);
result.set_message(agent.attr("reply")(py_message).attr("serialize")().cast<string>());
py::object reply_msg = agent.attr("reply")(py_message);
result.set_message(_serialize(reply_msg).cast<string>());
}
catch (const std::exception &e)
{
Expand All @@ -907,6 +935,7 @@ pair<bool, string> Worker::call_observe(const string &agent_id, const string &me
args.set_message(message);
int call_id = call_worker_func(worker_id, function_ids::observe, &args);
string result = get_result(call_id);
logger("call_observe 2: call_id = " + to_string(call_id) + " result = " + result);
return make_pair(true, result);
}

Expand All @@ -922,13 +951,28 @@ void Worker::observe_worker(const int call_id)
py::object agent = _agent_pool[agent_id];
py::object message_lib = py::module::import("agentscope.message");
py::object PlaceholderMessage_class = message_lib.attr("PlaceholderMessage");
// py::object serialize_lib = py::module::import("agentscope.serialize");
logger("observe_worker 1: call_id = " + to_string(call_id) + " message = " + message);
py::object py_messages = message.size() ? message_lib.attr("deserialize")(message) : py::list();
for (auto &py_message : py_messages)
py::object py_messages = message.size() ? _deserialize(message) : py::list();
// if (py::isinstance(py_messages, py::list()))
// {
// py_messages.attr("update_value")();
// }
if (py::isinstance<py::list>(py_messages))
{
for (auto &py_message : py_messages)
{
if (py::isinstance(py_message, PlaceholderMessage_class))
{
py_message.attr("update_value")();
}
}
}
else
{
if (py::isinstance(py_message, PlaceholderMessage_class))
if (py::isinstance(py_messages, PlaceholderMessage_class))
{
py_message.attr("update_value")();
py_messages.attr("update_value")();
}
}
py::print("observe_worker: py_messages = ", py_messages);
Expand Down
3 changes: 3 additions & 0 deletions src/agentscope/cpp_server/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class Worker
const unsigned int _max_tasks;
const unsigned int _max_timeout_seconds;

// common used functions
py::object _serialize, _deserialize;

enum function_ids
{
create_agent = 0,
Expand Down
7 changes: 3 additions & 4 deletions src/agentscope/rpc/worker_args.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ message ModelConfigsArgs {
bytes model_configs = 1;
}

message AgentMemoryReturn {
bool ok = 1;
string message = 2;
message AgentListReturn {
repeated string agent_str_list = 1;
}

message ReplyReturn {
message MsgReturn {
bool ok = 1;
string message = 2;
}

0 comments on commit 9c2cd1b

Please sign in to comment.