Skip to content

Commit

Permalink
Merge pull request #625 from dcSpark/bence/remove-job-api
Browse files Browse the repository at this point in the history
Remove job API
  • Loading branch information
nicarq authored Oct 24, 2024
2 parents 10e4188 + a7c0165 commit a0d4789
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 1 deletion.
6 changes: 6 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,12 @@ impl Node {
.await;
});
}
NodeCommand::V2ApiRemoveJob { bearer, job_id, res } => {
let db_clone = self.db.clone();
tokio::spawn(async move {
let _ = Node::v2_remove_job(db_clone, bearer, job_id, res).await;
});
}
NodeCommand::V2ApiVecFSRetrievePathSimplifiedJson { bearer, payload, res } => {
let db_clone = Arc::clone(&self.db);
let vector_fs_clone = self.vector_fs.clone();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1348,4 +1348,33 @@ impl Node {
let _ = res.send(Ok(forked_job_id)).await;
Ok(())
}

pub async fn v2_remove_job(
db: Arc<ShinkaiDB>,
bearer: String,
job_id: String,
res: Sender<Result<(), APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Remove the job
match db.remove_job(&job_id) {
Ok(_) => {
let _ = res.send(Ok(())).await;
Ok(())
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to remove job: {}", err),
};
let _ = res.send(Err(api_error)).await;
Ok(())
}
}
}
}
29 changes: 29 additions & 0 deletions shinkai-bin/shinkai-node/tests/it/db_job_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -975,4 +975,33 @@ mod tests {
assert_eq!(job.forked_jobs[1].job_id, forked_job2_id);
assert_eq!(job.forked_jobs[1].message_id, forked_message2_id);
}

#[test]
fn test_remove_job() {
setup();
let job_id = "job1".to_string();
let agent_id = "agent1".to_string();
let scope = JobScope::new_default();
let db_path = format!("db_tests/{}", hash_string(&agent_id.clone().to_string()));
let mut shinkai_db = ShinkaiDB::new(&db_path).unwrap();

// Create a new job
create_new_job(&mut shinkai_db, job_id.clone(), agent_id.clone(), scope);

// Retrieve all jobs
let jobs = shinkai_db.get_all_jobs().unwrap();

// Check if the job exists
let job_ids: Vec<String> = jobs.iter().map(|job| job.job_id().to_string()).collect();
assert!(job_ids.contains(&job_id));

// Remove the job
shinkai_db.remove_job(&job_id).unwrap();

// Check if the job is removed
match shinkai_db.get_job(&job_id) {
Ok(_) => panic!("Expected an error when getting a removed job"),
Err(e) => assert_eq!(e, ShinkaiDBError::DataNotFound),
}
}
}
88 changes: 88 additions & 0 deletions shinkai-libs/shinkai-db/src/db/db_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,94 @@ impl ShinkaiDB {
Ok(())
}

/// Removes a job from the DB
pub fn remove_job(&self, job_id: &str) -> Result<(), ShinkaiDBError> {
let cf_inbox = self.get_cf_handle(Topic::Inbox)?;

// Construct keys with job_id as part of the key
let job_scope_key = format!("jobinbox_{}_scope", job_id);
let job_is_finished_key = format!("jobinbox_{}_is_finished", job_id);
let job_datetime_created_key = format!("jobinbox_{}_datetime_created", job_id);
let job_parent_providerid = format!("jobinbox_{}_agentid", job_id);
let job_parent_llm_provider_id_key =
format!("jobinbox_agent_{}_{}", Self::llm_provider_id_to_hash(&job_id), job_id);
let job_inbox_name = format!("jobinbox_{}_inboxname", job_id);
let job_conversation_inbox_name_key = format!("jobinbox_{}_conversation_inbox_name", job_id);
let all_jobs_time_keyed = format!("all_jobs_time_keyed_placeholder_to_fit_prefix__{}", job_id);
let job_smart_inbox_name_key = format!("{}_smart_inbox_name", job_id);
let job_is_hidden_key = format!("jobinbox_{}_is_hidden", job_id);
let job_read_list_key = format!("jobinbox_{}_read_list", job_id);
let job_config_key = format!("jobinbox_{}_config", job_id);
let job_associated_ui_key = format!("jobinbox_{}_associated_ui", job_id);

// Start a write batch
let mut batch = rocksdb::WriteBatch::default();

// Delete the job attributes from the database
batch.delete_cf(cf_inbox, job_scope_key.as_bytes());
batch.delete_cf(cf_inbox, job_is_finished_key.as_bytes());
batch.delete_cf(cf_inbox, job_datetime_created_key.as_bytes());
batch.delete_cf(cf_inbox, job_parent_providerid.as_bytes());
batch.delete_cf(cf_inbox, job_parent_llm_provider_id_key.as_bytes());
batch.delete_cf(cf_inbox, job_inbox_name.as_bytes());
batch.delete_cf(cf_inbox, job_conversation_inbox_name_key.as_bytes());
batch.delete_cf(cf_inbox, all_jobs_time_keyed.as_bytes());
batch.delete_cf(cf_inbox, job_smart_inbox_name_key.as_bytes());
batch.delete_cf(cf_inbox, job_is_hidden_key.as_bytes());
batch.delete_cf(cf_inbox, job_read_list_key.as_bytes());
batch.delete_cf(cf_inbox, job_config_key.as_bytes());
batch.delete_cf(cf_inbox, job_associated_ui_key.as_bytes());

// Remove step history
let inbox_name = InboxName::get_job_inbox_name_from_params(job_id.to_string())?;
let mut until_offset_key: Option<String> = None;
loop {
let messages = self.get_last_messages_from_inbox(inbox_name.to_string(), 2, until_offset_key.clone())?;
if messages.is_empty() {
break;
}

for message_path in &messages {
if let Some(message) = message_path.first() {
let message_key = message.calculate_message_hash_for_pagination();
let hash_message_key = Self::message_key_to_hash(message_key);
let prefix = format!("step_history__{}_", hash_message_key);
let iter = self.db.prefix_iterator_cf(cf_inbox, prefix.as_bytes());
for item in iter {
if let Ok((key, _)) = item {
batch.delete_cf(cf_inbox, &key);
}
}
}
}

if let Some(last_message_path) = messages.last() {
if let Some(last_message) = last_message_path.first() {
until_offset_key = Some(last_message.calculate_message_hash_for_pagination());
} else {
break;
}
} else {
break;
}
}

// Remove unprocessed messages
let job_hash = Self::job_id_to_hash(job_id);
let prefix = format!("job_unprocess_{}_", job_hash);
let iter = self.db.prefix_iterator_cf(cf_inbox, prefix.as_bytes());
for item in iter {
if let Ok((key, _)) = item {
batch.delete_cf(cf_inbox, &key);
}
}

// Commit the write batch
self.db.write(batch)?;

Ok(())
}

pub fn add_forked_job(&self, job_id: &str, forked_job: ForkedJob) -> Result<(), ShinkaiDBError> {
let cf_inbox = self.get_cf_handle(Topic::Inbox).unwrap();
let forked_jobs_key = format!("jobinbox_{}_forked_jobs", job_id);
Expand Down
55 changes: 54 additions & 1 deletion shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ pub fn job_routes(
.and(warp::body::json())
.and_then(fork_job_messages_handler);

let remove_job_route = warp::path("remove_job")
.and(warp::post())
.and(with_sender(node_commands_sender.clone()))
.and(warp::header::<String>("authorization"))
.and(warp::body::json())
.and_then(remove_job_handler);

create_job_route
.or(job_message_route)
.or(get_last_messages_route)
Expand All @@ -173,6 +180,7 @@ pub fn job_routes(
.or(get_job_scope_route)
.or(get_tooling_logs_route)
.or(fork_job_messages_route)
.or(remove_job_route)
}

#[derive(Deserialize, ToSchema)]
Expand Down Expand Up @@ -227,6 +235,11 @@ pub struct ForkJobMessagesRequest {
pub message_id: String,
}

#[derive(Deserialize, ToSchema)]
pub struct RemoveJobRequest {
pub job_id: String,
}

#[utoipa::path(
post,
path = "/v2/retry_message",
Expand Down Expand Up @@ -1030,6 +1043,45 @@ pub async fn fork_job_messages_handler(
}
}

#[utoipa::path(
post,
path = "/v2/remove_job",
request_body = RemoveJobRequest,
responses(
(status = 200, description = "Successfully removed job", body = Value),
(status = 400, description = "Bad request", body = APIError),
(status = 500, description = "Internal server error", body = APIError)
)
)]
pub async fn remove_job_handler(
node_commands_sender: Sender<NodeCommand>,
authorization: String,
payload: RemoveJobRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
let (res_sender, res_receiver) = async_channel::bounded(1);
node_commands_sender
.send(NodeCommand::V2ApiRemoveJob {
bearer,
job_id: payload.job_id,
res: res_sender,
})
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(response) => {
let response = create_success_response(response);
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK))
}
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&error),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}

#[derive(OpenApi)]
#[openapi(
paths(
Expand All @@ -1050,6 +1102,7 @@ pub async fn fork_job_messages_handler(
get_job_scope_handler,
get_tooling_logs_handler,
fork_job_messages_handler,
remove_job_handler,
),
components(
schemas(AddFileToInboxRequest, V2SmartInbox, APIChangeJobAgentRequest, CreateJobRequest, JobConfig,
Expand All @@ -1059,7 +1112,7 @@ pub async fn fork_job_messages_handler(
VectorFSItemScopeEntry, VectorFSFolderScopeEntry, NetworkFolderScopeEntry, CallbackAction, ShinkaiName,
LLMProviderInterface, RetryMessageRequest, UpdateJobScopeRequest,
ShinkaiSubidentityType, OpenAI, Ollama, LocalLLM, Groq, Gemini, Exo, ShinkaiBackend, SheetManagerAction,
SheetJobAction, SendResponseBody, SendResponseBodyData, APIError, GetToolingLogsRequest, ForkJobMessagesRequest)
SheetJobAction, SendResponseBody, SendResponseBodyData, APIError, GetToolingLogsRequest, ForkJobMessagesRequest, RemoveJobRequest)
),
tags(
(name = "jobs", description = "Job API endpoints")
Expand Down
5 changes: 5 additions & 0 deletions shinkai-libs/shinkai-http-api/src/node_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,11 @@ pub enum NodeCommand {
message_id: String,
res: Sender<Result<String, APIError>>,
},
V2ApiRemoveJob {
bearer: String,
job_id: String,
res: Sender<Result<(), APIError>>,
},
V2ApiVecFSRetrievePathSimplifiedJson {
bearer: String,
payload: APIVecFsRetrievePathSimplifiedJson,
Expand Down

0 comments on commit a0d4789

Please sign in to comment.