Pluggable middleware system for language models to add custom behavior like logging, caching, request modification, and response transformation.
Core middleware interface that allows intercepting and modifying language model operations.
/**
* Middleware interface for language models
*/
interface LanguageModelV2Middleware {
/** Middleware API version */
middlewareVersion?: 'v2' | undefined;
/** Override the provider name returned by the model */
overrideProvider?: (options: { model: LanguageModelV2 }) => string;
/** Override the model ID returned by the model */
overrideModelId?: (options: { model: LanguageModelV2 }) => string;
/** Override the supported URLs for the model */
overrideSupportedUrls?: (options: {
model: LanguageModelV2
}) => PromiseLike<Record<string, RegExp[]>> | Record<string, RegExp[]>;
/** Transform call parameters before they reach the model */
transformParams?: (options: {
type: 'generate' | 'stream';
params: LanguageModelV2CallOptions;
model: LanguageModelV2;
}) => PromiseLike<LanguageModelV2CallOptions>;
/** Wrap the doGenerate method with custom logic */
wrapGenerate?: (options: {
doGenerate: () => ReturnType<LanguageModelV2['doGenerate']>;
doStream: () => ReturnType<LanguageModelV2['doStream']>;
params: LanguageModelV2CallOptions;
model: LanguageModelV2;
}) => Promise<Awaited<ReturnType<LanguageModelV2['doGenerate']>>>;
/** Wrap the doStream method with custom logic */
wrapStream?: (options: {
doGenerate: () => ReturnType<LanguageModelV2['doGenerate']>;
doStream: () => ReturnType<LanguageModelV2['doStream']>;
params: LanguageModelV2CallOptions;
model: LanguageModelV2;
}) => PromiseLike<Awaited<ReturnType<LanguageModelV2['doStream']>>>;
}Usage Examples:
import {
LanguageModelV2Middleware,
LanguageModelV2,
LanguageModelV2CallOptions
} from "@ai-sdk/provider";
// Logging middleware
const loggingMiddleware: LanguageModelV2Middleware = {
middlewareVersion: 'v2',
transformParams: async ({ type, params, model }) => {
console.log(`[${model.provider}:${model.modelId}] ${type} call with params:`, {
promptLength: params.prompt.length,
maxTokens: params.maxOutputTokens,
temperature: params.temperature
});
return params;
},
wrapGenerate: async ({ doGenerate, params, model }) => {
const startTime = Date.now();
console.log(`[${model.provider}:${model.modelId}] Starting generation...`);
try {
const result = await doGenerate();
const duration = Date.now() - startTime;
console.log(`[${model.provider}:${model.modelId}] Generation completed in ${duration}ms`, {
tokensGenerated: result.usage.outputTokens,
finishReason: result.finishReason
});
return result;
} catch (error) {
const duration = Date.now() - startTime;
console.error(`[${model.provider}:${model.modelId}] Generation failed after ${duration}ms:`, error);
throw error;
}
},
wrapStream: async ({ doStream, params, model }) => {
console.log(`[${model.provider}:${model.modelId}] Starting stream...`);
const streamResult = await doStream();
// Wrap the stream to log stream events
const originalStream = streamResult.stream;
const wrappedStream = new ReadableStream({
start(controller) {
const reader = originalStream.getReader();
const pump = async () => {
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
console.log(`[${model.provider}:${model.modelId}] Stream completed`);
controller.close();
break;
}
if (value.type === 'text-delta') {
console.log(`[${model.provider}:${model.modelId}] Text delta:`, value.textDelta);
} else if (value.type === 'finish') {
console.log(`[${model.provider}:${model.modelId}] Stream finished:`, value.finishReason);
}
controller.enqueue(value);
}
} catch (error) {
console.error(`[${model.provider}:${model.modelId}] Stream error:`, error);
controller.error(error);
}
};
pump();
}
});
return { ...streamResult, stream: wrappedStream };
}
};
// Caching middleware
class CachingMiddleware implements LanguageModelV2Middleware {
private cache = new Map<string, any>();
middlewareVersion: 'v2' = 'v2';
private getCacheKey(params: LanguageModelV2CallOptions, modelId: string): string {
return JSON.stringify({
modelId,
prompt: params.prompt,
temperature: params.temperature,
maxOutputTokens: params.maxOutputTokens,
seed: params.seed
});
}
wrapGenerate = async ({ doGenerate, params, model }) => {
const cacheKey = this.getCacheKey(params, model.modelId);
// Check cache first
if (this.cache.has(cacheKey)) {
console.log('Cache hit for generation request');
return this.cache.get(cacheKey);
}
// Generate and cache result
const result = await doGenerate();
this.cache.set(cacheKey, result);
console.log('Cached generation result');
return result;
};
}
// Rate limiting middleware
class RateLimitingMiddleware implements LanguageModelV2Middleware {
private lastCall = new Map<string, number>();
private minInterval: number;
constructor(minIntervalMs = 1000) {
this.minInterval = minIntervalMs;
}
middlewareVersion: 'v2' = 'v2';
private async enforceRateLimit(modelId: string): Promise<void> {
const now = Date.now();
const lastCallTime = this.lastCall.get(modelId) || 0;
const timeSinceLastCall = now - lastCallTime;
if (timeSinceLastCall < this.minInterval) {
const waitTime = this.minInterval - timeSinceLastCall;
console.log(`Rate limiting: waiting ${waitTime}ms for model ${modelId}`);
await new Promise(resolve => setTimeout(resolve, waitTime));
}
this.lastCall.set(modelId, Date.now());
}
wrapGenerate = async ({ doGenerate, model }) => {
await this.enforceRateLimit(model.modelId);
return doGenerate();
};
wrapStream = async ({ doStream, model }) => {
await this.enforceRateLimit(model.modelId);
return doStream();
};
}
// Parameter transformation middleware
const parameterEnhancementMiddleware: LanguageModelV2Middleware = {
middlewareVersion: 'v2',
transformParams: async ({ params, model }) => {
// Add system message if none exists
const hasSystemMessage = params.prompt.some(msg => msg.role === 'system');
if (!hasSystemMessage) {
return {
...params,
prompt: [
{
role: 'system',
content: 'You are a helpful AI assistant. Please provide accurate and helpful responses.'
},
...params.prompt
]
};
}
// Ensure temperature is set for creative tasks
if (!params.temperature && params.prompt.some(msg =>
msg.role === 'user' &&
JSON.stringify(msg.content).toLowerCase().includes('creative')
)) {
return {
...params,
temperature: 0.8
};
}
return params;
}
};
// Error handling middleware
const errorHandlingMiddleware: LanguageModelV2Middleware = {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate, params, model }) => {
try {
return await doGenerate();
} catch (error) {
console.error(`Generation error for ${model.modelId}:`, error);
// Return a fallback response for certain error types
if (error instanceof Error && error.message.includes('rate limit')) {
console.log('Rate limit detected, implementing exponential backoff...');
await new Promise(resolve => setTimeout(resolve, 2000));
return doGenerate();
}
throw error;
}
},
wrapStream: async ({ doStream, params, model }) => {
try {
return await doStream();
} catch (error) {
console.error(`Streaming error for ${model.modelId}:`, error);
throw error;
}
}
};
// Model aliasing middleware
const modelAliasingMiddleware: LanguageModelV2Middleware = {
middlewareVersion: 'v2',
overrideModelId: ({ model }) => {
const aliases: Record<string, string> = {
'gpt-4-latest': 'gpt-4-turbo-2024-04-09',
'claude-latest': 'claude-3-opus-20240229',
'gemini-latest': 'gemini-1.5-pro'
};
return aliases[model.modelId] || model.modelId;
}
};
// Metrics collection middleware
class MetricsMiddleware implements LanguageModelV2Middleware {
private metrics = {
totalCalls: 0,
totalTokensUsed: 0,
averageLatency: 0,
errorCount: 0
};
middlewareVersion: 'v2' = 'v2';
wrapGenerate = async ({ doGenerate, model }) => {
const startTime = Date.now();
this.metrics.totalCalls++;
try {
const result = await doGenerate();
const latency = Date.now() - startTime;
// Update metrics
this.metrics.totalTokensUsed += result.usage.totalTokens || 0;
this.metrics.averageLatency =
(this.metrics.averageLatency * (this.metrics.totalCalls - 1) + latency) /
this.metrics.totalCalls;
return result;
} catch (error) {
this.metrics.errorCount++;
throw error;
}
};
getMetrics() {
return { ...this.metrics };
}
}
// Composing multiple middleware
class MiddlewareChain implements LanguageModelV2Middleware {
constructor(private middlewares: LanguageModelV2Middleware[]) {}
middlewareVersion: 'v2' = 'v2';
transformParams = async (options) => {
let params = options.params;
for (const middleware of this.middlewares) {
if (middleware.transformParams) {
params = await middleware.transformParams({
...options,
params
});
}
}
return params;
};
wrapGenerate = async (options) => {
let wrappedGenerate = options.doGenerate;
// Apply middleware in reverse order for wrapping
for (let i = this.middlewares.length - 1; i >= 0; i--) {
const middleware = this.middlewares[i];
if (middleware.wrapGenerate) {
const currentGenerate = wrappedGenerate;
wrappedGenerate = () => middleware.wrapGenerate!({
...options,
doGenerate: currentGenerate
});
}
}
return wrappedGenerate();
};
}
// Usage: Applying middleware to a model
const metricsMiddleware = new MetricsMiddleware();
const cachingMiddleware = new CachingMiddleware();
const middlewareChain = new MiddlewareChain([
loggingMiddleware,
metricsMiddleware,
cachingMiddleware,
parameterEnhancementMiddleware,
errorHandlingMiddleware
]);
// Apply middleware to model (implementation would depend on your model wrapper)
class WrappedLanguageModel implements LanguageModelV2 {
constructor(
private baseModel: LanguageModelV2,
private middleware: LanguageModelV2Middleware
) {}
get specificationVersion() { return this.baseModel.specificationVersion; }
get provider() {
return this.middleware.overrideProvider?.({ model: this.baseModel }) ||
this.baseModel.provider;
}
get modelId() {
return this.middleware.overrideModelId?.({ model: this.baseModel }) ||
this.baseModel.modelId;
}
get supportedUrls() {
return this.middleware.overrideSupportedUrls?.({ model: this.baseModel }) ||
this.baseModel.supportedUrls;
}
async doGenerate(options: LanguageModelV2CallOptions) {
// Transform parameters
const transformedOptions = this.middleware.transformParams ?
await this.middleware.transformParams({
type: 'generate',
params: options,
model: this.baseModel
}) : options;
// Wrap generate
if (this.middleware.wrapGenerate) {
return this.middleware.wrapGenerate({
doGenerate: () => this.baseModel.doGenerate(transformedOptions),
doStream: () => this.baseModel.doStream(transformedOptions),
params: transformedOptions,
model: this.baseModel
});
}
return this.baseModel.doGenerate(transformedOptions);
}
async doStream(options: LanguageModelV2CallOptions) {
// Similar implementation for streaming...
const transformedOptions = this.middleware.transformParams ?
await this.middleware.transformParams({
type: 'stream',
params: options,
model: this.baseModel
}) : options;
if (this.middleware.wrapStream) {
return this.middleware.wrapStream({
doGenerate: () => this.baseModel.doGenerate(transformedOptions),
doStream: () => this.baseModel.doStream(transformedOptions),
params: transformedOptions,
model: this.baseModel
});
}
return this.baseModel.doStream(transformedOptions);
}
}
// Using the wrapped model
const originalModel = provider.languageModel('gpt-4');
const enhancedModel = new WrappedLanguageModel(originalModel, middlewareChain);
const result = await enhancedModel.doGenerate({
prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello!' }] }]
});
console.log('Metrics:', metricsMiddleware.getMetrics());