Skip to content

Commit

Permalink
#30462: Fix for cycling calls when non SystemHost AppConfig objects w…
Browse files Browse the repository at this point in the history
…ere instantiated (#30463)
  • Loading branch information
victoralfaro-dotcms authored Oct 25, 2024
1 parent 14430e3 commit cad5532
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ public void summarizeStream(final CompletionsForm summaryRequest, final OutputSt

@Override
public JSONObject raw(final JSONObject json, final String userId) {
AppConfig.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2));
config.debugLogger(this.getClass(), () -> "OpenAI request:" + json.toString(2));

final String response = sendRequest(config, json, userId).getResponse();
AppConfig.debugLogger(this.getClass(), () -> "OpenAI response:" + response);
config.debugLogger(this.getClass(), () -> "OpenAI response:" + response);

return new JSONObject(response);
}
Expand Down
9 changes: 4 additions & 5 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static com.dotcms.ai.app.AppConfig.debugLogger;
import static com.liferay.util.StringPool.BLANK;

/**
Expand Down Expand Up @@ -335,7 +334,7 @@ public Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(final String conten
.map(encoding -> encoding.encode(content))
.orElse(List.of());
if (tokens.isEmpty()) {
debugLogger(this.getClass(), () -> String.format("No tokens for content ID '%s' were encoded: %s", contentId, content));
config.debugLogger(this.getClass(), () -> String.format("No tokens for content ID '%s' were encoded: %s", contentId, content));
return Tuple.of(0, List.of());
}

Expand Down Expand Up @@ -432,15 +431,15 @@ private List<Float> sendTokensToOpenAI(final String contentId,
final JSONObject json = new JSONObject();
json.put(AiKeys.MODEL, config.getEmbeddingsModel().getCurrentModel());
json.put(AiKeys.INPUT, tokens);
debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
config.debugLogger(this.getClass(), () -> String.format("Content tokens for content ID '%s': %s", contentId, tokens));
final String responseString = AIProxyClient.get()
.callToAI(JSONObjectAIRequest.quickEmbeddings(config, json, userId))
.getResponse();
debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
config.debugLogger(this.getClass(), () -> String.format("OpenAI Response for content ID '%s': %s",
contentId, responseString.replace("\n", BLANK)));
final JSONObject jsonResponse = Try.of(() -> new JSONObject(responseString)).getOrElseThrow(e -> {
Logger.error(this, "OpenAI Response String is not a valid JSON", e);
debugLogger(this.getClass(), () -> String.format("Invalid JSON Response: %s", responseString));
config.debugLogger(this.getClass(), () -> String.format("Invalid JSON Response: %s", responseString));
return new DotCorruptedDataException(e);
});
if (jsonResponse.containsKey(AiKeys.ERROR)) {
Expand Down
6 changes: 3 additions & 3 deletions dotCMS/src/main/java/com/dotcms/ai/api/EmbeddingsRunner.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dotcms.ai.api;

import com.dotcms.ai.app.AppConfig;
import com.dotcms.ai.app.AppKeys;
import com.dotcms.ai.app.ConfigService;
import com.dotcms.ai.db.EmbeddingsDTO;
Expand All @@ -17,7 +18,6 @@
import java.util.List;
import java.util.Locale;

import static com.dotcms.ai.app.AppConfig.debugLogger;
import static com.liferay.util.StringPool.SPACE;

/**
Expand Down Expand Up @@ -86,9 +86,9 @@ public void run() {
}

if (buffer.toString().split("\\s+").length > 0) {
debugLogger(this.getClass(), () -> String.format("Saving embeddings for contentlet ID '%s'", this.contentlet.getIdentifier()));
AppConfig.debugLogger(embeddingsAPI.config, this.getClass(), () -> String.format("Saving embeddings for contentlet ID '%s'", this.contentlet.getIdentifier()));
this.saveEmbedding(buffer.toString());
debugLogger(this.getClass(), () -> String.format("Embeddings for contentlet ID '%s' were saved", this.contentlet.getIdentifier()));
AppConfig.debugLogger(embeddingsAPI.config, this.getClass(), () -> String.format("Embeddings for contentlet ID '%s' were saved", this.contentlet.getIdentifier()));
}
} catch (final Exception e) {
final String errorMsg = String.format("Failed to generate embeddings for contentlet ID " +
Expand Down
17 changes: 9 additions & 8 deletions dotCMS/src/main/java/com/dotcms/ai/app/AIModels.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ public static AIModels get() {
return INSTANCE.get();
}

private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final String apiKey) {
private static CircuitBreakerUrl.Response<OpenAIModels> fetchOpenAIModels(final AppConfig appConfig) {
final CircuitBreakerUrl.Response<OpenAIModels> response = CircuitBreakerUrl.builder()
.setMethod(CircuitBreakerUrl.Method.GET)
.setUrl(AI_MODELS_API_URL)
.setTimeout(AI_MODELS_FETCH_TIMEOUT)
.setTryAgainAttempts(AI_MODELS_FETCH_ATTEMPTS)
.setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + apiKey))
.setHeaders(CircuitBreakerUrl.authHeaders("Bearer " + appConfig.getApiKey()))
.setThrowWhenNot2xx(true)
.build()
.doResponse(OpenAIModels.class);

if (!CircuitBreakerUrl.isSuccessResponse(response)) {
AppConfig.debugLogger(
appConfig.debugLogger(
AIModels.class,
() -> String.format(
"Error fetching OpenAI supported models from [%s] (status code: [%d])",
Expand All @@ -98,10 +98,11 @@ private AIModels() {
* are already loaded, this method does nothing. It also maps model names to their
* corresponding AIModel instances.
*
* @param host the host for which the models are being loaded
* @param appConfig app config
* @param loading the list of AI models to load
*/
public void loadModels(final String host, final List<AIModel> loading) {
public void loadModels(final AppConfig appConfig, final List<AIModel> loading) {
final String host = appConfig.getHost();
final List<Tuple2<AIModelType, AIModel>> added = internalModels.putIfAbsent(
host,
loading.stream()
Expand All @@ -112,7 +113,7 @@ public void loadModels(final String host, final List<AIModel> loading) {
.forEach(model -> {
final Tuple3<String, Model, AIModelType> key = Tuple.of(host, model, aiModel.getType());
if (modelsByName.containsKey(key)) {
AppConfig.debugLogger(
appConfig.debugLogger(
getClass(),
() -> String.format(
"Model [%s] already exists for host [%s], ignoring it",
Expand Down Expand Up @@ -230,11 +231,11 @@ public Set<String> getOrPullSupportedModels(final AppConfig appConfig) {
}

if (!appConfig.isEnabled()) {
AppConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty set of supported models");
appConfig.debugLogger(getClass(), () -> "dotAI is not enabled, returning empty set of supported models");
return Set.of();
}

final CircuitBreakerUrl.Response<OpenAIModels> response = fetchOpenAIModels(appConfig.getApiKey());
final CircuitBreakerUrl.Response<OpenAIModels> response = fetchOpenAIModels(appConfig);
if (Objects.nonNull(response.getResponse().getError())) {
throw new DotRuntimeException("Found error in AI response: " + response.getResponse().getError().getMessage());
}
Expand Down
41 changes: 15 additions & 26 deletions dotCMS/src/main/java/com/dotcms/ai/app/AppConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.dotcms.ai.domain.Model;
import com.dotcms.security.apps.Secret;
import com.dotmarketing.util.Config;
import com.dotmarketing.util.Logger;
import com.dotmarketing.util.UtilMethods;
import com.liferay.util.StringPool;
Expand All @@ -12,9 +13,7 @@
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand All @@ -29,8 +28,7 @@ public class AppConfig implements Serializable {
private static final String AI_API_URL_KEY = "AI_API_URL";
private static final String AI_IMAGE_API_URL_KEY = "AI_IMAGE_API_URL";
private static final String AI_EMBEDDINGS_API_URL_KEY = "AI_EMBEDDINGS_API_URL";
private static final String SYSTEM_HOST = "System Host";
private static final AtomicReference<AppConfig> SYSTEM_HOST_CONFIG = new AtomicReference<>();
private static final String AI_DEBUG_LOGGING_KEY = "AI_DEBUG_LOGGING";

public static final Pattern SPLITTER = Pattern.compile("\\s?,\\s?");

Expand All @@ -51,9 +49,6 @@ public class AppConfig implements Serializable {

public AppConfig(final String host, final Map<String, Secret> secrets) {
this.host = host;
if (SYSTEM_HOST.equalsIgnoreCase(host)) {
setSystemHostConfig(this);
}

final AIAppUtil aiAppUtil = AIAppUtil.get();
apiKey = aiAppUtil.discoverSecret(secrets, AppKeys.API_KEY);
Expand All @@ -63,7 +58,7 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {

if (!secrets.isEmpty() || isEnabled()) {
AIModels.get().loadModels(
this.host,
this,
List.of(
aiAppUtil.createTextModel(secrets),
aiAppUtil.createImageModel(secrets),
Expand All @@ -85,35 +80,25 @@ public AppConfig(final String host, final Map<String, Secret> secrets) {
Logger.debug(this, this::toString);
}

/**
* Retrieves the system host configuration.
*
* @return the system host configuration
*/
public static AppConfig getSystemHostConfig() {
if (Objects.isNull(SYSTEM_HOST_CONFIG.get())) {
setSystemHostConfig(ConfigService.INSTANCE.config());
}
return SYSTEM_HOST_CONFIG.get();
}

/**
* Prints a specific error message to the log, based on the {@link AppKeys#DEBUG_LOGGING}
* property instead of the usual Log4j configuration.
*
* @param appConfig The {#link AppConfig} to be used when logging.
* @param clazz The {@link Class} to log the message for.
* @param message The {@link Supplier} with the message to log.
*/
public static void debugLogger(final Class<?> clazz, final Supplier<String> message) {
if (getSystemHostConfig().getConfigBoolean(AppKeys.DEBUG_LOGGING)) {
public static void debugLogger(final AppConfig appConfig, final Class<?> clazz, final Supplier<String> message) {
if (appConfig == null) {
Logger.debug(clazz, message);
return;
}
if (appConfig.getConfigBoolean(AppKeys.DEBUG_LOGGING)
|| Config.getBooleanProperty(AI_DEBUG_LOGGING_KEY, false)) {
Logger.info(clazz, message.get());
}
}

public static void setSystemHostConfig(final AppConfig systemHostConfig) {
AppConfig.SYSTEM_HOST_CONFIG.set(systemHostConfig);
}

/**
* Retrieves the host.
*
Expand Down Expand Up @@ -318,6 +303,10 @@ public boolean isEnabled() {
return Stream.of(apiUrl, apiImageUrl, apiEmbeddingsUrl, apiKey).allMatch(StringUtils::isNotBlank);
}

public void debugLogger(final Class<?> clazz, final Supplier<String> message) {
debugLogger(this, clazz, message);
}

@Override
public String toString() {
return "AppConfig{\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ public class AIModelFallbackStrategy implements AIClientStrategy {
*/
@Override
public AIResponseData applyStrategy(final AIClient client,
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final AIResponseEvaluator handler,
final AIRequest<? extends Serializable> request,
final OutputStream incoming) {
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final Tuple2<AIModel, Model> modelTuple = resolveModel(jsonRequest);

Expand Down Expand Up @@ -95,9 +95,9 @@ private static boolean isSameAsFirst(final Model firstAttempt, final Model model
return firstAttempt.equals(model);
}

private static boolean isOperational(final Model model) {
private static boolean isOperational(final Model model, final AppConfig config) {
if (!model.isOperational()) {
AppConfig.debugLogger(
config.debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] is not operational. Skipping.", model.getName()));
return false;
Expand All @@ -117,7 +117,7 @@ private static AIResponseData doSend(final AIClient client,
}

private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) {
AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId());
AIAppValidator.get().validateModelsUsage(aiModel, request);
}

private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
Expand All @@ -127,7 +127,7 @@ private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,
final Model model = modelTuple._2;

if (!responseData.getStatus().doesNeedToThrow()) {
AppConfig.debugLogger(
request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] failed then setting its status to [%s].",
Expand All @@ -138,7 +138,7 @@ private static void handleFailure(final Tuple2<AIModel, Model> modelTuple,

if (model.getIndex() == aiModel.getModels().size() - 1) {
aiModel.setCurrentModelIndex(AIModel.NOOP_INDEX);
AppConfig.debugLogger(
request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] is the last one. Cannot fallback anymore.",
Expand Down Expand Up @@ -167,7 +167,7 @@ private static AIResponseData sendRequest(final AIClient client,
if (!responseData.isSuccess()) {
if (responseData.getStatus().doesNeedToThrow()) {
if (!modelTuple._1.isOperational()) {
AppConfig.debugLogger(
request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"All models from type [%s] are not operational. Throwing exception.",
Expand All @@ -181,11 +181,11 @@ private static AIResponseData sendRequest(final AIClient client,
}

if (responseData.isSuccess()) {
AppConfig.debugLogger(
request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName()));
} else {
logFailure(modelTuple, responseData);
logFailure(modelTuple, request, responseData);

handleFailure(modelTuple, request, responseData);
}
Expand All @@ -198,18 +198,20 @@ private static AIResponseData sendRequest(final AIClient client,
return responseData;
}

private static void logFailure(final Tuple2<AIModel, Model> modelTuple, final AIResponseData responseData) {
private static void logFailure(final Tuple2<AIModel, Model> modelTuple,
final JSONObjectAIRequest request,
final AIResponseData responseData) {
Optional
.ofNullable(responseData.getResponse())
.ifPresentOrElse(
response -> AppConfig.debugLogger(
response -> request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] failed with response:%s%sTrying next model.",
modelTuple._2.getName(),
System.lineSeparator(),
response)),
() -> AppConfig.debugLogger(
() -> request.getConfig().debugLogger(
AIModelFallbackStrategy.class,
() -> String.format(
"Model [%s] failed with error: [%s]. Trying next model.",
Expand All @@ -229,7 +231,7 @@ private static AIResponseData runFallbacks(final AIClient client,
final OutputStream output,
final Tuple2<AIModel, Model> modelTuple) {
for(final Model model : modelTuple._1.getModels()) {
if (isSameAsFirst(modelTuple._2, model) || !isOperational(model)) {
if (isSameAsFirst(modelTuple._2, model) || !isOperational(model, request.getConfig())) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin
final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request);
final AppConfig appConfig = jsonRequest.getConfig();

AppConfig.debugLogger(
request.getConfig().debugLogger(
OpenAIClient.class,
() -> String.format(
"Posting to [%s] with method [%s]%s with app config:%s%s the payload: %s",
Expand All @@ -94,7 +94,7 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin
jsonRequest.payloadToString()));

if (!appConfig.isEnabled()) {
AppConfig.debugLogger(OpenAIClient.class, () -> "App dotAI is not enabled and will not send request.");
request.getConfig().debugLogger(OpenAIClient.class, () -> "App dotAI is not enabled and will not send request.");
throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key");
}

Expand All @@ -106,7 +106,7 @@ public <T extends Serializable> void sendRequest(final AIRequest<T> request, fin
final AIModel aiModel = modelTuple._1;

if (!modelTuple._2.isOperational()) {
AppConfig.debugLogger(
request.getConfig().debugLogger(
getClass(),
() -> String.format("Resolved model [%s] is not operational, avoiding its usage", modelName));
throw new DotAIModelNotOperationalException(String.format("Model [%s] is not operational", modelName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private AppConfig getAppConfig(final String hostId) {

final AppConfig appConfig = ConfigService.INSTANCE.config(host);
if (!appConfig.isEnabled()) {
AppConfig.debugLogger(
appConfig.debugLogger(
getClass(),
() -> "dotAI is not enabled since no API urls or API key found in app config");
throw new DotAIAppConfigDisabledException("App dotAI config without API urls or API key");
Expand Down
Loading

0 comments on commit cad5532

Please sign in to comment.