Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ export type ChatCompleteOptions = {
* Defaults to false.
*/
stream?: boolean;
/**
* The timeout for the chat completion request.
*/
timeout?: number;
} & ToolOptions;

export interface ChatCompleteRetryConfiguration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ describe('InferenceChatModel', () => {
model: 'super-duper-model',
functionCallingMode: 'simulated',
signal: abortCtrl.signal,
timeout: 60000,
telemetryMetadata,
});

Expand All @@ -355,6 +356,34 @@ describe('InferenceChatModel', () => {
temperature: 0.7,
modelName: 'super-duper-model',
abortSignal: abortCtrl.signal,
timeout: 60000,
maxRetries: undefined,
stream: false,
metadata,
});
});

it('accepts timeout argument in constructor', async () => {
const timeout = 60000;
const chatModel = new InferenceChatModel({
chatComplete,
connector,
timeout,
telemetryMetadata,
});

const response = createResponse({ content: 'dummy' });
chatComplete.mockResolvedValue(response);

await chatModel.invoke('question');

// Verify the instance was created successfully and can make calls
expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: connector.connectorId,
messages: [{ role: MessageRole.User, content: 'question' }],
timeout: 60000,
maxRetries: undefined,
stream: false,
metadata,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export interface InferenceChatModelParams extends BaseChatModelParams {
temperature?: number;
model?: string;
signal?: AbortSignal;
timeout?: number;
telemetryMetadata?: ConnectorTelemetryMetadata;
}

Expand All @@ -72,6 +73,7 @@ export interface InferenceChatModelCallOptions extends BaseChatModelCallOptions
tool_choice?: ToolChoice;
temperature?: number;
model?: string;
timeout?: number;
}

type InvocationParams = Omit<ChatCompleteOptions, 'messages' | 'system' | 'stream'>;
Expand Down Expand Up @@ -102,6 +104,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
protected maxRetries?: number;
protected model?: string;
protected signal?: AbortSignal;
protected timeout?: number;

constructor(args: InferenceChatModelParams) {
super(args);
Expand All @@ -113,6 +116,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
this.functionCallingMode = args.functionCallingMode;
this.model = args.model;
this.signal = args.signal;
this.timeout = args.timeout;
this.maxRetries = args.maxRetries;
}

Expand Down Expand Up @@ -190,6 +194,7 @@ export class InferenceChatModel extends BaseChatModel<InferenceChatModelCallOpti
abortSignal: options.signal ?? this.signal,
maxRetries: this.maxRetries,
metadata: { connectorTelemetry: this.telemetryMetadata },
timeout: options.timeout ?? this.timeout,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
modelName,
abortSignal,
metadata,
timeout,
}) => {
const noToolUsage = toolChoice === ToolChoiceType.none;

Expand All @@ -59,6 +60,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
model: modelName,
stopSequences: ['\n\nHuman:'],
signal: abortSignal,
...(typeof timeout === 'number' && isFinite(timeout) ? { timeout } : {}),
};

return defer(async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
modelName,
abortSignal,
metadata,
timeout,
}) => {
const connector = executor.getConnector();
const useThoughtSignature = mustUseThoughtSignature(
Expand All @@ -49,6 +50,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
...(metadata?.connectorTelemetry
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
...(typeof timeout === 'number' && isFinite(timeout) ? { timeout } : {}),
},
});
}).pipe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
logger,
abortSignal,
metadata,
timeout,
}) => {
const useSimulatedFunctionCalling =
functionCalling === 'auto'
Expand All @@ -52,6 +53,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
...(metadata?.connectorTelemetry
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
...(typeof timeout === 'number' && isFinite(timeout) ? { timeout } : {}),
},
});
}).pipe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
logger,
abortSignal,
metadata,
timeout,
}) => {
const connector = executor.getConnector();

Expand Down Expand Up @@ -76,6 +77,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
...(metadata?.connectorTelemetry
? { telemetryMetadata: metadata.connectorTelemetry }
: {}),
...(typeof timeout === 'number' && isFinite(timeout) ? { timeout } : {}),
},
});
}).pipe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ export function createChatCompleteCallbackApi({
temperature,
toolChoice,
tools,
timeout,
} = callback(executor);

const messages = givenMessages.map((message) => {
Expand Down Expand Up @@ -176,6 +177,7 @@ export function createChatCompleteCallbackApi({
modelName,
abortSignal,
metadata,
timeout,
})
.pipe(
chunksIntoMessage({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export type InferenceAdapterChatCompleteOptions = {
modelName?: string;
abortSignal?: AbortSignal;
metadata?: ChatCompleteMetadata;
timeout?: number;
} & ToolOptions;

/**
Expand Down