or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

embedding-models.mderrors.mdimage-models.mdindex.mdjson-utilities.mdlanguage-models.mdmiddleware.mdprovider.mdspeech-models.mdtranscription-models.md
tile.json

middleware.mddocs/

Middleware System

Pluggable middleware system for language models to add custom behavior like logging, caching, request modification, and response transformation.

Capabilities

LanguageModelV2Middleware Interface

Core middleware interface that allows intercepting and modifying language model operations.

/**
 * Middleware interface for language models
 */
interface LanguageModelV2Middleware {
  /** Middleware API version */
  middlewareVersion?: 'v2' | undefined;
  
  /** Override the provider name returned by the model */
  overrideProvider?: (options: { model: LanguageModelV2 }) => string;
  
  /** Override the model ID returned by the model */
  overrideModelId?: (options: { model: LanguageModelV2 }) => string;
  
  /** Override the supported URLs for the model */
  overrideSupportedUrls?: (options: { 
    model: LanguageModelV2 
  }) => PromiseLike<Record<string, RegExp[]>> | Record<string, RegExp[]>;
  
  /** Transform call parameters before they reach the model */
  transformParams?: (options: {
    type: 'generate' | 'stream';
    params: LanguageModelV2CallOptions;
    model: LanguageModelV2;
  }) => PromiseLike<LanguageModelV2CallOptions>;
  
  /** Wrap the doGenerate method with custom logic */
  wrapGenerate?: (options: {
    doGenerate: () => ReturnType<LanguageModelV2['doGenerate']>;
    doStream: () => ReturnType<LanguageModelV2['doStream']>;
    params: LanguageModelV2CallOptions;
    model: LanguageModelV2;
  }) => Promise<Awaited<ReturnType<LanguageModelV2['doGenerate']>>>;
  
  /** Wrap the doStream method with custom logic */
  wrapStream?: (options: {
    doGenerate: () => ReturnType<LanguageModelV2['doGenerate']>;
    doStream: () => ReturnType<LanguageModelV2['doStream']>;
    params: LanguageModelV2CallOptions;
    model: LanguageModelV2;
  }) => PromiseLike<Awaited<ReturnType<LanguageModelV2['doStream']>>>;
}

Usage Examples:

import { 
  LanguageModelV2Middleware,
  LanguageModelV2,
  LanguageModelV2CallOptions 
} from "@ai-sdk/provider";

// Logging middleware
const loggingMiddleware: LanguageModelV2Middleware = {
  middlewareVersion: 'v2',
  
  transformParams: async ({ type, params, model }) => {
    console.log(`[${model.provider}:${model.modelId}] ${type} call with params:`, {
      promptLength: params.prompt.length,
      maxTokens: params.maxOutputTokens,
      temperature: params.temperature
    });
    return params;
  },
  
  wrapGenerate: async ({ doGenerate, params, model }) => {
    const startTime = Date.now();
    console.log(`[${model.provider}:${model.modelId}] Starting generation...`);
    
    try {
      const result = await doGenerate();
      const duration = Date.now() - startTime;
      console.log(`[${model.provider}:${model.modelId}] Generation completed in ${duration}ms`, {
        tokensGenerated: result.usage.outputTokens,
        finishReason: result.finishReason
      });
      return result;
    } catch (error) {
      const duration = Date.now() - startTime;
      console.error(`[${model.provider}:${model.modelId}] Generation failed after ${duration}ms:`, error);
      throw error;
    }
  },
  
  wrapStream: async ({ doStream, params, model }) => {
    console.log(`[${model.provider}:${model.modelId}] Starting stream...`);
    
    const streamResult = await doStream();
    
    // Wrap the stream to log stream events
    const originalStream = streamResult.stream;
    const wrappedStream = new ReadableStream({
      start(controller) {
        const reader = originalStream.getReader();
        
        const pump = async () => {
          try {
            while (true) {
              const { done, value } = await reader.read();
              if (done) {
                console.log(`[${model.provider}:${model.modelId}] Stream completed`);
                controller.close();
                break;
              }
              
              if (value.type === 'text-delta') {
                console.log(`[${model.provider}:${model.modelId}] Text delta:`, value.textDelta);
              } else if (value.type === 'finish') {
                console.log(`[${model.provider}:${model.modelId}] Stream finished:`, value.finishReason);
              }
              
              controller.enqueue(value);
            }
          } catch (error) {
            console.error(`[${model.provider}:${model.modelId}] Stream error:`, error);
            controller.error(error);
          }
        };
        
        pump();
      }
    });
    
    return { ...streamResult, stream: wrappedStream };
  }
};

// Caching middleware
class CachingMiddleware implements LanguageModelV2Middleware {
  private cache = new Map<string, any>();
  
  middlewareVersion: 'v2' = 'v2';
  
  private getCacheKey(params: LanguageModelV2CallOptions, modelId: string): string {
    return JSON.stringify({
      modelId,
      prompt: params.prompt,
      temperature: params.temperature,
      maxOutputTokens: params.maxOutputTokens,
      seed: params.seed
    });
  }
  
  wrapGenerate = async ({ doGenerate, params, model }) => {
    const cacheKey = this.getCacheKey(params, model.modelId);
    
    // Check cache first
    if (this.cache.has(cacheKey)) {
      console.log('Cache hit for generation request');
      return this.cache.get(cacheKey);
    }
    
    // Generate and cache result
    const result = await doGenerate();
    this.cache.set(cacheKey, result);
    console.log('Cached generation result');
    
    return result;
  };
}

// Rate limiting middleware
class RateLimitingMiddleware implements LanguageModelV2Middleware {
  private lastCall = new Map<string, number>();
  private minInterval: number;
  
  constructor(minIntervalMs = 1000) {
    this.minInterval = minIntervalMs;
  }
  
  middlewareVersion: 'v2' = 'v2';
  
  private async enforceRateLimit(modelId: string): Promise<void> {
    const now = Date.now();
    const lastCallTime = this.lastCall.get(modelId) || 0;
    const timeSinceLastCall = now - lastCallTime;
    
    if (timeSinceLastCall < this.minInterval) {
      const waitTime = this.minInterval - timeSinceLastCall;
      console.log(`Rate limiting: waiting ${waitTime}ms for model ${modelId}`);
      await new Promise(resolve => setTimeout(resolve, waitTime));
    }
    
    this.lastCall.set(modelId, Date.now());
  }
  
  wrapGenerate = async ({ doGenerate, model }) => {
    await this.enforceRateLimit(model.modelId);
    return doGenerate();
  };
  
  wrapStream = async ({ doStream, model }) => {
    await this.enforceRateLimit(model.modelId);
    return doStream();
  };
}

// Parameter transformation middleware
const parameterEnhancementMiddleware: LanguageModelV2Middleware = {
  middlewareVersion: 'v2',
  
  transformParams: async ({ params, model }) => {
    // Add system message if none exists
    const hasSystemMessage = params.prompt.some(msg => msg.role === 'system');
    
    if (!hasSystemMessage) {
      return {
        ...params,
        prompt: [
          {
            role: 'system',
            content: 'You are a helpful AI assistant. Please provide accurate and helpful responses.'
          },
          ...params.prompt
        ]
      };
    }
    
    // Ensure temperature is set for creative tasks
    if (!params.temperature && params.prompt.some(msg => 
      msg.role === 'user' && 
      JSON.stringify(msg.content).toLowerCase().includes('creative')
    )) {
      return {
        ...params,
        temperature: 0.8
      };
    }
    
    return params;
  }
};

// Error handling middleware
const errorHandlingMiddleware: LanguageModelV2Middleware = {
  middlewareVersion: 'v2',
  
  wrapGenerate: async ({ doGenerate, params, model }) => {
    try {
      return await doGenerate();
    } catch (error) {
      console.error(`Generation error for ${model.modelId}:`, error);
      
      // Return a fallback response for certain error types
      if (error instanceof Error && error.message.includes('rate limit')) {
        console.log('Rate limit detected, implementing exponential backoff...');
        await new Promise(resolve => setTimeout(resolve, 2000));
        return doGenerate();
      }
      
      throw error;
    }
  },
  
  wrapStream: async ({ doStream, params, model }) => {
    try {
      return await doStream();
    } catch (error) {
      console.error(`Streaming error for ${model.modelId}:`, error);
      throw error;
    }
  }
};

// Model aliasing middleware
const modelAliasingMiddleware: LanguageModelV2Middleware = {
  middlewareVersion: 'v2',
  
  overrideModelId: ({ model }) => {
    const aliases: Record<string, string> = {
      'gpt-4-latest': 'gpt-4-turbo-2024-04-09',
      'claude-latest': 'claude-3-opus-20240229',
      'gemini-latest': 'gemini-1.5-pro'
    };
    
    return aliases[model.modelId] || model.modelId;
  }
};

// Metrics collection middleware
class MetricsMiddleware implements LanguageModelV2Middleware {
  private metrics = {
    totalCalls: 0,
    totalTokensUsed: 0,
    averageLatency: 0,
    errorCount: 0
  };
  
  middlewareVersion: 'v2' = 'v2';
  
  wrapGenerate = async ({ doGenerate, model }) => {
    const startTime = Date.now();
    this.metrics.totalCalls++;
    
    try {
      const result = await doGenerate();
      const latency = Date.now() - startTime;
      
      // Update metrics
      this.metrics.totalTokensUsed += result.usage.totalTokens || 0;
      this.metrics.averageLatency = 
        (this.metrics.averageLatency * (this.metrics.totalCalls - 1) + latency) / 
        this.metrics.totalCalls;
      
      return result;
    } catch (error) {
      this.metrics.errorCount++;
      throw error;
    }
  };
  
  getMetrics() {
    return { ...this.metrics };
  }
}

// Composing multiple middleware
class MiddlewareChain implements LanguageModelV2Middleware {
  constructor(private middlewares: LanguageModelV2Middleware[]) {}
  
  middlewareVersion: 'v2' = 'v2';
  
  transformParams = async (options) => {
    let params = options.params;
    
    for (const middleware of this.middlewares) {
      if (middleware.transformParams) {
        params = await middleware.transformParams({
          ...options,
          params
        });
      }
    }
    
    return params;
  };
  
  wrapGenerate = async (options) => {
    let wrappedGenerate = options.doGenerate;
    
    // Apply middleware in reverse order for wrapping
    for (let i = this.middlewares.length - 1; i >= 0; i--) {
      const middleware = this.middlewares[i];
      if (middleware.wrapGenerate) {
        const currentGenerate = wrappedGenerate;
        wrappedGenerate = () => middleware.wrapGenerate!({
          ...options,
          doGenerate: currentGenerate
        });
      }
    }
    
    return wrappedGenerate();
  };
}

// Usage: Applying middleware to a model
const metricsMiddleware = new MetricsMiddleware();
const cachingMiddleware = new CachingMiddleware();

const middlewareChain = new MiddlewareChain([
  loggingMiddleware,
  metricsMiddleware,
  cachingMiddleware,
  parameterEnhancementMiddleware,
  errorHandlingMiddleware
]);

// Apply middleware to model (implementation would depend on your model wrapper)
class WrappedLanguageModel implements LanguageModelV2 {
  constructor(
    private baseModel: LanguageModelV2,
    private middleware: LanguageModelV2Middleware
  ) {}
  
  get specificationVersion() { return this.baseModel.specificationVersion; }
  get provider() { 
    return this.middleware.overrideProvider?.({ model: this.baseModel }) || 
           this.baseModel.provider; 
  }
  get modelId() { 
    return this.middleware.overrideModelId?.({ model: this.baseModel }) || 
           this.baseModel.modelId; 
  }
  get supportedUrls() {
    return this.middleware.overrideSupportedUrls?.({ model: this.baseModel }) ||
           this.baseModel.supportedUrls;
  }
  
  async doGenerate(options: LanguageModelV2CallOptions) {
    // Transform parameters
    const transformedOptions = this.middleware.transformParams ?
      await this.middleware.transformParams({
        type: 'generate',
        params: options,
        model: this.baseModel
      }) : options;
    
    // Wrap generate
    if (this.middleware.wrapGenerate) {
      return this.middleware.wrapGenerate({
        doGenerate: () => this.baseModel.doGenerate(transformedOptions),
        doStream: () => this.baseModel.doStream(transformedOptions),
        params: transformedOptions,
        model: this.baseModel
      });
    }
    
    return this.baseModel.doGenerate(transformedOptions);
  }
  
  async doStream(options: LanguageModelV2CallOptions) {
    // Similar implementation for streaming...
    const transformedOptions = this.middleware.transformParams ?
      await this.middleware.transformParams({
        type: 'stream',
        params: options,
        model: this.baseModel
      }) : options;
    
    if (this.middleware.wrapStream) {
      return this.middleware.wrapStream({
        doGenerate: () => this.baseModel.doGenerate(transformedOptions),
        doStream: () => this.baseModel.doStream(transformedOptions),
        params: transformedOptions,
        model: this.baseModel
      });
    }
    
    return this.baseModel.doStream(transformedOptions);
  }
}

// Using the wrapped model
const originalModel = provider.languageModel('gpt-4');
const enhancedModel = new WrappedLanguageModel(originalModel, middlewareChain);

const result = await enhancedModel.doGenerate({
  prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello!' }] }]
});

console.log('Metrics:', metricsMiddleware.getMetrics());