Skip to content

Commit

Permalink
Simplify the implementation of ChatClient-based interfaces.
Browse files Browse the repository at this point in the history
Provide sensible defaults for overloaded methods in PromptUserSpec, PromptSystemSpec, AdvisorSpec, CallResponseSpec, ChatClientRequestSpec, and Builder interfaces.

Closes #1663
  • Loading branch information
jxblum committed Nov 4, 2024
1 parent 08a007f commit be38f10
Showing 1 changed file with 103 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package org.springframework.ai.chat.client;

import java.lang.reflect.Type;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -37,21 +39,24 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;

/**
* Client to perform stateless requests to an AI Model, using a fluent API.
*
* Client used to perform stateless requests to an AI Model, using a fluent API.
* <p/>
* Use {@link ChatClient#builder(ChatModel)} to prepare an instance.
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @author Arjen Poutsma
* @author Thomas Vitale
* @author John Blum
* @see ChatModel
* @since 1.0.0
*/
public interface ChatClient {
Expand Down Expand Up @@ -97,15 +102,22 @@ static Builder builder(ChatModel chatModel, ObservationRegistry observationRegis

interface PromptUserSpec {

PromptUserSpec text(String text);
default PromptUserSpec text(String text) {
Charset defaultCharset = Charset.defaultCharset();
return text(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

PromptUserSpec text(Resource text, Charset charset);
default PromptUserSpec text(Resource text) {
return text(text, Charset.defaultCharset());
}

PromptUserSpec text(Resource text);
PromptUserSpec text(Resource text, Charset charset);

PromptUserSpec params(Map<String, Object> p);
default PromptUserSpec param(String key, Object value) {
return params(Map.of(key, value));
}

PromptUserSpec param(String k, Object v);
PromptUserSpec params(Map<String, Object> params);

PromptUserSpec media(Media... media);

Expand All @@ -117,25 +129,36 @@ interface PromptUserSpec {

interface PromptSystemSpec {

PromptSystemSpec text(String text);
default PromptSystemSpec text(String text) {
Charset defaultCharset = Charset.defaultCharset();
return text(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

PromptSystemSpec text(Resource text, Charset charset);
default PromptSystemSpec text(Resource text) {
return text(text, Charset.defaultCharset());
}

PromptSystemSpec text(Resource text);
PromptSystemSpec text(Resource text, Charset charset);

PromptSystemSpec params(Map<String, Object> p);
default PromptSystemSpec param(String key, Object value) {
return params(Map.of(key, value));
}

PromptSystemSpec param(String k, Object v);
PromptSystemSpec params(Map<String, Object> params);

}

interface AdvisorSpec {

AdvisorSpec param(String k, Object v);
default AdvisorSpec param(String key, Object value) {
return params(Map.of(key, value));
}

AdvisorSpec params(Map<String, Object> p);
AdvisorSpec params(Map<String, Object> params);

AdvisorSpec advisors(Advisor... advisors);
default AdvisorSpec advisors(Advisor... advisors) {
return advisors(Arrays.asList(advisors));
}

AdvisorSpec advisors(List<Advisor> advisors);

Expand All @@ -144,21 +167,39 @@ interface AdvisorSpec {
interface CallResponseSpec {

@Nullable
<T> T entity(ParameterizedTypeReference<T> type);
default <T> T entity(Class<T> type) {

return entity(new ParameterizedTypeReference<>() {

@Override
public Type getType() {
return type;
}
});
}

@Nullable
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);
<T> T entity(ParameterizedTypeReference<T> type);

@Nullable
<T> T entity(Class<T> type);
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);

@Nullable
ChatResponse chatResponse();

@Nullable
String content();

<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);
default <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) {

return responseEntity(new ParameterizedTypeReference<T>() {

@Override
public Type getType() {
return type;
}
});
}

<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);

Expand Down Expand Up @@ -202,11 +243,15 @@ interface ChatClientRequestSpec {

ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer);

ChatClientRequestSpec advisors(Advisor... advisors);
default ChatClientRequestSpec advisors(Advisor... advisors) {
return advisors(Arrays.asList(advisors));
}

ChatClientRequestSpec advisors(List<Advisor> advisors);

ChatClientRequestSpec messages(Message... messages);
default ChatClientRequestSpec messages(Message... messages) {
return messages(Arrays.asList(messages));
}

ChatClientRequestSpec messages(List<Message> messages);

Expand All @@ -227,19 +272,29 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>

ChatClientRequestSpec toolContext(Map<String, Object> toolContext);

ChatClientRequestSpec system(String text);
default ChatClientRequestSpec system(String text) {
Charset defaultCharset = Charset.defaultCharset();
return system(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

ChatClientRequestSpec system(Resource textResource, Charset charset);
default ChatClientRequestSpec system(Resource text) {
return system(text, Charset.defaultCharset());
}

ChatClientRequestSpec system(Resource text);
ChatClientRequestSpec system(Resource textResource, Charset charset);

ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer);

ChatClientRequestSpec user(String text);
default ChatClientRequestSpec user(String text) {
Charset defaultCharset = Charset.defaultCharset();
return user(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

ChatClientRequestSpec user(Resource text, Charset charset);
default ChatClientRequestSpec user(Resource text) {
return user(text, Charset.defaultCharset());
}

ChatClientRequestSpec user(Resource text);
ChatClientRequestSpec user(Resource text, Charset charset);

ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);

Expand All @@ -254,27 +309,39 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>
*/
interface Builder {

Builder defaultAdvisors(Advisor... advisor);

Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);
default Builder defaultAdvisors(Advisor... advisors) {
return defaultAdvisors(Arrays.asList(advisors));
}

Builder defaultAdvisors(List<Advisor> advisors);

Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);

Builder defaultOptions(ChatOptions chatOptions);

Builder defaultUser(String text);
default Builder defaultUser(String text) {
Charset defaulCharset = Charset.defaultCharset();
return defaultUser(new ByteArrayResource(text.getBytes(defaulCharset)), defaulCharset);
}

Builder defaultUser(Resource text, Charset charset);
default Builder defaultUser(Resource text) {
return defaultUser(text, Charset.defaultCharset());
}

Builder defaultUser(Resource text);
Builder defaultUser(Resource text, Charset charset);

Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);

Builder defaultSystem(String text);
default Builder defaultSystem(String text) {
Charset defaultCharset = Charset.defaultCharset();
return defaultSystem(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

Builder defaultSystem(Resource text, Charset charset);
default Builder defaultSystem(Resource text) {
return defaultSystem(text, Charset.defaultCharset());
}

Builder defaultSystem(Resource text);
Builder defaultSystem(Resource text, Charset charset);

Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);

Expand Down

0 comments on commit be38f10

Please sign in to comment.