From dcfc7d9362b32bb66c5e751c9cc35a5f1d09fc3a Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 31 Oct 2024 15:26:38 +0100 Subject: [PATCH] streaming and tools support --- .../src/llm_provider/providers/claude.rs | 568 ++++++++++++------ .../providers/shared/claude_api.rs | 57 +- .../managers/model_capabilities_manager.rs | 48 +- 3 files changed, 428 insertions(+), 245 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/claude.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/claude.rs index e6035a285..d3af86791 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/claude.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/claude.rs @@ -63,7 +63,7 @@ impl LLMService for Claude { let is_stream = config.as_ref().and_then(|c| c.stream).unwrap_or(true); - let messages_result = claude_prepare_messages(&model, prompt)?; + let (messages_result, system_messages) = claude_prepare_messages(&model, prompt)?; let messages_json = match messages_result.messages { PromptResultEnum::Value(v) => v, _ => { @@ -75,6 +75,16 @@ impl LLMService for Claude { // Extract tools_json from the result let tools_json = messages_result.functions.unwrap_or_else(Vec::new); + let tools_json = tools_json + .into_iter() + .map(|mut tool| { + if let Some(input_schema) = tool.get_mut("parameters") { + tool["input_schema"] = input_schema.clone(); + tool.as_object_mut().unwrap().remove("parameters"); + } + tool + }) + .collect::>(); // Print messages_json as a pretty JSON string match serde_json::to_string_pretty(&messages_json) { @@ -90,13 +100,23 @@ impl LLMService for Claude { let mut payload = json!({ "model": self.model_type, "messages": messages_json, - "max_tokens": messages_result.remaining_tokens, + "max_tokens": messages_result.remaining_output_tokens, "stream": is_stream, + "system": system_messages.into_iter().map(|m| m.content.unwrap_or_default()).collect::>().join(""), }); // Conditionally add functions to the payload if tools_json is not empty if !tools_json.is_empty() { - payload["tools"] = serde_json::Value::Array(tools_json.clone()); + let tools_payload = tools_json + .clone() + .into_iter() + .map(|mut tool| { + tool.as_object_mut().unwrap().remove("tool_router_key"); + tool + }) + .collect::>(); + + payload["tools"] = serde_json::Value::Array(tools_payload); } // Add options to payload @@ -173,7 +193,7 @@ async fn handle_streaming_response( let mut stream = res.bytes_stream(); let mut response_text = String::new(); - let mut previous_json_chunk: String = String::new(); + let mut processed_tool: Option = None; let mut function_call: Option = None; while let Some(item) = stream.next().await { @@ -221,101 +241,116 @@ async fn handle_streaming_response( Ok(chunk) => { let chunk_str = String::from_utf8_lossy(&chunk).to_string(); eprintln!("Chunk: {}", chunk_str); - previous_json_chunk += chunk_str.as_str(); - let trimmed_chunk_str = previous_json_chunk.trim().to_string(); - let data_resp: Result = serde_json::from_str(&trimmed_chunk_str); - match data_resp { - Ok(data) => { - serde_json::to_string_pretty(&data) - .map(|pretty_json| eprintln!("Response JSON: {}", pretty_json)) - .unwrap_or_else(|e| eprintln!("Failed to serialize response_json: {:?}", e)); - - previous_json_chunk = "".to_string(); - if let Some(choices) = data.get("choices") { - for choice in choices.as_array().unwrap_or(&vec![]) { - if let Some(message) = choice.get("message") { - if let Some(content) = message.get("content") { - response_text.push_str(content.as_str().unwrap_or("")); - } - if let Some(fc) = message.get("function_call") { - if let Some(name) = fc.get("name") { - let fc_arguments = fc - .get("arguments") - .and_then(|args| args.as_str()) - .and_then(|args_str| serde_json::from_str(args_str).ok()) - .and_then(|args_value: serde_json::Value| { - args_value.as_object().cloned() - }) - .unwrap_or_else(|| serde_json::Map::new()); - - // Extract tool_router_key - let tool_router_key = tools.as_ref().and_then(|tools_array| { - tools_array.iter().find_map(|tool| { - if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") { - tool.get("tool_router_key") - .and_then(|key| key.as_str().map(|s| s.to_string())) - } else { - None - } - }) - }); - - function_call = Some(FunctionCall { - name: name.as_str().unwrap_or("").to_string(), - arguments: fc_arguments.clone(), - tool_router_key, - }); - } - } - } + let processed_chunk = process_chunk(&chunk)?; + response_text.push_str(&processed_chunk.partial_text); + + if let Some(tool_use) = processed_chunk.tool_use { + match processed_tool { + Some(ref mut tool) => { + if !tool_use.tool_name.is_empty() { + tool.tool_name = tool_use.tool_name; } + tool.partial_tool_arguments.push_str(&tool_use.partial_tool_arguments); } + None => { + processed_tool = Some(tool_use); + } + } + } - // Updated WS message handling for tooling - if let Some(ref manager) = ws_manager_trait { - if let Some(ref inbox_name) = inbox_name { - if let Some(ref function_call) = function_call { - let m = manager.lock().await; - let inbox_name_string = inbox_name.to_string(); - - // Serialize FunctionCall to JSON value - let function_call_json = - serde_json::to_value(function_call).unwrap_or_else(|_| serde_json::json!({})); - - // Prepare ToolMetadata - let tool_metadata = ToolMetadata { - tool_name: function_call.name.clone(), - tool_router_key: function_call.tool_router_key.clone(), - args: function_call_json.as_object().cloned().unwrap_or_default(), - result: None, - status: ToolStatus { - type_: ToolStatusType::Running, - reason: None, - }, - }; - - let ws_message_type = - WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); - - let _ = m - .queue_message( - WSTopic::Inbox, - inbox_name_string, - serde_json::to_string(&function_call).unwrap_or_else(|_| "{}".to_string()), - ws_message_type, - true, - ) - .await; - } + if processed_chunk.is_done && processed_tool.is_some() { + let name = processed_tool.as_ref().unwrap().tool_name.clone(); + let arguments = + serde_json::from_str::(&processed_tool.as_ref().unwrap().partial_tool_arguments) + .ok() + .and_then(|args_value| args_value.as_object().cloned()) + .unwrap_or_else(|| serde_json::Map::new()); + let tool_router_key = tools.as_ref().and_then(|tools_array| { + tools_array.iter().find_map(|tool| { + if tool.get("name")?.as_str()? == name { + tool.get("tool_router_key") + .and_then(|key| key.as_str().map(|s| s.to_string())) + } else { + None } + }) + }); + + function_call = Some(FunctionCall { + name, + arguments, + tool_router_key, + }); + + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Info, + format!("Function Call: {:?}", function_call).as_str(), + ); + } + + if let Some(ref manager) = ws_manager_trait { + if let Some(ref inbox_name) = inbox_name { + if !processed_chunk.partial_text.is_empty() { + let m = manager.lock().await; + let inbox_name_string = inbox_name.to_string(); + let metadata = WSMetadata { + id: Some(session_id.clone()), + is_done: processed_chunk.is_done, + done_reason: if processed_chunk.is_done { + processed_chunk.done_reason.clone() + } else { + None + }, + total_duration: None, + eval_count: None, + }; + + let ws_message_type = WSMessageType::Metadata(metadata); + + let _ = m + .queue_message( + WSTopic::Inbox, + inbox_name_string, + processed_chunk.partial_text.clone(), + ws_message_type, + true, + ) + .await; + } + + if let Some(ref function_call) = function_call { + let m = manager.lock().await; + let inbox_name_string = inbox_name.to_string(); + + // Serialize FunctionCall to JSON value + let function_call_json = + serde_json::to_value(function_call).unwrap_or_else(|_| serde_json::json!({})); + + // Prepare ToolMetadata + let tool_metadata = ToolMetadata { + tool_name: function_call.name.clone(), + tool_router_key: function_call.tool_router_key.clone(), + args: function_call_json.as_object().cloned().unwrap_or_default(), + result: None, + status: ToolStatus { + type_: ToolStatusType::Running, + reason: None, + }, + }; + + let ws_message_type = WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); + + let _ = m + .queue_message( + WSTopic::Inbox, + inbox_name_string, + serde_json::to_string(&function_call).unwrap_or_else(|_| "{}".to_string()), + ws_message_type, + true, + ) + .await; } - } - Err(_e) => { - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Error, - format!("Error while receiving chunk: {:?}", _e).as_str(), - ); } } } @@ -374,117 +409,97 @@ async fn handle_non_streaming_response( let response_body = res.text().await?; let response_json: serde_json::Value = serde_json::from_str(&response_body)?; - serde_json::to_string_pretty(&response_json) - .map(|pretty_json| eprintln!("Response JSON: {}", pretty_json)) - .unwrap_or_else(|e| eprintln!("Failed to serialize response_json: {:?}", e)); - if let Some(content) = response_json.get("content") { - let content_str = content.as_array().and_then(|content_array| { - content_array.iter().find_map(|item| { - if let Some(text) = item.get("text") { - text.as_str() - } else { - None - } - }) - }); - if let Some(content_str) = content_str { - let function_call = response_json - .get("tool_use") - .and_then(|tool_use| { - tool_use.as_array().and_then(|calls| { - calls.iter().find_map(|call| { - let name = call.get("name")?.as_str()?.to_string(); - let arguments = call.get("input") + let mut response_text = String::new(); + let mut function_call = None; + + for content_block in content.as_array().unwrap_or(&vec![]) { + if let Some(content_type) = content_block.get("type") { + match content_type.as_str().unwrap_or("") { + "text" => { + if let Some(text) = content_block.get("text") { + response_text.push_str(text.as_str().unwrap_or("")); + } + } + "tool_use" => { + let name = content_block["name"].as_str().unwrap_or_default().to_string(); + let arguments = content_block.get("input") .and_then(|args_value| args_value.as_object().cloned()) .unwrap_or_else(|| serde_json::Map::new()); - // Search for the tool_router_key in the tools array - let tool_router_key = tools.as_ref().and_then(|tools_array| { - tools_array.iter().find_map(|tool| { - if tool.get("name")?.as_str()? == name { - tool.get("tool_router_key").and_then(|key| key.as_str().map(|s| s.to_string())) - } else { - None - } - }) - }); - - Some(FunctionCall { name, arguments, tool_router_key }) - }) - }) - }); - - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Info, - format!("Function Call: {:?}", function_call).as_str(), - ); - - - // Send WS message if a function call is detected - if let Some(ref manager) = ws_manager_trait { - if let Some(ref inbox_name) = inbox_name { - if let Some(ref function_call) = function_call { - let m = manager.lock().await; - let inbox_name_string = inbox_name.to_string(); - - // Serialize FunctionCall to JSON value - let function_call_json = serde_json::to_value(function_call) - .unwrap_or_else(|_| serde_json::json!({})); - - // Prepare ToolMetadata - let tool_metadata = ToolMetadata { - tool_name: function_call.name.clone(), - tool_router_key: None, - args: function_call_json - .as_object() - .cloned() - .unwrap_or_default(), - result: None, - status: ToolStatus { - type_: ToolStatusType::Running, - reason: None, - }, - }; - - let ws_message_type = WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); - - let _ = m - .queue_message( - WSTopic::Inbox, - inbox_name_string, - serde_json::to_string(&function_call) - .unwrap_or_else(|_| "{}".to_string()), - ws_message_type, - true, - ) - .await; + // Search for the tool_router_key in the tools array + let tool_router_key = tools.as_ref().and_then(|tools_array| { + tools_array.iter().find_map(|tool| { + if tool.get("name")?.as_str()? == name { + tool.get("tool_router_key").and_then(|key| key.as_str().map(|s| s.to_string())) + } else { + None + } + }) + }); + + function_call = Some(FunctionCall { name, arguments, tool_router_key }); + + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Info, + format!("Function Call: {:?}", function_call).as_str(), + ); + + + // Send WS message if a function call is detected + if let Some(ref manager) = ws_manager_trait { + if let Some(ref inbox_name) = inbox_name { + if let Some(ref function_call) = function_call { + let m = manager.lock().await; + let inbox_name_string = inbox_name.to_string(); + + // Serialize FunctionCall to JSON value + let function_call_json = serde_json::to_value(function_call) + .unwrap_or_else(|_| serde_json::json!({})); + + // Prepare ToolMetadata + let tool_metadata = ToolMetadata { + tool_name: function_call.name.clone(), + tool_router_key: None, + args: function_call_json + .as_object() + .cloned() + .unwrap_or_default(), + result: None, + status: ToolStatus { + type_: ToolStatusType::Running, + reason: None, + }, + }; + + let ws_message_type = WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); + + let _ = m + .queue_message( + WSTopic::Inbox, + inbox_name_string, + serde_json::to_string(&function_call) + .unwrap_or_else(|_| "{}".to_string()), + ws_message_type, + true, + ) + .await; + } + } + } } + _ => {} } } - - // Calculate tps - let eval_count = response_json.get("eval_count").and_then(|v| v.as_u64()).unwrap_or(0); - let eval_duration = - response_json.get("eval_duration").and_then(|v| v.as_u64()).unwrap_or(1); // Avoid division by zero - let tps = if eval_duration > 0 { - Some(eval_count as f64 / eval_duration as f64 * 1e9) - } else { - None - }; - - break Ok(LLMInferenceResponse::new( - content_str.to_string(), - json!({}), - function_call, - tps, - )); - } else { - break Err(LLMProviderError::UnexpectedResponseFormat( - "Content is not a string".to_string(), - )); } + + break Ok(LLMInferenceResponse::new( + response_text, + json!({}), + function_call, + None, + )); } else { break Err(LLMProviderError::UnexpectedResponseFormat( "No content field in message".to_string(), @@ -513,6 +528,9 @@ fn add_options_to_payload(payload: &mut serde_json::Value, config: Option<&JobCo if let Some(temp) = get_value("LLM_TEMPERATURE", config.and_then(|c| c.temperature.as_ref())) { payload["temperature"] = serde_json::json!(temp); } + if let Some(top_k) = get_value("LLM_TOP_K", config.and_then(|c| c.top_k.as_ref())) { + payload["top_k"] = serde_json::json!(top_k); + } if let Some(top_p) = get_value("LLM_TOP_P", config.and_then(|c| c.top_p.as_ref())) { payload["top_p"] = serde_json::json!(top_p); } @@ -543,3 +561,159 @@ fn add_options_to_payload(payload: &mut serde_json::Value, config: Option<&JobCo } } } + +#[derive(Debug, Clone)] +struct ProcessedChunk { + partial_text: String, + tool_use: Option, + is_done: bool, + done_reason: Option, +} + +#[derive(Debug, Clone)] +struct ProcessedTool { + tool_name: String, + partial_tool_arguments: String, +} + +// Claude streams chunk of events. Each pack can contain text deltas, name of the tool used or partial JSON of tool arguments. +fn process_chunk(chunk: &[u8]) -> Result> { + let chunk_str = String::from_utf8_lossy(chunk).to_string(); + + let mut text_blocks = Vec::new(); + let mut is_done = false; + let mut done_reason = None; + + let mut content_block_type = String::new(); + let mut content_block_data = String::new(); + let mut current_tool: Option = None; + + let events = chunk_str.split("\n\n").collect::>(); + for event in events { + let event_rows = event.split("\n").collect::>(); + + if event_rows.len() < 2 { + continue; + } + + let event_type = event_rows[0]; + let event_data = event_rows[1]; + + if event_type.starts_with("event: ") { + let event_type = event_type.trim_start_matches("event: "); + + match event_type { + "content_block_start" => { + let data_json: serde_json::Value = serde_json::from_str(event_data.trim_start_matches("data: "))?; + + if data_json + .get("content_block") + .and_then(|block| block.get("type")) + .is_none() + { + continue; + } + + content_block_type = data_json["content_block"]["type"].as_str().unwrap_or("").to_string(); + content_block_data = String::new(); + + if content_block_type == "tool_use" { + let tool_name = data_json["content_block"]["name"].as_str().unwrap_or("").to_string(); + current_tool = Some(ProcessedTool { + tool_name: tool_name, + partial_tool_arguments: String::new(), + }); + } + } + "content_block_delta" => { + let data_json: serde_json::Value = serde_json::from_str(event_data.trim_start_matches("data: "))?; + + let delta_type = data_json + .get("delta") + .and_then(|delta| delta.get("type")) + .unwrap_or(&serde_json::Value::Null); + match delta_type { + serde_json::Value::String(delta_type) => { + if delta_type == "text_delta" { + content_block_type = "text".to_string(); + let text = data_json["delta"]["text"].as_str().unwrap_or(""); + content_block_data.push_str(text); + } else if delta_type == "input_json_delta" { + content_block_type = "tool_use".to_string(); + let input_json = data_json["delta"]["partial_json"].as_str().unwrap_or(""); + content_block_data.push_str(input_json); + } + } + _ => {} + } + } + "content_block_stop" => { + if content_block_type == "text" { + text_blocks.push(content_block_data.clone()); + } else if content_block_type == "tool_use" { + if current_tool.is_none() { + current_tool = Some(ProcessedTool { + tool_name: "".to_string(), + partial_tool_arguments: "".to_string(), + }); + } + current_tool.as_mut().map(|tool| { + tool.partial_tool_arguments = content_block_data.clone(); + }); + } + + content_block_type = String::new(); + content_block_data = String::new(); + } + "message_delta" => { + let data_json: serde_json::Value = serde_json::from_str(event_data.trim_start_matches("data: "))?; + + let stop_reason = data_json + .get("delta") + .and_then(|delta| delta.get("stop_reason")) + .and_then(|reason| reason.as_str()) + .unwrap_or(""); + + if !stop_reason.is_empty() { + done_reason = Some(stop_reason.to_string()); + is_done = true; + } + } + "message_stop" => { + is_done = true; + } + "error" => { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + format!("Error in Claude response: {}", event_data).as_str(), + ); + } + _ => {} + } + } + } + + if !content_block_type.is_empty() && !content_block_data.is_empty() { + if content_block_type == "text" { + text_blocks.push(content_block_data); + } else if content_block_type == "tool_use" { + if current_tool.is_none() { + current_tool = Some(ProcessedTool { + tool_name: "".to_string(), + partial_tool_arguments: "".to_string(), + }); + } + current_tool.as_mut().map(|tool| { + tool.partial_tool_arguments = content_block_data; + }); + } + } + + Ok(ProcessedChunk { + partial_text: text_blocks.join(""), + tool_use: current_tool, + is_done, + done_reason, + }) +} diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/claude_api.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/claude_api.rs index f9ac771cf..101d0cabe 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/claude_api.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/claude_api.rs @@ -3,48 +3,51 @@ use crate::managers::model_capabilities_manager::ModelCapabilitiesManager; use crate::managers::model_capabilities_manager::PromptResult; use crate::managers::model_capabilities_manager::PromptResultEnum; use serde_json::{self}; +use shinkai_message_primitives::schemas::llm_message::LlmMessage; use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::LLMProviderInterface; use shinkai_message_primitives::schemas::prompts::Prompt; use super::shared_model_logic; -pub fn claude_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Result { +pub fn claude_prepare_messages( + model: &LLMProviderInterface, + prompt: Prompt, +) -> Result<(PromptResult, Vec), LLMProviderError> { let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model); // Generate the messages and filter out images - let mut chat_completion_messages = prompt.generate_llm_messages( + let chat_completion_messages = prompt.generate_llm_messages( Some(max_input_tokens), Some("function".to_string()), &ModelCapabilitiesManager::num_tokens_from_llama3, )?; - // Turn system role to assistant. Claude supports only user and assistant roles - for message in &mut chat_completion_messages { - if message.role == Some("system".to_string()) { - message.role = Some("assistant".to_string()); - } - } - - // Filter out empty content messages - chat_completion_messages.retain(|message| { - if let Some(content) = &message.content { - if content.is_empty() { - return false; - } - } - true - }); - // Get a more accurate estimate of the number of used tokens let used_tokens = ModelCapabilitiesManager::num_tokens_from_messages(&chat_completion_messages); // Calculate the remaining output tokens available let remaining_output_tokens = ModelCapabilitiesManager::get_remaining_output_tokens(model, used_tokens); // Separate messages into those with a valid role and those without - let (messages_with_role, tools): (Vec<_>, Vec<_>) = chat_completion_messages + let (mut messages_with_role, tools): (Vec<_>, Vec<_>) = chat_completion_messages .into_iter() .partition(|message| message.role.is_some()); + let mut system_messages = Vec::new(); + + // Collect system messages + for message in &mut messages_with_role { + if message.role == Some("system".to_string()) { + system_messages.push(message.clone()); + } + } + + // Filter out empty content and keep only user and assistant messages + messages_with_role.retain(|message| { + message.content.is_some() + && message.content.as_ref().unwrap().len() > 0 + && (message.role == Some("user".to_string()) || message.role == Some("assistant".to_string())) + }); + // Convert both sets of messages to serde Value let messages_json = serde_json::to_value(messages_with_role)?; let tools_json = serde_json::to_value(tools)?; @@ -106,9 +109,13 @@ pub fn claude_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> _ => vec![], }; - Ok(PromptResult { - messages: PromptResultEnum::Value(serde_json::Value::Array(messages_vec)), - functions: Some(tools_vec), - remaining_tokens: remaining_output_tokens, - }) + Ok(( + PromptResult { + messages: PromptResultEnum::Value(serde_json::Value::Array(messages_vec)), + functions: Some(tools_vec), + remaining_output_tokens, + tokens_used: used_tokens, + }, + system_messages, + )) } diff --git a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs index aebe3ba9c..df71dc542 100644 --- a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs +++ b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs @@ -447,6 +447,7 @@ impl ModelCapabilitiesManager { model_type if model_type.starts_with("llama3.2") => 128_000, model_type if model_type.starts_with("llama3.1") => 128_000, model_type if model_type.starts_with("llama3") || model_type.starts_with("llava-llama3") => 8_000, + model_type if model_type.starts_with("claude") => 200_000, _ => 4096, // Default token count if no specific model type matches } } @@ -633,31 +634,32 @@ impl ModelCapabilitiesManager { LLMProviderInterface::OpenAI(_) => true, LLMProviderInterface::Ollama(model) => { // For Ollama, check model type and respect the passed stream parameter - (model.model_type.starts_with("llama3.1") || - model.model_type.starts_with("llama3.2") || - model.model_type.starts_with("llama-3.1") || - model.model_type.starts_with("llama-3.2") || - model.model_type.starts_with("mistral-nemo") || - model.model_type.starts_with("mistral-small") || - model.model_type.starts_with("mistral-large")) && - stream.map_or(true, |s| !s) - }, + (model.model_type.starts_with("llama3.1") + || model.model_type.starts_with("llama3.2") + || model.model_type.starts_with("llama-3.1") + || model.model_type.starts_with("llama-3.2") + || model.model_type.starts_with("mistral-nemo") + || model.model_type.starts_with("mistral-small") + || model.model_type.starts_with("mistral-large")) + && stream.map_or(true, |s| !s) + } LLMProviderInterface::Groq(model) => { - model.model_type.starts_with("llama-3.2") || - model.model_type.starts_with("llama3.2") || - model.model_type.starts_with("llama-3.1") || - model.model_type.starts_with("llama3.1") - }, + model.model_type.starts_with("llama-3.2") + || model.model_type.starts_with("llama3.2") + || model.model_type.starts_with("llama-3.1") + || model.model_type.starts_with("llama3.1") + } LLMProviderInterface::OpenRouter(model) => { - model.model_type.starts_with("llama-3.2") || - model.model_type.starts_with("llama3.2") || - model.model_type.starts_with("llama-3.1") || - model.model_type.starts_with("llama3.1") || - model.model_type.starts_with("mistral-nemo") || - model.model_type.starts_with("mistral-small") || - model.model_type.starts_with("mistral-large") || - model.model_type.starts_with("mistral-pixtral") - }, + model.model_type.starts_with("llama-3.2") + || model.model_type.starts_with("llama3.2") + || model.model_type.starts_with("llama-3.1") + || model.model_type.starts_with("llama3.1") + || model.model_type.starts_with("mistral-nemo") + || model.model_type.starts_with("mistral-small") + || model.model_type.starts_with("mistral-large") + || model.model_type.starts_with("mistral-pixtral") + } + LLMProviderInterface::Claude(_) => true, _ => false, } }