Skip to content

Commit

Permalink
chore: Fix resource group header (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfrenken authored Oct 2, 2024
1 parent 8a96e9e commit 771f986
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 25 deletions.
9 changes: 9 additions & 0 deletions .changeset/two-bottles-juggle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
'@sap-ai-sdk/foundation-models': minor
'@sap-ai-sdk/orchestration': minor
'@sap-ai-sdk/langchain': minor
'@sap-ai-sdk/ai-api': minor
'@sap-ai-sdk/core': minor
---

[Fixed Issue] Fix sending the correct resource group headers when custom resource group is set.
11 changes: 11 additions & 0 deletions packages/ai-api/src/utils/deployment-resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ export type ModelDeployment<ModelNameT = string> =
| ModelNameT
| ((ModelConfig<ModelNameT> | DeploymentIdConfig) & ResourceGroupConfig);

/**
* @internal
*/
export function getResourceGroup(
modelDeployment: ModelDeployment
): string | undefined {
return typeof modelDeployment === 'object'
? modelDeployment.resourceGroup
: undefined;
}

/**
* Type guard to check if the given deployment configuration is a deployment ID configuration.
* @param modelDeployment - Configuration to check.
Expand Down
27 changes: 27 additions & 0 deletions packages/core/src/http-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,31 @@ describe('http-client', () => {
expect(res.status).toBe(200);
expect(res.data).toEqual(mockPromptResponse);
}, 10000);

it('should execute a request to the AI Core service with a custom resource group', async () => {
const mockPrompt = { prompt: 'some test prompt' };
const mockPromptResponse = { completion: 'some test completion' };

const scope = nock(aiCoreDestination.url, {
reqheaders: {
'ai-resource-group': 'custom-resource-group',
'ai-client-type': 'AI SDK JavaScript'
}
})
.post('/v2/some/endpoint', mockPrompt)
.query({ 'api-version': 'mock-api-version' })
.reply(200, mockPromptResponse);

const res = await executeRequest(
{
url: '/some/endpoint',
apiVersion: 'mock-api-version',
resourceGroup: 'custom-resource-group'
},
mockPrompt
);
expect(scope.isDone()).toBe(true);
expect(res.status).toBe(200);
expect(res.data).toEqual(mockPromptResponse);
});
});
12 changes: 8 additions & 4 deletions packages/core/src/http-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ export interface EndpointOptions {
* The specific endpoint to call.
*/
url: string;

/**
* The API version to use.
*/
apiVersion?: string;
/**
* The resource group to use.
*/
resourceGroup?: string;
}
/**
* Executes a request to the AI Core service.
Expand All @@ -49,10 +52,10 @@ export async function executeRequest(
requestConfig?: CustomRequestConfig
): Promise<HttpResponse> {
const aiCoreDestination = await getAiCoreDestination();
const { url, apiVersion } = endpointOptions;
const { url, apiVersion, resourceGroup = 'default' } = endpointOptions;

const mergedRequestConfig = {
...mergeWithDefaultRequestConfig(apiVersion, requestConfig),
...mergeWithDefaultRequestConfig(apiVersion, resourceGroup, requestConfig),
data: JSON.stringify(data)
};

Expand All @@ -69,13 +72,14 @@ export async function executeRequest(

function mergeWithDefaultRequestConfig(
apiVersion?: string,
resourceGroup?: string,
requestConfig?: CustomRequestConfig
): HttpRequestConfig {
const defaultConfig: HttpRequestConfig = {
method: 'post',
headers: {
'content-type': 'application/json',
'ai-resource-group': 'default',
'ai-resource-group': resourceGroup,
'ai-client-type': 'AI SDK JavaScript'
},
params: apiVersion ? { 'api-version': apiVersion } : {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import nock from 'nock';
import {
mockClientCredentialsGrantCall,
mockDeploymentsList,
mockInference,
parseMockResponse
} from '../../../../test-util/mock-http.js';
Expand Down Expand Up @@ -75,4 +76,55 @@ describe('Azure OpenAI chat client', () => {

await expect(client.run(prompt)).rejects.toThrow('status code 400');
});

it('executes a request with the custom resource group', async () => {
const customChatCompletionEndpoint = {
url: 'inference/deployments/1234/chat/completions',
apiVersion,
resourceGroup: 'custom-resource-group'
};

const mockResponse =
parseMockResponse<AzureOpenAiCreateChatCompletionResponse>(
'foundation-models',
'azure-openai-chat-completion-success-response.json'
);

const prompt = {
messages: [
{
role: 'user' as const,
content: 'Where is the deepest place on earth located'
}
]
};

mockDeploymentsList(
{
scenarioId: 'foundation-models',
resourceGroup: 'custom-resource-group',
executableId: 'azure-openai'
},
{ id: '1234', model: { name: 'gpt-4o', version: 'latest' } }
);

mockInference(
{
data: prompt
},
{
data: mockResponse,
status: 200
},
customChatCompletionEndpoint
);

const clientWithResourceGroup = new AzureOpenAiChatClient({
modelName: 'gpt-4o',
resourceGroup: 'custom-resource-group'
});

const response = await clientWithResourceGroup.run(prompt);
expect(response.data).toEqual(mockResponse);
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { type CustomRequestConfig, executeRequest } from '@sap-ai-sdk/core';
import {
getDeploymentId,
getResourceGroup,
type ModelDeployment
} from '@sap-ai-sdk/ai-api/internal.js';
import type { AzureOpenAiCreateChatCompletionRequest } from './client/inference/schema/index.js';
Expand Down Expand Up @@ -31,10 +32,12 @@ export class AzureOpenAiChatClient {
this.modelDeployment,
'azure-openai'
);
const resourceGroup = getResourceGroup(this.modelDeployment);
const response = await executeRequest(
{
url: `/inference/deployments/${deploymentId}/chat/completions`,
apiVersion
apiVersion,
resourceGroup
},
data,
requestConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { type CustomRequestConfig, executeRequest } from '@sap-ai-sdk/core';
import {
getDeploymentId,
getResourceGroup,
type ModelDeployment
} from '@sap-ai-sdk/ai-api/internal.js';
import { AzureOpenAiEmbeddingResponse } from './azure-openai-embedding-response.js';
Expand Down Expand Up @@ -34,8 +35,13 @@ export class AzureOpenAiEmbeddingClient {
this.modelDeployment,
'azure-openai'
);
const resourceGroup = getResourceGroup(this.modelDeployment);
const response = await executeRequest(
{ url: `/inference/deployments/${deploymentId}/embeddings`, apiVersion },
{
url: `/inference/deployments/${deploymentId}/embeddings`,
apiVersion,
resourceGroup
},
data,
requestConfig
);
Expand Down
12 changes: 1 addition & 11 deletions packages/langchain/src/openai/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { BaseMessage } from '@langchain/core/messages';
import type { ChatResult } from '@langchain/core/outputs';
import { AzureOpenAiChatClient as AzureOpenAiChatClientBase } from '@sap-ai-sdk/foundation-models';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AzureOpenAiChatModel } from '@sap-ai-sdk/core';
import { mapLangchainToAiClient, mapOutputToChatResult } from './util.js';
import type {
AzureOpenAiChatCallOptions,
Expand All @@ -13,13 +12,7 @@ import type {
/**
* LangChain chat client for Azure OpenAI consumption on SAP BTP.
*/
export class AzureOpenAiChatClient
extends BaseChatModel<AzureOpenAiChatCallOptions>
implements AzureOpenAiChatModelParams
{
modelName: AzureOpenAiChatModel;
modelVersion?: string;
resourceGroup?: string;
export class AzureOpenAiChatClient extends BaseChatModel<AzureOpenAiChatCallOptions> {
temperature?: number | null;
top_p?: number | null;
logit_bias?: Record<string, any> | null | undefined;
Expand All @@ -33,9 +26,6 @@ export class AzureOpenAiChatClient
constructor(fields: AzureOpenAiChatModelParams) {
super(fields);
this.openAiChatClient = new AzureOpenAiChatClientBase(fields);
this.modelName = fields.modelName;
this.modelVersion = fields.modelVersion;
this.resourceGroup = fields.resourceGroup;
this.temperature = fields.temperature;
this.top_p = fields.top_p;
this.logit_bias = fields.logit_bias;
Expand Down
5 changes: 1 addition & 4 deletions packages/langchain/src/openai/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ import { AzureOpenAiEmbeddingModelParams } from './types.js';
/**
* LangChain embedding client for Azure OpenAI consumption on SAP BTP.
*/
export class AzureOpenAiEmbeddingClient
extends Embeddings
implements AzureOpenAiEmbeddingModelParams
{
export class AzureOpenAiEmbeddingClient extends Embeddings {
modelName: AzureOpenAiEmbeddingModel;
modelVersion?: string;
resourceGroup?: string;
Expand Down
56 changes: 55 additions & 1 deletion packages/orchestration/src/orchestration-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
} from './orchestration-client.js';
import { buildAzureContentFilter } from './orchestration-filter-utility.js';
import { OrchestrationResponse } from './orchestration-response.js';
import { OrchestrationModuleConfig } from './orchestration-types.js';
import { OrchestrationModuleConfig, Prompt } from './orchestration-types.js';

describe('orchestration service client', () => {
beforeEach(() => {
Expand Down Expand Up @@ -220,4 +220,58 @@ describe('orchestration service client', () => {
);
expect(response.data).toEqual(mockResponse);
});

it('executes a request with the custom resource group', async () => {
const prompt: Prompt = {
messagesHistory: [
{
role: 'user',
content: 'Where is the deepest place on earth located'
}
]
};

const config: OrchestrationModuleConfig = {
llm: {
model_name: 'gpt-4o',
model_params: {}
},
templating: {
template: [{ role: 'user', content: "What's my name?" }]
}
};

const customChatCompletionEndpoint = {
url: 'inference/deployments/1234/completion',
resourceGroup: 'custom-resource-group'
};

const mockResponse = parseMockResponse<CompletionPostResponse>(
'orchestration',
'orchestration-chat-completion-message-history.json'
);

mockDeploymentsList(
{ scenarioId: 'orchestration', resourceGroup: 'custom-resource-group' },
{ id: '1234', model: { name: 'gpt-4o', version: 'latest' } }
);

mockInference(
{
data: constructCompletionPostRequest(config, prompt)
},
{
data: mockResponse,
status: 200
},
customChatCompletionEndpoint
);

const clientWithResourceGroup = new OrchestrationClient(config, {
resourceGroup: 'custom-resource-group'
});

const response = await clientWithResourceGroup.chatCompletion(prompt);
expect(response.data).toEqual(mockResponse);
});
});
3 changes: 2 additions & 1 deletion packages/orchestration/src/orchestration-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export class OrchestrationClient {

const response = await executeRequest(
{
url: `/inference/deployments/${deploymentId}/completion`
url: `/inference/deployments/${deploymentId}/completion`,
resourceGroup: this.deploymentConfig?.resourceGroup
},
body,
requestConfig
Expand Down
4 changes: 2 additions & 2 deletions test-util/mock-http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ export function mockInference(
},
endpoint: EndpointOptions = mockEndpoint
): nock.Scope {
const { url, apiVersion } = endpoint;
const { url, apiVersion, resourceGroup = 'default' } = endpoint;
const destination = getMockedAiCoreDestination();
return nock(destination.url, {
reqheaders: {
'ai-resource-group': 'default',
'ai-resource-group': resourceGroup,
authorization: `Bearer ${destination.authTokens?.[0].value}`
}
})
Expand Down

0 comments on commit 771f986

Please sign in to comment.