Skip to content

Commit

Permalink
Nico/api cron devops (#159)
Browse files Browse the repository at this point in the history
* cp

* new endpoint
  • Loading branch information
nicarq authored Nov 29, 2023
1 parent 9999ba5 commit e539110
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 51 deletions.
5 changes: 3 additions & 2 deletions src/agent/execution/job_execution_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ impl JobManager {
return Err(AgentError::AgentNotFound);
}
};
// new code
shinkai_log(ShinkaiLogOption::JobExecution, ShinkaiLogLevel::Debug, format!("KaiJobFile: {:?}", kai_file).as_str());
match kai_file.schema {
KaiSchemaType::CronJobRequest(cron_task_request) => {
shinkai_log(ShinkaiLogOption::JobExecution, ShinkaiLogLevel::Debug, format!("CronJobRequest: {:?}", cron_task_request).as_str());
// Handle CronJobRequest
JobManager::handle_cron_job_request(
db.clone(),
Expand All @@ -266,7 +267,7 @@ impl JobManager {
return Ok(true);
}
KaiSchemaType::CronJob(cron_task) => {
eprintln!("CronJob: {:?}", cron_task);
shinkai_log(ShinkaiLogOption::JobExecution, ShinkaiLogLevel::Debug, format!("CronJob: {:?}", cron_task).as_str());
// Handle CronJob
JobManager::handle_cron_job(
db.clone(),
Expand Down
6 changes: 3 additions & 3 deletions src/agent/execution/job_execution_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
agent::{
error::AgentError, execution::chains::inference_chain_router::InferenceChain, job::Job, job_manager::JobManager,
},
cron_tasks::web_scrapper::{CronTaskRequest, CronTaskResponse, WebScraper},
cron_tasks::web_scrapper::{CronTaskRequest, CronTaskRequestResponse, WebScraper},
db::{db_cron_task::CronTask, db_errors::ShinkaiDBError, ShinkaiDB},
planner::kai_files::{KaiJobFile, KaiSchemaType},
};
Expand Down Expand Up @@ -50,7 +50,7 @@ impl JobManager {
.await?;

// Prepare data to save inference response to the DB
let cron_task_response = CronTaskResponse {
let cron_task_response = CronTaskRequestResponse {
cron_task_request: cron_task_request,
cron_description: inference_response_content.cron_expression.to_string(),
pddl_plan_problem: inference_response_content.pddl_plan_problem.to_string(),
Expand All @@ -62,7 +62,7 @@ impl JobManager {
let agent = agent_found.ok_or(AgentError::AgentNotFound)?;

let kai_file = KaiJobFile {
schema: KaiSchemaType::CronJobResponse(cron_task_response.clone()),
schema: KaiSchemaType::CronJobRequestResponse(cron_task_response.clone()),
shinkai_profile: Some(profile.clone()),
agent_id: agent.id.clone(),
};
Expand Down
5 changes: 3 additions & 2 deletions src/cron_tasks/web_scrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@ use crate::db::db_cron_task::CronTask;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CronTaskRequest {
pub crawl_links: bool,
pub cron_description: String,
pub task_description: String,
pub object_description: Option<String>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CronTaskResponse {
pub struct CronTaskRequestResponse {
pub cron_task_request: CronTaskRequest,
pub cron_description: String,
pub pddl_plan_problem: String,
pub pddl_plan_domain: Option<String>,
}

impl fmt::Display for CronTaskResponse {
impl fmt::Display for CronTaskRequestResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
Expand Down
8 changes: 7 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::utils::qr_code_setup::generate_qr_codes;
use async_channel::{bounded, Receiver, Sender};
use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey};
use network::Node;
use network::node_api::ExtraAPIConfig;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{IdentityPermissions, RegistrationCodeType};
use shinkai_message_primitives::shinkai_utils::encryption::{
encryption_public_key_to_string, encryption_secret_key_to_string,
Expand Down Expand Up @@ -141,9 +142,14 @@ fn main() {
let _ = generate_qr_codes(&node_commands_sender, &node_env, &node_keys, global_identity_name.as_str(), identity_public_key_string.as_str()).await;
}

let extra_api_config = ExtraAPIConfig {
cron_devops_api_enabled: node_env.cron_devops_api_enabled,
cron_devops_api_token: node_env.cron_devops_api_token.clone(),
};

// API Server task
let api_server = tokio::spawn(async move {
node_api::run_api(node_commands_sender, node_env.api_listen_address).await;
node_api::run_api(node_commands_sender, node_env.api_listen_address, Some(extra_api_config)).await;
});

let _ = tokio::try_join!(api_server, node_task);
Expand Down
3 changes: 2 additions & 1 deletion src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pub mod node_internal_commands;
pub mod node_api_commands;
pub mod node_local_commands;
pub mod node_api;
pub mod node_error;
pub mod node_error;
pub mod node_devops_api_commands;
6 changes: 4 additions & 2 deletions src/network/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex};
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};

use crate::agent::error::AgentError;
use crate::agent::job_manager::JobManager;
use crate::cron_tasks::cron_manager::CronManager;
use crate::db::db_errors::ShinkaiDBError;
use crate::db::db_retry::RetryMessage;
use crate::db::ShinkaiDB;
use crate::managers::identity_manager::{self};
Expand Down Expand Up @@ -225,6 +223,9 @@ pub enum NodeCommand {
full_profile_name: String,
res: Sender<Result<Vec<SerializedAgent>, String>>,
},
APIPrivateDevopsCronList {
res: Sender<Result<String, APIError>>,
}
}

// A type alias for a string that represents a profile name.
Expand Down Expand Up @@ -438,6 +439,7 @@ impl Node {
Some(NodeCommand::APIGetAllSmartInboxesForProfile { msg, res }) => self.api_get_all_smart_inboxes_for_profile(msg, res).await?,
Some(NodeCommand::APIUpdateSmartInboxName { msg, res }) => self.api_update_smart_inbox_name(msg, res).await?,
Some(NodeCommand::APIUpdateJobToFinished { msg, res }) => self.api_update_job_to_finished(msg, res).await?,
Some(NodeCommand::APIPrivateDevopsCronList { res }) => self.api_private_devops_cron_list(res).await?,
_ => break,
}
}
Expand Down
119 changes: 111 additions & 8 deletions src/network/node_api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::node::NodeCommand;
use async_channel::Sender;
use chrono::format;
use futures::Stream;
use futures::Future;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryFutureExt;
use futures::TryStreamExt;
Expand All @@ -14,9 +14,11 @@ use shinkai_message_primitives::shinkai_utils::shinkai_logging::shinkai_log;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::ShinkaiLogLevel;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::ShinkaiLogOption;
use shinkai_message_primitives::shinkai_utils::signatures::signature_public_key_to_string;
use std::collections::HashMap;
use std::net::SocketAddr;
use warp::fs::file;
use std::pin::Pin;
use std::sync::Arc;
use warp::filters::BoxedFilter;
use warp::reply::Reply;
use warp::Buf;
use warp::Filter;

Expand Down Expand Up @@ -70,7 +72,19 @@ impl From<&str> for APIError {
}
}

pub async fn run_api(node_commands_sender: Sender<NodeCommand>, address: SocketAddr) {
impl warp::reject::Reject for APIError {}

#[derive(Clone, Debug)]
pub struct ExtraAPIConfig {
pub cron_devops_api_enabled: bool,
pub cron_devops_api_token: String,
}

pub async fn run_api(
node_commands_sender: Sender<NodeCommand>,
address: SocketAddr,
extra_config: Option<ExtraAPIConfig>,
) {
println!("Starting Node API server at: {}", &address);

let log = warp::log::custom(|info| {
Expand Down Expand Up @@ -331,12 +345,43 @@ pub async fn run_api(node_commands_sender: Sender<NodeCommand>, address: SocketA
})
};

// GET v1/private_devops_cron_list
let extra_config_clone = Arc::new(extra_config.clone());
let private_devops_cron_list = {
let node_commands_sender = node_commands_sender.clone();
let extra_config = extra_config_clone.clone();
warp::path!("v1" / "private_devops_cron_list")
.and(warp::get())
.and(warp::header::<String>("Authorization"))
.and_then(move |auth_header: String| {
let token = auth_header.strip_prefix("Bearer ").unwrap_or("");
let binding = "".to_string();
let expected_token = (*extra_config)
.as_ref()
.map(|config| &config.cron_devops_api_token)
.unwrap_or(&binding);
if token == expected_token {
Box::pin(private_devops_cron_list_handler(node_commands_sender.clone()))
as Pin<Box<dyn Future<Output = Result<_, _>> + Send>>
} else {
Box::pin(async {
Err(warp::reject::custom(APIError::new(
StatusCode::UNAUTHORIZED,
"Unauthorized",
"Invalid token",
)))
})
as Pin<Box<dyn Future<Output = Result<Box<dyn warp::Reply>, warp::Rejection>> + Send>>
}
})
};

let cors = warp::cors() // build the CORS filter
.allow_any_origin() // allow requests from any origin
.allow_methods(vec!["GET", "POST", "OPTIONS"]) // allow GET, POST, and OPTIONS methods
.allow_headers(vec!["Content-Type"]); // allow the Content-Type header
.allow_headers(vec!["Content-Type", "Authorization"]); // allow the Content-Type and Authorization headers

let routes = ping_all
let routes_without_private = ping_all
.or(send_msg)
.or(get_peers)
.or(identity_name_to_external_profile_data)
Expand All @@ -362,7 +407,25 @@ pub async fn run_api(node_commands_sender: Sender<NodeCommand>, address: SocketA
.or(update_job_to_finished)
.recover(handle_rejection)
.with(log)
.with(cors);
.with(cors)
.boxed();

let routes_without_private = routes_without_private.boxed();

let noop_route = warp::any()
.map(warp::reply)
.and_then(|reply| async move { Ok::<_, warp::Rejection>(Box::new(reply) as Box<dyn warp::Reply>) })
.boxed();

let routes = if let Some(config) = extra_config.clone() {
if config.cron_devops_api_enabled {
routes_without_private.or(private_devops_cron_list).boxed()
} else {
routes_without_private.or(noop_route).boxed()
}
} else {
routes_without_private.or(noop_route).boxed()
};

warp::serve(routes).run(address).await;

Expand Down Expand Up @@ -402,6 +465,34 @@ where
}
}

async fn handle_node_command_without_message<T, U>(
node_commands_sender: Sender<NodeCommand>,
command: T,
) -> Result<Box<dyn warp::Reply>, warp::Rejection>
where
T: FnOnce(Sender<NodeCommand>, Sender<Result<U, APIError>>) -> NodeCommand,
U: Serialize,
{
let (res_sender, res_receiver) = async_channel::bounded(1);
node_commands_sender
.clone()
.send(command(node_commands_sender, res_sender))
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(message) => Ok(Box::new(warp::reply::with_status(
warp::reply::json(&json!({"status": "success", "data": message})),
StatusCode::OK,
)) as Box<dyn warp::Reply>),
Err(error) => Ok(Box::new(warp::reply::with_status(
warp::reply::json(&json!({"status": "error", "error": error})),
StatusCode::from_u16(error.code).unwrap(),
)) as Box<dyn warp::Reply>),
}
}

async fn ping_all_handler(node_commands_sender: Sender<NodeCommand>) -> Result<impl warp::Reply, warp::Rejection> {
match node_commands_sender.send(NodeCommand::PingAll).await {
Ok(_) => Ok(warp::reply::json(&json!({
Expand All @@ -413,6 +504,18 @@ async fn ping_all_handler(node_commands_sender: Sender<NodeCommand>) -> Result<i
}
}

async fn private_devops_cron_list_handler(
node_commands_sender: Sender<NodeCommand>,
) -> Result<Box<dyn warp::Reply>, warp::Rejection> {
handle_node_command_without_message(
node_commands_sender,
|node_commands_sender, res_sender| NodeCommand::APIPrivateDevopsCronList {
res: res_sender,
},
)
.await
}

async fn send_msg_handler(
node_commands_sender: Sender<NodeCommand>,
message: ShinkaiMessage,
Expand Down
63 changes: 63 additions & 0 deletions src/network/node_devops_api_commands.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use super::{
node_api::{APIError, APIUseRegistrationCodeSuccessResponse},
node_error::NodeError,
Node,
};
use crate::{
managers::identity_manager::{self, IdentityManager},
network::node_message_handlers::{ping_pong, PingPong},
planner::{kai_files::KaiJobFile, kai_manager::KaiJobFileManager},
schemas::{
identity::{DeviceIdentity, Identity, IdentityType, RegistrationCode, StandardIdentity, StandardIdentityType},
inbox_permission::InboxPermission,
smart_inbox::SmartInbox,
},
};
use async_channel::Sender;
use blake3::Hasher;
use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey};
use reqwest::StatusCode;
use shinkai_message_primitives::{
schemas::shinkai_name::{ShinkaiName, ShinkaiNameError, ShinkaiSubidentityType},
shinkai_message::{
shinkai_message::{MessageBody, MessageData, ShinkaiMessage},
shinkai_message_schemas::{
APIAddAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions,
MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType,
},
},
shinkai_utils::{
encryption::{
clone_static_secret_key, encryption_public_key_to_string, encryption_secret_key_to_string,
string_to_encryption_public_key, EncryptionMethod,
},
shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption},
signatures::{clone_signature_secret_key, signature_public_key_to_string, string_to_signature_public_key},
},
};
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};

impl Node {
pub async fn api_private_devops_cron_list(&self, res: Sender<Result<String, APIError>>) -> Result<(), NodeError> {
// Call the get_all_cron_tasks_from_all_profiles function
match self.db.lock().await.get_all_cron_tasks_from_all_profiles() {
Ok(tasks) => {
eprintln!("Got {} cron tasks", tasks.len());
// If everything went well, send the tasks back as a JSON string
let tasks_json = serde_json::to_string(&tasks).unwrap();
let _ = res.send(Ok(tasks_json)).await;
Ok(())
}
Err(err) => {
// If there was an error, send the error message
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("{}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}
}
}
6 changes: 3 additions & 3 deletions src/planner/kai_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Serialize, Deserialize};
use serde_json::Value;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;

use crate::{cron_tasks::web_scrapper::{CronTaskRequest, CronTaskResponse}, db::db_cron_task::CronTask};
use crate::{cron_tasks::web_scrapper::{CronTaskRequest, CronTaskRequestResponse}, db::db_cron_task::CronTask};

// Define your schema types here
#[derive(Debug, Serialize, Deserialize, Clone)]
Expand All @@ -11,7 +11,7 @@ pub enum KaiSchemaType {
#[serde(rename = "cronjobrequest")]
CronJobRequest(CronTaskRequest),
#[serde(rename = "cronjobresponse")]
CronJobResponse(CronTaskResponse),
CronJobRequestResponse(CronTaskRequestResponse),
#[serde(rename = "cronjob")]
CronJob(CronTask),
}
Expand All @@ -30,7 +30,7 @@ impl KaiJobFile {
KaiSchemaType::CronJobRequest(cron_task_request) => {
serde_json::to_value(cron_task_request)
},
KaiSchemaType::CronJobResponse(cron_task_response) => {
KaiSchemaType::CronJobRequestResponse(cron_task_response) => {
serde_json::to_value(cron_task_response)
},
KaiSchemaType::CronJob(cron_job) => {
Expand Down
Loading

0 comments on commit e539110

Please sign in to comment.