docs
Guide to creating custom middleware for extending agent behavior.
import { createMiddleware } from "langchain";
import { z } from "zod";
const myMiddleware = createMiddleware({
name: "my-middleware",
// Optional: Define state schema
stateSchema: z.object({
customField: z.string(),
}),
// Optional: Define context schema
contextSchema: z.object({
requestId: z.string(),
}),
// Optional: Add tools
tools: [/* ... */],
// Lifecycle hooks
beforeAgent: async (state, runtime) => {
// Initialize state
return state;
},
wrapToolCall: async (request, handler, runtime) => {
// Intercept tool execution
const result = await handler(request);
return result;
},
});Called once at the start of agent invocation.
beforeAgent: async (state, runtime) => {
// Initialize middleware state
return {
...state,
customField: "initialized",
};
}Called before each model invocation.
beforeModel: async (state, runtime) => {
// Prepare for model call
console.log("About to call model");
return state;
}Wraps the model invocation.
wrapModelCall: async (state, handler, runtime) => {
console.log("Before model call");
const result = await handler(state);
console.log("After model call");
return result;
}Called after each model invocation.
afterModel: async (state, runtime) => {
// Process model response
return state;
}Wraps individual tool executions.
wrapToolCall: async (request, handler, runtime) => {
console.log(`Calling tool: ${request.toolName}`);
try {
const result = await handler(request);
console.log(`Tool succeeded: ${request.toolName}`);
return result;
} catch (error) {
console.error(`Tool failed: ${request.toolName}`, error);
throw error;
}
}Called once when agent completes.
afterAgent: async (state, runtime) => {
// Cleanup or finalization
console.log("Agent completed");
return state;
}const loggingMiddleware = createMiddleware({
name: "logging",
wrapToolCall: async (request, handler, runtime) => {
const start = Date.now();
console.log(`[Tool] ${request.toolName} called with:`, request.args);
const result = await handler(request);
const duration = Date.now() - start;
console.log(`[Tool] ${request.toolName} completed in ${duration}ms`);
return result;
},
wrapModelCall: async (state, handler, runtime) => {
const start = Date.now();
console.log("[Model] Calling model");
const result = await handler(state);
const duration = Date.now() - start;
console.log(`[Model] Completed in ${duration}ms`);
return result;
},
});const sessionMiddleware = createMiddleware({
name: "session",
stateSchema: z.object({
sessionId: z.string(),
startTime: z.number(),
operationCount: z.number(),
}),
beforeAgent: async (state, runtime) => {
return {
...state,
sessionId: crypto.randomUUID(),
startTime: Date.now(),
operationCount: 0,
};
},
afterModel: async (state, runtime) => {
return {
...state,
operationCount: state.operationCount + 1,
};
},
afterAgent: async (state, runtime) => {
const duration = Date.now() - state.startTime;
console.log(`Session ${state.sessionId}: ${state.operationCount} operations in ${duration}ms`);
return state;
},
});const errorHandlingMiddleware = createMiddleware({
name: "error-handling",
wrapToolCall: async (request, handler, runtime) => {
try {
return await handler(request);
} catch (error) {
console.error(`Tool ${request.toolName} failed:`, error);
// Return error as tool result
return {
content: `Error: ${error.message}`,
error: error.message,
};
}
},
wrapModelCall: async (state, handler, runtime) => {
try {
return await handler(state);
} catch (error) {
console.error("Model call failed:", error);
// Could implement fallback logic here
throw error;
}
},
});const validationMiddleware = createMiddleware({
name: "validation",
wrapToolCall: async (request, handler, runtime) => {
// Validate tool arguments
if (request.toolName === "send_email") {
if (!request.args.to || !request.args.to.includes("@")) {
return {
content: "Error: Invalid email address",
error: "Invalid email address",
};
}
}
return await handler(request);
},
});const cachingMiddleware = createMiddleware({
name: "caching",
stateSchema: z.object({
cache: z.record(z.string(), z.any()),
}),
beforeAgent: async (state, runtime) => {
return {
...state,
cache: {},
};
},
wrapToolCall: async (request, handler, runtime) => {
const cacheKey = `${request.toolName}:${JSON.stringify(request.args)}`;
// Check cache
if (state.cache[cacheKey]) {
console.log(`Cache hit for ${request.toolName}`);
return state.cache[cacheKey];
}
// Execute and cache
const result = await handler(request);
state.cache[cacheKey] = result;
return result;
},
});const rateLimitMiddleware = createMiddleware({
name: "rate-limit",
stateSchema: z.object({
toolCallTimes: z.array(z.number()),
}),
beforeAgent: async (state, runtime) => {
return {
...state,
toolCallTimes: [],
};
},
wrapToolCall: async (request, handler, runtime) => {
const now = Date.now();
const recentCalls = state.toolCallTimes.filter(t => now - t < 60000); // Last minute
if (recentCalls.length >= 10) {
throw new Error("Rate limit exceeded: max 10 tool calls per minute");
}
state.toolCallTimes.push(now);
return await handler(request);
},
});