diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 8d6b60b23d..da46f5adca 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -48,6 +48,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -69,6 +70,7 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -408,7 +410,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC chunkChoice.logprobs())) .toList(); - return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), + return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(), chunk.systemFingerprint(), "chat.completion", chunk.usage()); } @@ -425,11 +427,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List contentList = new ArrayList<>( List.of(new MediaContent(message.getContent()))); - contentList.addAll(userMessage.getMedia() - .stream() - .map(media -> new MediaContent(new MediaContent.ImageUrl( - this.fromMediaData(media.getMimeType(), media.getData())))) - .toList()); + contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); content = contentList; } @@ -448,7 +446,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { }).toList(); } return List.of(new ChatCompletionMessage(assistantMessage.getContent(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -460,7 +458,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null)) + tr.id(), null, null, null)) .toList(); } else { @@ -512,6 +510,29 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) { return request; } + private MediaContent mapToMediaContent(Media media) { + var mimeType = media.getMimeType(); + if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) { + return new MediaContent( + new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.MP3)); + } + if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) { + return new MediaContent( + new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV)); + } + else { + return new MediaContent( + new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))); + } + } + + private String fromAudioData(Object audioData) { + if (audioData instanceof byte[] bytes) { + return Base64.getEncoder().encodeToString(bytes); + } + throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName()); + } + private String fromMediaData(MimeType mimeType, Object mediaContentData) { if (mediaContentData instanceof byte[] bytes) { // Assume the bytes are an image. So, convert the bytes to a base64 encoded diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 3141826826..4f7846de82 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -210,6 +210,11 @@ public enum ChatModel implements ChatModelDescription { */ GPT_4_O("gpt-4o"), + /** + * Preview release for audio inputs in chat completions. + */ + GPT_4_O_AUDIO_PREVIEW("gpt-4o-audio-preview"), + /** * Affordable and intelligent small model for fast, lightweight tasks. GPT-4o mini * is cheaper and more capable than GPT-3.5 Turbo. Currently points to @@ -345,11 +350,17 @@ public enum Type { * with a maximum length of 64. * @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a * function that accepts no parameters, provide the value {"type": "object", "properties": {}}. + * @param strict Whether to enable strict schema adherence when generating the function call. If set to true, + * the model will follow the exact schema defined in the parameters field. Only a subset of JSON Schema + * is supported when strict is true. */ + @JsonInclude(Include.NON_NULL) public record Function( @JsonProperty("description") String description, @JsonProperty("name") String name, - @JsonProperty("parameters") Map parameters) { + @JsonProperty("parameters") Map parameters, + @JsonProperty("strict") Boolean strict + ) { /** * Create tool function definition. @@ -360,16 +371,32 @@ public record Function( */ @ConstructorBinding public Function(String description, String name, String jsonSchema) { - this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); + this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema), null); } } }// @formatter:on + /** + * The type of modality for the model completion. + */ + public enum OutputModality { + + // @formatter:off + @JsonProperty("audio") AUDIO, + @JsonProperty("text") TEXT; + // @formatter:on + + } + /** * Creates a model response for the given chat conversation. * * @param messages A list of messages comprising the conversation so far. * @param model ID of the model to use. + * @param store Whether to store the output of this chat completion request for use in + * OpenAI's model distillation or evals products. + * @param metadata Developer-defined tags and values used for filtering completions in + * the OpenAI's dashboard. * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new * tokens based on their existing frequency in the text so far, decreasing the model's * likelihood to repeat the same line verbatim. @@ -386,14 +413,22 @@ public Function(String description, String name, String jsonSchema) { * @param topLogprobs An integer between 0 and 5 specifying the number of most likely * tokens to return at each token position, each with an associated log probability. * 'logprobs' must be set to 'true' if this parameter is used. - * @param maxTokens The maximum number of tokens to generate in the chat completion. - * The total length of input tokens and generated tokens is limited by the model's - * context length. + * @param maxTokens The maximum number of tokens that can be generated in the chat + * completion. This value can be used to control costs for text generated via API. + * This value is now deprecated in favor of max_completion_tokens, and is not + * compatible with o1 series models. * @param maxCompletionTokens An upper bound for the number of tokens that can be * generated for a completion, including visible output tokens and reasoning tokens. * @param n How many chat completion choices to generate for each input message. Note * that you will be charged based on the number of generated tokens across all the * choices. Keep n as 1 to minimize costs. + * @param outputModalities Output types that you would like the model to generate for + * this request. Most models are capable of generating text, which is the default: + * ["text"]. The gpt-4o-audio-preview model can also be used to generate audio. To + * request that this model generate both text and audio responses, you can use: + * ["text", "audio"]. + * @param audioParameters Parameters for audio output. Required when audio output is + * requested with outputModalities: ["audio"]. * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new * tokens based on whether they appear in the text so far, increasing the model's * likelihood to talk about new topics. @@ -405,6 +440,9 @@ public Function(String description, String name, String jsonSchema) { * and parameters should return the same result. Determinism is not guaranteed, and * you should refer to the system_fingerprint response parameter to monitor changes in * the backend. + * @param serviceTier Specifies the latency tier to use for processing the request. + * This parameter is relevant for customers subscribed to the scale tier service. When + * this parameter is set, the response body will include the service_tier utilized. * @param stop Up to 4 sequences where the API will stop generating further tokens. * @param stream If set, partial message deltas will be sent.Tokens will be sent as * data-only server-sent events as they become available, with the stream terminated @@ -438,16 +476,21 @@ public Function(String description, String name, String jsonSchema) { public record ChatCompletionRequest(// @formatter:off @JsonProperty("messages") List messages, @JsonProperty("model") String model, + @JsonProperty("store") Boolean store, + @JsonProperty("metadata") Object metadata, @JsonProperty("frequency_penalty") Double frequencyPenalty, @JsonProperty("logit_bias") Map logitBias, @JsonProperty("logprobs") Boolean logprobs, @JsonProperty("top_logprobs") Integer topLogprobs, - @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("max_tokens") @Deprecated Integer maxTokens, // Use maxCompletionTokens instead @JsonProperty("max_completion_tokens") Integer maxCompletionTokens, @JsonProperty("n") Integer n, + @JsonProperty("modalities") List outputModalities, + @JsonProperty("audio") AudioParameters audioParameters, @JsonProperty("presence_penalty") Double presencePenalty, @JsonProperty("response_format") ResponseFormat responseFormat, @JsonProperty("seed") Integer seed, + @JsonProperty("service_tier") String serviceTier, @JsonProperty("stop") List stop, @JsonProperty("stream") Boolean stream, @JsonProperty("stream_options") StreamOptions streamOptions, @@ -466,11 +509,25 @@ public record ChatCompletionRequest(// @formatter:off * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this(messages, model, null, null, null, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, null, null, null, null); } + /** + * Shortcut constructor for a chat completion request with text and audio output. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"]. + */ + public ChatCompletionRequest(List messages, String model, AudioParameters audio) { + this(messages, model, null, null, null, null, null, null, + null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, + null, null, null, false, null, null, null, + null, null, null, null); + } + /** * Shortcut constructor for a chat completion request with the given messages, model, temperature and control for streaming. * @@ -481,9 +538,9 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this(messages, model, null, null, null, null, null, null, null, null, - null, null, null, stream, null, temperature, null, - null, null, null, null); + this(messages, model, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, stream, null, temperature, null, + null, null, null, null); } /** @@ -497,8 +554,8 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { - this(messages, model, null, null, null, null, null, null, null, null, - null, null, null, false, null, 0.8, null, + this(messages, model, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, false, null, 0.8, null, tools, toolChoice, null, null); } @@ -510,8 +567,8 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, null, null, null, null, null, - null, null, null, null, stream, null, null, null, + this(messages, null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, stream, null, null, null, null, null, null, null); } @@ -522,8 +579,9 @@ public ChatCompletionRequest(List messages, Boolean strea * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, maxCompletionTokens, n, presencePenalty, - responseFormat, seed, stop, stream, streamOptions, temperature, topP, + return new ChatCompletionRequest(messages, model, store, metadata, frequencyPenalty, logitBias, logprobs, + topLogprobs, maxTokens, maxCompletionTokens, n, outputModalities, audioParameters, presencePenalty, + responseFormat, seed, serviceTier, stop, stream, streamOptions, temperature, topP, tools, toolChoice, parallelToolCalls, user); } @@ -548,6 +606,40 @@ public static Object FUNCTION(String functionName) { } } + /** + * Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"]. + * @param voice Specifies the voice type. + * @param format Specifies the output audio format. + */ + @JsonInclude(Include.NON_NULL) + public record AudioParameters( + @JsonProperty("voice") Voice voice, + @JsonProperty("format") AudioResponseFormat format) { + + /** + * Specifies the voice type. + */ + public enum Voice { + @JsonProperty("alloy") ALLOY, + @JsonProperty("echo") ECHO, + @JsonProperty("fable") FABLE, + @JsonProperty("onyx") ONYX, + @JsonProperty("nova") NOVA, + @JsonProperty("shimmer") SHIMMER; + } + + /** + * Specifies the output audio format. + */ + public enum AudioResponseFormat { + @JsonProperty("mp3") MP3, + @JsonProperty("flac") FLAC, + @JsonProperty("opus") OPUS, + @JsonProperty("pcm16") PCM16, + @JsonProperty("wav") WAV; + } + } + /** * An object specifying the format that the model must output. * @param type Must be one of 'text' or 'json_object'. @@ -647,6 +739,9 @@ public record StreamOptions( * the {@link Role#TOOL} role and null otherwise. * @param toolCalls The tool calls generated by the model, such as function calls. * Applicable only for {@link Role#ASSISTANT} role and null otherwise. + * @param refusal The refusal message by the assistant. Applicable only for + * {@link Role#ASSISTANT} role and null otherwise. + * @param audioOutput Audio response from the model. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionMessage(// @formatter:off @@ -655,7 +750,9 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls, - @JsonProperty("refusal") String refusal) {// @formatter:on + @JsonProperty("refusal") String refusal, + @JsonProperty("audio") AudioOutput audioOutput + ) {// @formatter:on /** * Get message content as String. @@ -677,7 +774,7 @@ public String content() { * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null); + this(content, role, null, null, null, null, null); } /** @@ -710,19 +807,22 @@ public enum Role { /** * An array of content parts with a defined type. Each MediaContent can be of - * either "text" or "image_url" type. Not both. + * either "text", "image_url", or "input_audio" type. Only one option allowed. * * @param type Content type, each can be of type text or image_url. * @param text The text content of the message. * @param imageUrl The image content of the message. You can pass multiple images * by adding multiple image_url content parts. Image input is only supported when * using the gpt-4-visual-preview model. + * @param inputAudio Audio content part. */ @JsonInclude(Include.NON_NULL) public record MediaContent(// @formatter:off @JsonProperty("type") String type, @JsonProperty("text") String text, - @JsonProperty("image_url") ImageUrl imageUrl) { + @JsonProperty("image_url") ImageUrl imageUrl, + @JsonProperty("input_audio") InputAudio inputAudio + ) { // @formatter:on /** * @param url Either a URL of the image or the base64 encoded image data. The @@ -738,12 +838,29 @@ public ImageUrl(String url) { } } + /** + * @param data Base64 encoded audio data. + * @param format The format of the encoded audio data. Currently supports + * "wav" and "mp3". + */ + @JsonInclude(Include.NON_NULL) + public record InputAudio(@JsonProperty("data") String data, @JsonProperty("format") Format format) { + public enum Format { + + // @formatter:off + @JsonProperty("mp3") MP3, + @JsonProperty("wav") WAV; + // @formatter:on + + } + } + /** * Shortcut constructor for a text content. * @param text The text content of the message. */ public MediaContent(String text) { - this("text", text, null); + this("text", text, null, null); } /** @@ -751,7 +868,15 @@ public MediaContent(String text) { * @param imageUrl The image content of the message. */ public MediaContent(ImageUrl imageUrl) { - this("image_url", null, imageUrl); + this("image_url", null, imageUrl, null); + } + + /** + * Shortcut constructor for an audio content. + * @param inputAudio The audio content of the message. + */ + public MediaContent(InputAudio inputAudio) { + this("input_audio", null, null, inputAudio); } } @@ -790,6 +915,24 @@ public record ChatCompletionFunction(// @formatter:off @JsonProperty("name") String name, @JsonProperty("arguments") String arguments) {// @formatter:on } + + /** + * Audio response from the model. + * + * @param id Unique identifier for the audio response from the model. + * @param data Audio output from the model. + * @param expiresAt When the audio content will no longer be available on the + * server. + * @param transcript Transcript of the audio output from the model. + */ + @JsonInclude(Include.NON_NULL) + public record AudioOutput(// @formatter:off + @JsonProperty("id") String id, + @JsonProperty("data") String data, + @JsonProperty("expires_at") Long expiresAt, + @JsonProperty("transcript") String transcript + ) {// @formatter:on + } } public static String getTextContent(List content) { @@ -847,6 +990,8 @@ public enum ChatCompletionFinishReason { * @param created The Unix timestamp (in seconds) of when the chat completion was * created. * @param model The model used for the chat completion. + * @param serviceTier The service tier used for processing the request. This field is + * only included if the service_tier parameter is specified in the request. * @param systemFingerprint This fingerprint represents the backend configuration that * the model runs with. Can be used in conjunction with the seed request parameter to * understand when backend changes have been made that might impact determinism. @@ -859,9 +1004,11 @@ public record ChatCompletion(// @formatter:off @JsonProperty("choices") List choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("service_tier") String serviceTier, @JsonProperty("system_fingerprint") String systemFingerprint, @JsonProperty("object") String object, - @JsonProperty("usage") Usage usage) {// @formatter:on + @JsonProperty("usage") Usage usage + ) {// @formatter:on /** * Chat completion choice. @@ -877,7 +1024,6 @@ public record Choice(// @formatter:off @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, @JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on - } } @@ -885,9 +1031,11 @@ public record Choice(// @formatter:off * Log probability information for the choice. * * @param content A list of message content tokens with log probability information. + * @param refusal A list of message refusal tokens with log probability information. */ @JsonInclude(Include.NON_NULL) - public record LogProbs(@JsonProperty("content") List content) { + public record LogProbs(@JsonProperty("content") List content, + @JsonProperty("refusal") List refusal) { /** * Message content tokens with log probability information. @@ -945,20 +1093,38 @@ public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, @JsonProperty("total_tokens") Integer totalTokens, - @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on + @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails, + @JsonProperty("prompt_tokens_details") PromptTokenDetails promptTokenDetails + ) {// @formatter:on public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { - this(completionTokens, promptTokens, totalTokens, null); + this(completionTokens, promptTokens, totalTokens, null, null); } /** - * Breakdown of tokens used in a completion + * Breakdown of tokens used in a completion. * + * @param audioTokens Audio input tokens generated by the model. * @param reasoningTokens Number of tokens generated by the model for reasoning. */ @JsonInclude(Include.NON_NULL) public record CompletionTokenDetails(// @formatter:off - @JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on + @JsonProperty("audio_tokens") Integer audioTokens, + @JsonProperty("reasoning_tokens") Integer reasoningTokens + ) {// @formatter:on + } + + /** + * Breakdown of tokens used in the prompt. + * + * @param audioTokens Audio input tokens present in the prompt. + * @param cachedTokens Cached tokens present in the prompt. + */ + @JsonInclude(Include.NON_NULL) + public record PromptTokenDetails(// @formatter:off + @JsonProperty("audio_tokens") Integer audioTokens, + @JsonProperty("cached_tokens") Integer cachedTokens + ) {// @formatter:on } } @@ -973,6 +1139,8 @@ public record CompletionTokenDetails(// @formatter:off * @param created The Unix timestamp (in seconds) of when the chat completion was * created. Each chunk has the same timestamp. * @param model The model used for the chat completion. + * @param serviceTier The service tier used for processing the request. This field is + * only included if the service_tier parameter is specified in the request. * @param systemFingerprint This fingerprint represents the backend configuration that * the model runs with. Can be used in conjunction with the seed request parameter to * understand when backend changes have been made that might impact determinism. @@ -986,6 +1154,7 @@ public record ChatCompletionChunk(// @formatter:off @JsonProperty("choices") List choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("service_tier") String serviceTier, @JsonProperty("system_fingerprint") String systemFingerprint, @JsonProperty("object") String object, @JsonProperty("usage") Usage usage) {// @formatter:on @@ -1101,7 +1270,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat // Flux> -> Flux> .concatMapIterable(window -> { Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null, null), + new ChatCompletionChunk(null, null, null, null, null, null, null, null), (previous, current) -> this.chunkMerger.merge(previous, current)); return List.of(monoChunk); }) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 02bfd31080..5e37dc0381 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -37,6 +37,7 @@ * It can merge the streamed ChatCompletionChunk in case of function calling message. * * @author Christian Tzolov + * @author Thomas Vitale * @since 0.8.1 */ public class OpenAiStreamFunctionCallingHelper { @@ -56,6 +57,7 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu String id = (current.id() != null ? current.id() : previous.id()); Long created = (current.created() != null ? current.created() : previous.created()); String model = (current.model() != null ? current.model() : previous.model()); + String serviceTier = (current.serviceTier() != null ? current.serviceTier() : previous.serviceTier()); String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint() : previous.systemFingerprint()); String object = (current.object() != null ? current.object() : previous.object()); @@ -66,7 +68,7 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu ChunkChoice choice = merge(previousChoice0, currentChoice0); List chunkChoices = choice == null ? List.of() : List.of(choice); - return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object, usage); + return new ChatCompletionChunk(id, chunkChoices, created, model, serviceTier, systemFingerprint, object, usage); } private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { @@ -92,6 +94,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti String name = (current.name() != null ? current.name() : previous.name()); String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId()); String refusal = (current.refusal() != null ? current.refusal() : previous.refusal()); + ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput() + : previous.audioOutput()); List toolCalls = new ArrayList<>(); ToolCall lastPreviousTooCall = null; @@ -121,7 +125,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput); } private ToolCall merge(ToolCall previous, ToolCall current) { @@ -196,7 +200,7 @@ public ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) { chunkChoice.logprobs())) .toList(); - return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), + return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(), chunk.systemFingerprint(), "chat.completion", null); } diff --git a/models/spring-ai-openai/src/main/resources/speech1.mp3 b/models/spring-ai-openai/src/main/resources/speech1.mp3 new file mode 100644 index 0000000000..68cb2a30da Binary files /dev/null and b/models/spring-ai-openai/src/main/resources/speech1.mp3 differ diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 17a0022763..d5ff4cf628 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -15,12 +15,8 @@ */ package org.springframework.ai.openai.api; -import java.util.List; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import reactor.core.publisher.Flux; - import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; @@ -28,12 +24,19 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.OpenAiApi.Embedding; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; +import org.springframework.core.io.ClassPathResource; import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.io.IOException; +import java.util.Base64; +import java.util.List; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiIT { @@ -70,4 +73,48 @@ void embeddings() { assertThat(response.getBody().data().get(0).embedding()).hasSize(1536); } + @Test + void inputAudio() throws IOException { + var audioData = new ClassPathResource("speech1.mp3").getContentAsByteArray(); + List content = List + .of(new ChatCompletionMessage.MediaContent("What is this recording about?"), + new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.InputAudio( + Base64.getEncoder().encodeToString(audioData), + ChatCompletionMessage.MediaContent.InputAudio.Format.MP3))); + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(content, Role.USER); + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(List.of(chatCompletionMessage), + OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), 0.0); + ResponseEntity response = openAiApi.chatCompletionEntity(chatCompletionRequest); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + + assertThat(response.getBody().usage().promptTokenDetails().audioTokens()).isGreaterThan(0); + assertThat(response.getBody().usage().completionTokenDetails().audioTokens()).isEqualTo(0); + + assertThat(response.getBody().choices().get(0).message().content()).containsIgnoringCase("hobbits"); + } + + @Test + void outputAudio() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage( + "What is the magic spell to make objects fly?", Role.USER); + ChatCompletionRequest.AudioParameters audioParameters = new ChatCompletionRequest.AudioParameters( + ChatCompletionRequest.AudioParameters.Voice.NOVA, + ChatCompletionRequest.AudioParameters.AudioResponseFormat.MP3); + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(List.of(chatCompletionMessage), + OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), audioParameters); + ResponseEntity response = openAiApi.chatCompletionEntity(chatCompletionRequest); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + + assertThat(response.getBody().usage().promptTokenDetails().audioTokens()).isEqualTo(0); + assertThat(response.getBody().usage().completionTokenDetails().audioTokens()).isGreaterThan(0); + + assertThat(response.getBody().choices().get(0).message().audioOutput().data()).isNotNull(); + assertThat(response.getBody().choices().get(0).message().audioOutput().transcript()) + .containsIgnoringCase("leviosa"); + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index f8d0f20316..c2f64b6788 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -44,6 +44,7 @@ * https://platform.openai.com/docs/guides/function-calling/parallel-function-calling * * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiToolFunctionCallIT { @@ -87,7 +88,7 @@ public void toolFunctionCall() { }, "required": ["location", "lat", "lon", "unit"] } - """))); + """), null)); List messages = new ArrayList<>(List.of(message)); @@ -120,7 +121,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null, null)); + functionName, toolCall.id(), null, null, null)); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 8a393c8aaf..658baecfbd 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -48,6 +48,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @ExtendWith(MockitoExtension.class) @SuppressWarnings("unchecked") @@ -64,12 +65,12 @@ public class MessageTypeContentTests { @Captor ArgumentCaptor> headersCaptor; - Flux fluxResponse = Flux - .generate(() -> new ChatCompletionChunk("id", List.of(), 0l, "model", "fp", "object", null), (state, sink) -> { - sink.next(state); - sink.complete(); - return state; - }); + Flux fluxResponse = Flux.generate( + () -> new ChatCompletionChunk("id", List.of(), 0l, "model", null, "fp", "object", null), (state, sink) -> { + sink.next(state); + sink.complete(); + return state; + }); @BeforeEach public void beforeEach() { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 37a619ab92..75852acc9b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -437,6 +437,42 @@ void streamingMultiModalityImageUrl() throws IOException { assertThat(content).containsAnyOf("bowl", "basket", "fruit stand"); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "gpt-4o-audio-preview" }) + void multiModalityInputAudio(String modelName) { + var audioResource = new ClassPathResource("speech1.mp3"); + var userMessage = new UserMessage("What is this recording about?", + List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource))); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + + logger.info(response.getResult().getOutput().getContent()); + assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("hobbits"); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "gpt-4o-audio-preview" }) + void streamingMultiModalityInputAudio(String modelName) { + var audioResource = new ClassPathResource("speech1.mp3"); + var userMessage = new UserMessage("What is this recording about?", + List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource))); + + Flux response = chatModel + .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + assertThat(content).containsIgnoringCase("hobbits"); + } + @Test void validateCallResponseMetadata() { String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getName(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index bae33e60c8..38e0fd2085 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -75,6 +75,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @SuppressWarnings("unchecked") @ExtendWith(MockitoExtension.class) @@ -141,7 +142,7 @@ public void openAiChatTransientError() { var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, - new OpenAiApi.Usage(10, 10, 10)); + null, new OpenAiApi.Usage(10, 10, 10)); when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) @@ -170,7 +171,7 @@ public void openAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, - null, null); + null, null, null); when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index b9215b4c3d..9018fa1eea 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -56,7 +56,7 @@ void whenTotalTokensIsNull() { @Test void whenCompletionTokenDetailsIsNull() { - OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null); + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); assertThat(usage.getReasoningTokens()).isEqualTo(0); @@ -65,7 +65,7 @@ void whenCompletionTokenDetailsIsNull() { @Test void whenReasoningTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, - new OpenAiApi.Usage.CompletionTokenDetails(null)); + new OpenAiApi.Usage.CompletionTokenDetails(null, null), null); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getReasoningTokens()).isEqualTo(0); } @@ -73,7 +73,7 @@ void whenReasoningTokensIsNull() { @Test void whenCompletionTokenDetailsIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, - new OpenAiApi.Usage.CompletionTokenDetails(50)); + new OpenAiApi.Usage.CompletionTokenDetails(null, 50), null); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getReasoningTokens()).isEqualTo(50); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index f74a63e18e..8d5d6b8f42 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -159,7 +159,11 @@ Read more about xref:api/chat/functions/openai-chat-functions.adoc[OpenAI Functi == Multimodal Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats. -OpenAI models that offer multimodal support include `gpt-4`, `gpt-4o`, and `gpt-4o-mini`. +OpenAI supports text, vision, and audio input modalities. + +=== Vision + +OpenAI models that offer vision multimodal support include `gpt-4`, `gpt-4o`, and `gpt-4o-mini`. Refer to the link:https://platform.openai.com/docs/guides/vision[Vision] guide for more information. The OpenAI link:https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages[User Message API] can incorporate a list of base64-encoded images or image urls with the message. @@ -210,6 +214,31 @@ for carrying. The bowl is placed on a flat surface with a neutral-colored backgr view of the fruit inside. ---- +=== Audio + +OpenAI models that offer audio multimodal support include `gpt-4o-audio-preview`. +Refer to the link:https://platform.openai.com/docs/guides/audio[Audio] guide for more information. + +The OpenAI link:https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages[User Message API] can incorporate a list of base64-encoded audio files with the message. +Spring AI’s link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java[Message] interface facilitates multimodal AI models by introducing the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Media.java[Media] type. +This type encompasses data and details regarding media attachments in messages, utilizing Spring’s `org.springframework.util.MimeType` and a `org.springframework.core.io.Resource` for the raw media data. +Currently, OpenAI support only the following media types: `audio/mp3` and `audio/wav`. + +Below is a code example excerpted from link:https://github.com/spring-projects/spring-ai/blob/c9a3e66f90187ce7eae7eb78c462ec622685de6c/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java#L442[OpenAiChatModelIT.java], illustrating the fusion of user text with an audio file using the `gpt-4o-audio-preview` model. + +[source,java] +---- +var audioResource = new ClassPathResource("speech1.mp3"); + +var userMessage = new UserMessage("What is this recording about?", + List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource))); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW).build())); +---- + +TIP: You can pass multiple audio files as well. + == Structured Outputs OpenAI provides custom https://platform.openai.com/docs/guides/structured-outputs[Structured Outputs] APIs that ensure your model generates responses conforming strictly to your provided `JSON Schema`.