Skip to content

Commit

Permalink
OpenAI - Support audio input modality
Browse files Browse the repository at this point in the history
* Extend OpenAiApi to support the latest version of the Chat Completion API, including input and output audio modality.
* Support input audio modality in OpenAiChatModel via the existing multimodality support in Spring AI.

Fixes spring-projectsgh-1560
  • Loading branch information
ThomasVitale committed Oct 18, 2024
1 parent 8eef6e6 commit bdb66e5
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
Expand All @@ -69,6 +70,7 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -408,7 +410,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC
chunkChoice.logprobs()))
.toList();

return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),
chunk.systemFingerprint(), "chat.completion", chunk.usage());
}

Expand All @@ -425,11 +427,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
List<MediaContent> contentList = new ArrayList<>(
List.of(new MediaContent(message.getContent())));

contentList.addAll(userMessage.getMedia()
.stream()
.map(media -> new MediaContent(new MediaContent.ImageUrl(
this.fromMediaData(media.getMimeType(), media.getData()))))
.toList());
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());

content = contentList;
}
Expand All @@ -448,7 +446,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
}).toList();
}
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -460,7 +458,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null))
tr.id(), null, null, null))
.toList();
}
else {
Expand Down Expand Up @@ -512,6 +510,29 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) {
return request;
}

private MediaContent mapToMediaContent(Media media) {
var mimeType = media.getMimeType();
if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) {
return new MediaContent(
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.MP3));
}
if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) {
return new MediaContent(
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV));
}
else {
return new MediaContent(
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
}
}

private String fromAudioData(Object audioData) {
if (audioData instanceof byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes);
}
throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName());
}

private String fromMediaData(MimeType mimeType, Object mediaContentData) {
if (mediaContentData instanceof byte[] bytes) {
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
Expand Down
Loading

0 comments on commit bdb66e5

Please sign in to comment.