Skip to content

Commit

Permalink
Merge pull request #559 from dcSpark/nico/fix_rag
Browse files Browse the repository at this point in the history
fix omni rag
  • Loading branch information
nicarq authored Sep 16, 2024
2 parents 7f2eaf4 + 481da43 commit b362542
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_node"
version = "0.8.3"
version = "0.8.4"
edition = "2021"
authors.workspace = true
# this causes `cargo run` in the workspace root to run this package
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ impl JobPromptGenerator {
// Add previous messages
// TODO: this should be full messages with assets and not just strings
if let Some(step_history) = job_step_history {
for step in &step_history {
if let Some(prompt) = step.get_result_prompt() {
println!("Step history content: {:?}", prompt);
}
}
prompt.add_step_history(step_history, 97);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ impl Prompt {

/// Adds a sub-prompt that holds any Omni (String + Assets) content.
/// Of note, priority value must be between 0-100, where higher is greater priority
pub fn add_omni(&mut self, content: String, files: HashMap<String, String>, prompt_type: SubPromptType, priority_value: u8) {
pub fn add_omni(
&mut self,
content: String,
files: HashMap<String, String>,
prompt_type: SubPromptType,
priority_value: u8,
) {
let capped_priority_value = std::cmp::min(priority_value, 100);
let assets: Vec<(SubPromptAssetType, SubPromptAssetContent, SubPromptAssetDetail)> = files
.into_iter()
Expand Down Expand Up @@ -306,7 +312,7 @@ impl Prompt {

// Accumulator for ExtraContext content
let mut extra_context_content = String::new();
let mut last_user_message: Option<String> = None;
let mut last_user_message: Option<LlmMessage> = None;
let mut function_calls: Vec<LlmMessage> = Vec::new();
let mut function_call_responses: Vec<LlmMessage> = Vec::new();

Expand Down Expand Up @@ -366,13 +372,25 @@ impl Prompt {
function_call_responses.push(new_message);
}
SubPrompt::Content(SubPromptType::UserLastMessage, content, _) => {
last_user_message = Some(content.clone());
last_user_message = Some(LlmMessage {
role: Some(SubPromptType::User.to_string()),
content: Some(content.clone()),
name: None,
function_call: None,
functions: None,
images: None,
});
}
SubPrompt::Omni(_, _, _, _) => {
SubPrompt::Omni(prompt_type, _, _, _) => {
// Process the current sub-prompt
let new_message = sub_prompt.into_chat_completion_request_message();
current_length += sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message);
tiktoken_messages.push(new_message);

if let SubPromptType::UserLastMessage = prompt_type {
last_user_message = Some(new_message);
} else {
tiktoken_messages.push(new_message);
}
}
_ => {
// Process the current sub-prompt
Expand All @@ -388,18 +406,21 @@ impl Prompt {
let combined_content = format!(
"{}\n{}",
extra_context_content.trim(),
last_user_message.unwrap_or_default()
last_user_message
.as_ref()
.and_then(|msg| msg.content.clone())
.unwrap_or_default()
)
.trim()
.to_string();

let combined_message = LlmMessage {
let mut combined_message = LlmMessage {
role: Some(SubPromptType::User.to_string()),
content: Some(combined_content),
name: None,
function_call: None,
functions: None,
images: None,
images: last_user_message.and_then(|msg| msg.images),
};
current_length += ModelCapabilitiesManager::num_tokens_from_llama3(&[combined_message.clone()]);
tiktoken_messages.push(combined_message);
Expand Down

0 comments on commit b362542

Please sign in to comment.