Skip to content

Commit

Permalink
Ollama: Pull models automatically at startup
Browse files Browse the repository at this point in the history
* Introduce support for Ollama model auto-pull at startup time
* Enhance support for Ollama model auto-pull at run time
* Update documentation about integrating with Ollama and managing models
* Adopt Builder pattern in Ollama Model classes for better code readability
* Unify Ollama model auto-pull functionality in production and test code
* Improve integration tests for Ollama with Testcontainers
  • Loading branch information
ThomasVitale authored and tzolov committed Oct 18, 2024
1 parent 5dfdd8d commit 8eef6e6
Show file tree
Hide file tree
Showing 31 changed files with 878 additions and 327 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand All @@ -59,10 +61,9 @@
/**
* {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run
* large language models and generate embeddings locally. It supports open-source models
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>). - Llama
* 2 (7B parameters, 3.8GB size) - Mistral (7B parameters, 4.1GB size) Please refer to the
* <a href="https://ollama.ai/">official Ollama website</a> for the most up-to-date
* information on available models.
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>) and on
* Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
* website</a> for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author luocongqiu
Expand All @@ -73,57 +74,33 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

/**
* Low-level Ollama API library.
*/
private final OllamaApi chatApi;

/**
* Default options to be used for all chat requests.
*/
private final OllamaOptions defaultOptions;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;

public OllamaChatModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this(ollamaApi, defaultOptions, null);
}
private final OllamaModelManager modelManager;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext) {
this(ollamaApi, defaultOptions, functionCallbackContext, List.of());
}
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
this(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
}

public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
Assert.notNull(chatApi, "ollamaApi must not be null");
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
this.chatApi = chatApi;
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(observationRegistry, "modelManagementOptions must not be null");
this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(chatApi);
this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
return new Builder();
}

@Override
Expand Down Expand Up @@ -324,9 +301,9 @@ else if (message instanceof ToolResponseMessage toolMessage) {
}
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
Expand All @@ -353,9 +330,7 @@ else if (message instanceof ToolResponseMessage toolMessage) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(mergedOptions.getModel(), true);
}
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return requestBuilder.build();
}
Expand Down Expand Up @@ -400,6 +375,15 @@ public ChatOptions getDefaultOptions() {
return OllamaOptions.fromOptions(this.defaultOptions);
}

/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand All @@ -409,4 +393,58 @@ public void setObservationConvention(ChatModelObservationConvention observationC
this.observationConvention = observationConvention;
}

public static class Builder {

private OllamaApi ollamaApi;

private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);

private FunctionCallbackContext functionCallbackContext;

private List<FunctionCallback> toolFunctionCallbacks = List.of();

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private Builder() {
}

public Builder withOllamaApi(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
return this;
}

public Builder withDefaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}

public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
this.functionCallbackContext = functionCallbackContext;
return this;
}

public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}

public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
this.modelManagementOptions = modelManagementOptions;
return this;
}

public OllamaChatModel build() {
return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks,
observationRegistry, modelManagementOptions);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.regex.Pattern;

import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
Expand All @@ -32,24 +31,20 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* {@link EmbeddingModel} implementation for {@literal Ollama}.
*
* Ollama allows developers to run large language models and generate embeddings locally.
* It supports open-source models available on [Ollama AI
* Library](https://ollama.ai/library).
*
* Examples of models supported: - Llama 2 (7B parameters, 3.8GB size) - Mistral (7B
* parameters, 4.1GB size)
*
* Please refer to the <a href="https://ollama.ai/">official Ollama website</a> for the
* most up-to-date information on available models.
* {@link EmbeddingModel} implementation for {@literal Ollama}. Ollama allows developers
* to run large language models and generate embeddings locally. It supports open-source
* models available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>)
* and on Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
* website</a> for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author Thomas Vitale
Expand All @@ -61,41 +56,31 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {

private final OllamaApi ollamaApi;

/**
* Default options to be used for all chat requests.
*/
private final OllamaOptions defaultOptions;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;
private final OllamaModelManager modelManager;

public OllamaEmbeddingModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}

public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this(ollamaApi, defaultOptions, ObservationRegistry.NOOP);
}
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
ObservationRegistry observationRegistry) {
Assert.notNull(ollamaApi, "openAiApi must not be null");
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");

this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(ollamaApi);
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);

initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
return new Builder();
}

@Override
Expand Down Expand Up @@ -153,9 +138,9 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
Expand All @@ -164,9 +149,7 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
}
String model = mergedOptions.getModel();

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(model, true);
}
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
Expand All @@ -176,6 +159,15 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
}

/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down Expand Up @@ -216,4 +208,43 @@ public static Duration parse(String input) {

}

public static class Builder {

private OllamaApi ollamaApi;

private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private Builder() {
}

public Builder withOllamaApi(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
return this;
}

public Builder withDefaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}

public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
this.modelManagementOptions = modelManagementOptions;
return this;
}

public OllamaEmbeddingModel build() {
return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions);
}

}

}
Loading

0 comments on commit 8eef6e6

Please sign in to comment.