Skip to content

Commit

Permalink
Merge pull request #637 from dcSpark/nico/fix_job_scope_architecture
Browse files Browse the repository at this point in the history
Nico/fix job scope architecture
  • Loading branch information
nicarq authored Oct 31, 2024
2 parents cf05063 + 8d38415 commit 8885f07
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 183 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@ shinkai-bin/shinkai-node/ocrs/text-recognition.rten
logs.json
example.db-shm
example.db-wal
libpdfium.dylib
5 changes: 0 additions & 5 deletions docs/openapi/jobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,6 @@ components:
- local_vrpack
- vector_fs_items
- vector_fs_folders
- network_folders
properties:
local_vrkai:
type: array
Expand All @@ -1075,10 +1074,6 @@ components:
type: array
items:
$ref: '#/components/schemas/LocalScopeVRPackEntry'
network_folders:
type: array
items:
$ref: '#/components/schemas/NetworkFolderScopeEntry'
vector_fs_folders:
type: array
items:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ impl AsyncFunction for OpinionatedInferenceFunction {
// If both the scope and custom_system_prompt are not empty, we use an empty string
// for the user_message and the custom_system_prompt as the query_text.
// This allows for more focused searches based on the system prompt when a scope is provided.
let (effective_user_message, query_text) = if !full_job.scope().is_empty() && custom_system_prompt.is_some() {
let (effective_user_message, query_text) = if !full_job.scope_with_files().unwrap().is_empty() && custom_system_prompt.is_some() {
("".to_string(), custom_system_prompt.clone().unwrap_or_default())
} else {
(user_message.clone(), user_message.clone())
Expand All @@ -411,15 +411,15 @@ impl AsyncFunction for OpinionatedInferenceFunction {
// TODO: extract files from args

// If we need to search for nodes using the scope
let scope_is_empty = full_job.scope().is_empty();
let scope_is_empty = full_job.scope_with_files().unwrap().is_empty();
let mut ret_nodes: Vec<RetrievedNode> = vec![];
let mut summary_node_text = None;
if !scope_is_empty {
// TODO: this should also be a generic fn
let (ret, summary) = JobManager::keyword_chained_job_scope_vector_search(
db.clone(),
vector_fs.clone(),
full_job.scope(),
full_job.scope_with_files().unwrap(),
query_text.clone(),
user_profile,
generator.clone(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures::{future::join_all, StreamExt};
use serde_json::json;
use shinkai_message_primitives::{
schemas::subprompts::SubPrompt,
schemas::{job::JobLike, subprompts::SubPrompt},
shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption},
};
use std::{any::Any, collections::HashMap};
Expand Down Expand Up @@ -220,7 +220,7 @@ pub fn process_embeddings_in_job_scope(
.block_on(async {
let vector_fs = context.vector_fs();
let user_profile = context.user_profile();
let scope = context.full_job().scope.clone();
let scope = context.full_job().scope_with_files().clone().unwrap();

let resource_stream =
JobManager::retrieve_all_resources_in_job_scope_stream(vector_fs.clone(), &scope, user_profile)
Expand Down Expand Up @@ -285,7 +285,7 @@ pub fn process_embeddings_in_job_scope_with_metadata(
.block_on(async {
let vector_fs = context.vector_fs();
let user_profile = context.user_profile();
let scope = context.full_job().scope.clone();
let scope = context.full_job().scope_with_files().clone().unwrap();

let resource_stream =
JobManager::retrieve_all_resources_in_job_scope_stream(vector_fs.clone(), &scope, user_profile)
Expand Down Expand Up @@ -357,7 +357,7 @@ pub fn search_embeddings_in_job_scope(
let db = context.db();
let vector_fs = context.vector_fs();
let user_profile = context.user_profile();
let job_scope = context.full_job().scope.clone();
let job_scope = context.full_job().scope_with_files().clone().unwrap();
let generator = context.generator();

let result = JobManager::keyword_chained_job_scope_vector_search(
Expand Down Expand Up @@ -637,7 +637,6 @@ mod tests {
path: VRPath::root(),
name: "/".to_string(),
}],
network_folders: Vec::new(),
vector_search_mode: Vec::new(),
};
shinkai_db_arc
Expand Down Expand Up @@ -773,7 +772,6 @@ mod tests {
path: VRPath::root(),
name: "/".to_string(),
}],
network_folders: Vec::new(),
vector_search_mode: Vec::new(),
};
shinkai_db_arc
Expand Down Expand Up @@ -924,7 +922,6 @@ mod tests {
path: VRPath::root(),
name: "/".to_string(),
}],
network_folders: Vec::new(),
vector_search_mode: Vec::new(),
};
shinkai_db_arc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ impl GenericInferenceChain {
*/

// 1) Vector search for knowledge if the scope isn't empty
let scope_is_empty = full_job.scope().is_empty();
let scope_is_empty = full_job.scope_with_files().unwrap().is_empty();
let mut ret_nodes: Vec<RetrievedNode> = vec![];
let mut summary_node_text = None;
if !scope_is_empty {
let (ret, summary) = JobManager::keyword_chained_job_scope_vector_search(
db.clone(),
vector_fs.clone(),
full_job.scope(),
full_job.scope_with_files().unwrap(),
user_message.clone(),
&user_profile,
generator.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl SheetUIInferenceChain {
// 1) Vector search for knowledge if the scope isn't empty
// Check if the sheet has uploaded files and add them to the job scope
let job_scope = if let Some(sheet_manager) = &sheet_manager {
let mut job_scope = full_job.scope().clone();
let mut job_scope = full_job.scope_with_files.clone().unwrap();

let sheet = {
let sheet_manager_guard = sheet_manager.lock().await;
Expand Down Expand Up @@ -189,7 +189,7 @@ impl SheetUIInferenceChain {
}
job_scope
} else {
full_job.scope().clone()
full_job.scope_with_files.clone().unwrap()
};

let scope_is_empty = job_scope.is_empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,9 @@ impl JobManager {
.map_err(|e| LLMProviderError::InvalidVRPath(e.to_string()))?,
source: VRSourceReference::None,
};
mutable_job.scope.vector_fs_items.push(vector_fs_entry);

// Unwrap the scope_with_files since you are sure it is always Some
mutable_job.scope_with_files.as_mut().unwrap().vector_fs_items.push(vector_fs_entry);
}

// Determine the workflow to use
Expand Down Expand Up @@ -900,67 +902,66 @@ impl JobManager {

match new_scope_entries_result {
Ok(new_scope_entries) => {
// Update the job scope with new entries
for (_filename, scope_entry) in new_scope_entries {
match scope_entry {
ScopeEntry::LocalScopeVRKai(local_entry) => {
if !full_job.scope.local_vrkai.contains(&local_entry) {
full_job.scope.local_vrkai.push(local_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate LocalScopeVRKaiEntry detected",
);
// Safely unwrap the scope_with_files
let job_id = full_job.job_id().to_string();
if let Some(ref mut scope_with_files) = full_job.scope_with_files {
// Update the job scope with new entries
for (_filename, scope_entry) in new_scope_entries {
match scope_entry {
ScopeEntry::LocalScopeVRKai(local_entry) => {
if !scope_with_files.local_vrkai.contains(&local_entry) {
scope_with_files.local_vrkai.push(local_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate LocalScopeVRKaiEntry detected",
);
}
}
}
ScopeEntry::LocalScopeVRPack(local_entry) => {
if !full_job.scope.local_vrpack.contains(&local_entry) {
full_job.scope.local_vrpack.push(local_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate LocalScopeVRPackEntry detected",
);
ScopeEntry::LocalScopeVRPack(local_entry) => {
if !scope_with_files.local_vrpack.contains(&local_entry) {
scope_with_files.local_vrpack.push(local_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate LocalScopeVRPackEntry detected",
);
}
}
}
ScopeEntry::VectorFSItem(fs_entry) => {
if !full_job.scope.vector_fs_items.contains(&fs_entry) {
full_job.scope.vector_fs_items.push(fs_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate VectorFSScopeEntry detected",
);
ScopeEntry::VectorFSItem(fs_entry) => {
if !scope_with_files.vector_fs_items.contains(&fs_entry) {
scope_with_files.vector_fs_items.push(fs_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate VectorFSScopeEntry detected",
);
}
}
}
ScopeEntry::VectorFSFolder(fs_entry) => {
if !full_job.scope.vector_fs_folders.contains(&fs_entry) {
full_job.scope.vector_fs_folders.push(fs_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate VectorFSScopeEntry detected",
);
}
}
ScopeEntry::NetworkFolder(nf_entry) => {
if !full_job.scope.network_folders.contains(&nf_entry) {
full_job.scope.network_folders.push(nf_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate VectorFSScopeEntry detected",
);
ScopeEntry::VectorFSFolder(fs_entry) => {
if !scope_with_files.vector_fs_folders.contains(&fs_entry) {
scope_with_files.vector_fs_folders.push(fs_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"Duplicate VectorFSScopeEntry detected",
);
}
}
}
}
db.update_job_scope(job_id, scope_with_files.clone())?;
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
"No scope_with_files found in full_job",
);
}
db.update_job_scope(full_job.job_id().to_string(), full_job.scope.clone())?;
}
Err(e) => {
shinkai_log(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn gemini_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
let remaining_output_tokens = ModelCapabilitiesManager::get_remaining_output_tokens(model, used_tokens);

// Separate messages into those with a user / assistant / system role and those without
let (mut messages_with_role, _tools): (Vec<_>, Vec<_>) = chat_completion_messages
let (mut messages_with_role, tools): (Vec<_>, Vec<_>) = chat_completion_messages
.into_iter()
.partition(|message| message.role.is_some());

Expand Down Expand Up @@ -73,6 +73,18 @@ pub fn gemini_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
_ => vec![],
};

// Extract functions from tools
let functions_vec = tools.into_iter().filter_map(|tool| {
if let Some(function_call) = tool.function_call {
Some(serde_json::json!({
"name": function_call.name,
"arguments": function_call.arguments,
}))
} else {
None
}
}).collect::<Vec<_>>();

// Separate system instruction from other messages
let system_instruction = messages_vec
.clone()
Expand Down Expand Up @@ -130,7 +142,8 @@ pub fn gemini_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
"role": role,
"parts": content
})
}).collect::<Vec<_>>()
}).collect::<Vec<_>>(),
"functions": functions_vec
});

Ok(PromptResult {
Expand Down Expand Up @@ -209,7 +222,8 @@ mod tests {
}
}]
}
]
],
"functions": []
});

// Assert the results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use shinkai_message_primitives::{
};
use shinkai_vector_fs::welcome_files::welcome_message::WELCOME_MESSAGE;
use shinkai_vector_resources::vector_resource::VRPath;
use std::time::Instant;
use std::{io::Error, net::SocketAddr};
use std::{str::FromStr, sync::Arc};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -125,6 +126,8 @@ impl Node {
return Vec::new();
}
};
// Start the timer
let start = Instant::now();
let result = match db.get_inboxes_for_profile(standard_identity) {
Ok(inboxes) => inboxes,
Err(e) => {
Expand All @@ -136,6 +139,9 @@ impl Node {
return Vec::new();
}
};
// Measure the elapsed time
let duration = start.elapsed();
println!("Time taken to get all inboxes: {:?}", duration);

result
}
Expand Down Expand Up @@ -435,7 +441,6 @@ impl Node {
local_vrpack: vec![],
vector_fs_items: vec![],
vector_fs_folders: vec![shinkai_folder_fs],
network_folders: vec![],
vector_search_mode: vec![],
};
let job_creation = JobCreationInfo {
Expand Down
Loading

0 comments on commit 8885f07

Please sign in to comment.