Build LLM-powered applications in Java with support for chatbots, agents, RAG, tools, and much more
Comprehensive RAG framework for augmenting LLM responses with retrieved information. Supports query transformation, routing, content retrieval, aggregation, and injection into prompts.
Main entry point for RAG flow. Augments chat messages with retrieved content.
package dev.langchain4j.rag;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.rag.content.Content;
/**
* Augments ChatMessage with retrieved Contents
* Entry point for RAG flow in LangChain4j
*/
public interface RetrievalAugmentor {
/**
* Augment ChatMessage with retrieved contents
* @param augmentationRequest Request containing message and metadata
* @return Result with augmented message
*/
AugmentationResult augment(AugmentationRequest augmentationRequest);
}Thread Safety: Implementations must be thread-safe as they may be called concurrently by multiple threads in a chat application. The default implementation is thread-safe.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
DefaultRetrievalAugmentor - Standard implementationContentRetriever - Core retrieval interfaceAugmentationRequest/Result - Request/response typesDefault RAG implementation orchestrating the complete retrieval flow.
package dev.langchain4j.rag;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.injector.ContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import java.util.concurrent.Executor;
/**
* Default RetrievalAugmentor implementation
* Orchestrates: QueryTransformer -> QueryRouter -> ContentRetriever
* -> ContentAggregator -> ContentInjector
* Suitable for majority of RAG use cases
*/
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
/**
* Create DefaultRetrievalAugmentor
* @param queryTransformer Transforms query into one or more queries
* @param queryRouter Routes queries to appropriate retrievers
* @param contentAggregator Aggregates/ranks retrieved contents
* @param contentInjector Injects contents into message
* @param executor Executor for parallel retrieval
*/
public DefaultRetrievalAugmentor(
QueryTransformer queryTransformer,
QueryRouter queryRouter,
ContentAggregator contentAggregator,
ContentInjector contentInjector,
Executor executor
);
/**
* Create builder for DefaultRetrievalAugmentor
* @return Builder instance
*/
public static Builder builder();
public static class Builder {
/**
* Set query transformer
* @param queryTransformer Query transformer
* @return Builder
*/
public Builder queryTransformer(QueryTransformer queryTransformer);
/**
* Set query router
* @param queryRouter Query router
* @return Builder
*/
public Builder queryRouter(QueryRouter queryRouter);
/**
* Set single content retriever
* @param contentRetriever Content retriever
* @return Builder
*/
public Builder contentRetriever(ContentRetriever contentRetriever);
/**
* Set content aggregator
* @param contentAggregator Content aggregator
* @return Builder
*/
public Builder contentAggregator(ContentAggregator contentAggregator);
/**
* Set content injector
* @param contentInjector Content injector
* @return Builder
*/
public Builder contentInjector(ContentInjector contentInjector);
/**
* Set executor for parallel processing
* @param executor Executor instance
* @return Builder
*/
public Builder executor(Executor executor);
/**
* Build DefaultRetrievalAugmentor
* @return DefaultRetrievalAugmentor instance
*/
public DefaultRetrievalAugmentor build();
}
}Thread Safety: Thread-safe. Can be used concurrently across multiple threads. Each augmentation is independent.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
RetrievalAugmentor - Base interfaceQueryTransformer - Query transformationQueryRouter - Query routingContentRetriever - Content retrievalContentAggregator - Content aggregationContentInjector - Content injectionRequest and response types for RAG augmentation.
package dev.langchain4j.rag;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.query.Metadata;
/**
* Request for augmentation containing message and metadata
*/
public class AugmentationRequest {
/**
* Get chat message to augment
* @return ChatMessage to augment
*/
public ChatMessage chatMessage();
/**
* Get metadata for augmentation
* @return Metadata
*/
public Metadata metadata();
/**
* Create builder
* @return Builder instance
*/
public static Builder builder();
public static class Builder {
/**
* Set chat message
* @param chatMessage Chat message
* @return Builder
*/
public Builder chatMessage(ChatMessage chatMessage);
/**
* Set metadata
* @param metadata Metadata
* @return Builder
*/
public Builder metadata(Metadata metadata);
/**
* Build request
* @return AugmentationRequest
*/
public AugmentationRequest build();
}
}
/**
* Result of augmentation containing augmented message
*/
public class AugmentationResult {
/**
* Get augmented chat message
* @return Augmented ChatMessage with injected content
*/
public UserMessage chatMessage();
/**
* Create builder
* @return Builder instance
*/
public static Builder builder();
public static class Builder {
/**
* Set chat message
* @param chatMessage Augmented chat message
* @return Builder
*/
public Builder chatMessage(UserMessage chatMessage);
/**
* Build result
* @return AugmentationResult
*/
public AugmentationResult build();
}
}Thread Safety: Immutable after construction. Safe for concurrent access.
Common Pitfalls:
Edge Cases:
Exception Handling:
Related APIs:
RetrievalAugmentor.augment() - Uses these typesMetadata - Query metadataChatMessage/UserMessage - Message typesRepresents a query used for content retrieval.
package dev.langchain4j.rag.query;
/**
* Represents a query for content retrieval
* Contains query text and optional metadata
*/
public class Query {
/**
* Get query text
* @return Query text string
*/
public String text();
/**
* Get query metadata
* @return Query metadata
*/
public Metadata metadata();
/**
* Create query from text
* @param text Query text
* @return Query instance
*/
public static Query from(String text);
/**
* Create query from text and metadata
* @param text Query text
* @param metadata Query metadata
* @return Query instance
*/
public static Query from(String text, Metadata metadata);
}Thread Safety: Immutable. Safe for concurrent use across threads.
Common Pitfalls:
Edge Cases:
Performance Notes:
Exception Handling:
Related APIs:
QueryTransformer - Transforms queriesContentRetriever.retrieve() - Accepts QueryMetadata - Query metadataCore interface for retrieving relevant content from data sources.
package dev.langchain4j.rag.content.retriever;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import java.util.List;
/**
* Retrieves relevant Contents from underlying data source using Query
* Data source can be: embedding store, full-text search, web search,
* knowledge graph, SQL database, etc.
*/
public interface ContentRetriever {
/**
* Retrieve relevant contents for query
* @param query Query to use for retrieval
* @return List of retrieved contents, sorted by relevance
*/
List<Content> retrieve(Query query);
}Thread Safety: Implementations must be thread-safe. Standard implementations (EmbeddingStoreContentRetriever, WebSearchContentRetriever) are thread-safe.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
EmbeddingStoreContentRetriever - Embedding-based implementationWebSearchContentRetriever - Web search implementationQueryRouter - Routes queries to retrieversContent - Retrieved content typeContent retriever backed by an embedding store for semantic search.
package dev.langchain4j.rag.content.retriever;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
/**
* ContentRetriever that retrieves from EmbeddingStore
* Uses EmbeddingModel to embed queries for semantic search
*/
public class EmbeddingStoreContentRetriever implements ContentRetriever {
/**
* Create builder
* @return Builder instance
*/
public static Builder builder();
public static class Builder {
/**
* Set embedding store
* @param embeddingStore Embedding store to search
* @return Builder
*/
public Builder embeddingStore(EmbeddingStore<?> embeddingStore);
/**
* Set embedding model
* @param embeddingModel Model to embed queries
* @return Builder
*/
public Builder embeddingModel(EmbeddingModel embeddingModel);
/**
* Set maximum results
* @param maxResults Maximum number of results to return
* @return Builder
*/
public Builder maxResults(int maxResults);
/**
* Set minimum score threshold
* @param minScore Minimum similarity score (0.0 to 1.0)
* @return Builder
*/
public Builder minScore(double minScore);
/**
* Set metadata filter
* @param filter Filter for metadata
* @return Builder
*/
public Builder filter(Filter filter);
/**
* Build retriever
* @return EmbeddingStoreContentRetriever
*/
public EmbeddingStoreContentRetriever build();
}
}Thread Safety: Thread-safe. The underlying EmbeddingStore and EmbeddingModel must also be thread-safe (standard implementations are).
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
EmbeddingStore - Stores and searches embeddingsEmbeddingModel - Generates query embeddingsFilter - Metadata filteringContent - Returned content typeembedding-store.md for EmbeddingStore detailsContent retriever backed by a web search engine.
package dev.langchain4j.rag.content.retriever;
import dev.langchain4j.web.search.WebSearchEngine;
/**
* ContentRetriever that retrieves from web search engine
* Converts web search results to Content objects
*/
public class WebSearchContentRetriever implements ContentRetriever {
/**
* Create WebSearchContentRetriever
* @param webSearchEngine Web search engine to use
*/
public WebSearchContentRetriever(WebSearchEngine webSearchEngine);
/**
* Create WebSearchContentRetriever with max results
* @param webSearchEngine Web search engine to use
* @param maxResults Maximum results to retrieve
*/
public WebSearchContentRetriever(WebSearchEngine webSearchEngine, int maxResults);
/**
* Create builder
* @return Builder instance
*/
public static Builder builder();
public static class Builder {
/**
* Set web search engine
* @param webSearchEngine Web search engine
* @return Builder
*/
public Builder webSearchEngine(WebSearchEngine webSearchEngine);
/**
* Set maximum results
* @param maxResults Maximum results
* @return Builder
*/
public Builder maxResults(int maxResults);
/**
* Build retriever
* @return WebSearchContentRetriever
*/
public WebSearchContentRetriever build();
}
}Thread Safety: Thread-safe. The underlying WebSearchEngine must also be thread-safe (standard implementations are).
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
WebSearchEngine - Performs web searchesContent - Returned content typeEmbeddingStoreContentRetriever - Alternative retrieval methodTransforms original query into one or more queries for retrieval.
package dev.langchain4j.rag.query.transformer;
import dev.langchain4j.rag.query.Query;
import java.util.Collection;
/**
* Transforms Query into one or more Queries
* Examples: query expansion, query compression, multi-query generation
*/
public interface QueryTransformer {
/**
* Transform query
* @param query Original query
* @return Collection of transformed queries
*/
Collection<Query> transform(Query query);
}Thread Safety: Implementations must be thread-safe. LLM-based transformers typically are thread-safe.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
Query - Query objectDefaultRetrievalAugmentor - Uses QueryTransformerQueryRouter - Next stage after transformationRoutes queries to appropriate content retrievers.
package dev.langchain4j.rag.query.router;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import java.util.Collection;
/**
* Routes Query to appropriate ContentRetrievers
* Enables retrieval from multiple specialized data sources
*/
public interface QueryRouter {
/**
* Route query to retrievers
* @param query Query to route
* @return Collection of ContentRetrievers for this query
*/
Collection<ContentRetriever> route(Query query);
}Thread Safety: Implementations must be thread-safe. Can be called concurrently for different queries.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
ContentRetriever - Routed to retrieversQuery - Input queryDefaultRetrievalAugmentor - Uses QueryRouterAggregates and ranks retrieved contents.
package dev.langchain4j.rag.content.aggregator;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import java.util.List;
import java.util.Map;
/**
* Aggregates Contents retrieved by multiple ContentRetrievers
* Can perform: fusion, re-ranking, filtering, deduplication
*/
public interface ContentAggregator {
/**
* Aggregate contents from multiple retrievers/queries
* @param contents Map of queries to their retrieved contents
* @return Aggregated and ranked list of contents
*/
List<Content> aggregate(Map<Query, Collection<List<Content>>> contents);
}Thread Safety: Implementations must be thread-safe. Called once per augmentation with results from all retrievers.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
Content - Aggregated content typeQuery - Query that generated resultsDefaultRetrievalAugmentor - Uses ContentAggregatorInjects retrieved contents into user message.
package dev.langchain4j.rag.content.injector;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import java.util.List;
/**
* Injects retrieved Contents into UserMessage
* Formats contents as context for the LLM
*/
public interface ContentInjector {
/**
* Inject contents into user message
* @param contents Retrieved contents to inject
* @param userMessage Original user message
* @return UserMessage with injected contents
*/
UserMessage inject(List<Content> contents, UserMessage userMessage);
}Thread Safety: Implementations must be thread-safe. Can be called concurrently for different messages.
Common Pitfalls:
Edge Cases:
Performance Notes:
Cost Considerations:
Exception Handling:
Related APIs:
Content - Injected content typeUserMessage - Message with injected contentDefaultRetrievalAugmentor - Uses ContentInjectorRepresents a piece of retrieved content.
package dev.langchain4j.rag.content;
import dev.langchain4j.data.segment.TextSegment;
/**
* Represents a piece of content retrieved during RAG
* Contains text and optional metadata
*/
public interface Content {
/**
* Get text content
* @return Text string
*/
String textSegment();
/**
* Create content from text
* @param text Content text
* @return Content instance
*/
public static Content from(String text);
/**
* Create content from TextSegment
* @param textSegment Text segment with metadata
* @return Content instance
*/
public static Content from(TextSegment textSegment);
}Thread Safety: Immutable. Safe for concurrent access.
Common Pitfalls:
Edge Cases:
Exception Handling:
Related APIs:
TextSegment - Rich text with metadataContentRetriever.retrieve() - Returns Content listContentInjector.inject() - Injects ContentComplete RAG setup with embedding store retrieval:
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.service.AiServices;
// Create embedding store retriever
EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.build();
// Create retrieval augmentor
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.build();
// Integrate with AI service
Assistant assistant = AiServices.builder(Assistant.class)
.chatModel(chatModel)
.retrievalAugmentor(augmentor)
.build();
// Use assistant - RAG happens automatically
String response = assistant.chat("What are the main features?");Production-ready RAG implementation with comprehensive error handling:
import dev.langchain4j.rag.*;
import dev.langchain4j.rag.content.retriever.*;
import dev.langchain4j.rag.content.aggregator.*;
import dev.langchain4j.rag.content.injector.*;
import dev.langchain4j.model.embedding.*;
import dev.langchain4j.store.embedding.*;
import dev.langchain4j.service.AiServices;
import java.util.concurrent.*;
import java.util.logging.Logger;
public class ProductionRAGSetup {
private static final Logger logger = Logger.getLogger(ProductionRAGSetup.class.getName());
public static Assistant createRobustAssistant(
ChatLanguageModel chatModel,
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel
) {
// Create dedicated thread pool for parallel retrieval
ExecutorService executor = Executors.newFixedThreadPool(
4,
r -> {
Thread t = new Thread(r);
t.setDaemon(true); // Don't prevent JVM shutdown
t.setName("rag-retrieval-" + t.getId());
return t;
}
);
// Wrap embedding model with retry logic
EmbeddingModel resilientEmbeddingModel = new ResilientEmbeddingModel(
embeddingModel,
3, // max retries
1000 // initial backoff ms
);
// Create retriever with defensive settings
EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(resilientEmbeddingModel)
.maxResults(5) // Balance context vs. tokens
.minScore(0.65) // Not too high (0.65-0.75 is good range)
.build();
// Create content aggregator with deduplication
ContentAggregator aggregator = new DefaultContentAggregator();
// Create content injector with token limiting
ContentInjector injector = new TokenLimitingContentInjector(
4000, // Max tokens for injected content
embeddingModel
);
// Build retrieval augmentor
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.contentAggregator(aggregator)
.contentInjector(injector)
.executor(executor)
.build();
// Create AI service with RAG
return AiServices.builder(Assistant.class)
.chatModel(chatModel)
.retrievalAugmentor(augmentor)
.build();
}
// Resilient embedding model wrapper
static class ResilientEmbeddingModel implements EmbeddingModel {
private final EmbeddingModel delegate;
private final int maxRetries;
private final long initialBackoffMs;
ResilientEmbeddingModel(EmbeddingModel delegate, int maxRetries, long initialBackoffMs) {
this.delegate = delegate;
this.maxRetries = maxRetries;
this.initialBackoffMs = initialBackoffMs;
}
@Override
public Response<Embedding> embed(String text) {
Exception lastException = null;
for (int i = 0; i < maxRetries; i++) {
try {
return delegate.embed(text);
} catch (Exception e) {
lastException = e;
logger.warning("Embedding failed (attempt " + (i+1) + "/" + maxRetries + "): " + e.getMessage());
if (i < maxRetries - 1) {
try {
Thread.sleep(initialBackoffMs * (1L << i)); // Exponential backoff
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted during retry backoff", ie);
}
}
}
}
throw new RuntimeException("Embedding failed after " + maxRetries + " retries", lastException);
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
// Similar retry logic for batch embedding
Exception lastException = null;
for (int i = 0; i < maxRetries; i++) {
try {
return delegate.embedAll(textSegments);
} catch (Exception e) {
lastException = e;
logger.warning("Batch embedding failed (attempt " + (i+1) + "/" + maxRetries + "): " + e.getMessage());
if (i < maxRetries - 1) {
try {
Thread.sleep(initialBackoffMs * (1L << i));
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted during retry backoff", ie);
}
}
}
}
throw new RuntimeException("Batch embedding failed after " + maxRetries + " retries", lastException);
}
}
// Token-limiting content injector
static class TokenLimitingContentInjector implements ContentInjector {
private final int maxTokens;
private final EmbeddingModel embeddingModel;
TokenLimitingContentInjector(int maxTokens, EmbeddingModel embeddingModel) {
this.maxTokens = maxTokens;
this.embeddingModel = embeddingModel;
}
@Override
public UserMessage inject(List<Content> contents, UserMessage userMessage) {
if (contents == null || contents.isEmpty()) {
return userMessage;
}
StringBuilder contextBuilder = new StringBuilder();
contextBuilder.append("Context information:\n\n");
int tokenCount = 0;
int includedCount = 0;
for (Content content : contents) {
String text = content.textSegment();
int estimatedTokens = estimateTokens(text);
if (tokenCount + estimatedTokens > maxTokens) {
logger.info("Token limit reached. Included " + includedCount + " of " + contents.size() + " contents.");
break;
}
contextBuilder.append("---\n");
contextBuilder.append(text);
contextBuilder.append("\n\n");
tokenCount += estimatedTokens;
includedCount++;
}
contextBuilder.append("---\n\n");
contextBuilder.append("Query: ");
contextBuilder.append(userMessage.singleText());
return UserMessage.from(contextBuilder.toString());
}
private int estimateTokens(String text) {
// Rough estimate: 1 token ≈ 4 characters
return text.length() / 4;
}
}
}Techniques for optimizing RAG queries:
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import dev.langchain4j.rag.query.Query;
import java.util.*;
/**
* Expands queries with synonyms and related terms
* Improves recall by generating multiple query variations
*/
public class SynonymQueryExpander implements QueryTransformer {
private final Map<String, List<String>> synonymMap;
public SynonymQueryExpander(Map<String, List<String>> synonymMap) {
this.synonymMap = synonymMap;
}
@Override
public Collection<Query> transform(Query query) {
String text = query.text().toLowerCase();
List<Query> expanded = new ArrayList<>();
expanded.add(query); // Always include original
// Generate synonym variations
for (Map.Entry<String, List<String>> entry : synonymMap.entrySet()) {
if (text.contains(entry.getKey())) {
for (String synonym : entry.getValue()) {
String expandedText = text.replace(entry.getKey(), synonym);
expanded.add(Query.from(expandedText, query.metadata()));
}
}
}
// Limit to 3-5 variations to control cost
return expanded.subList(0, Math.min(5, expanded.size()));
}
}
// Usage
Map<String, List<String>> synonyms = Map.of(
"error", List.of("exception", "failure", "issue"),
"configure", List.of("setup", "initialize", "set up")
);
QueryTransformer expander = new SynonymQueryExpander(synonyms);
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(expander)
.contentRetriever(retriever)
.build();import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
/**
* Filter retrieval by metadata for precision
* Reduces irrelevant results and improves response quality
*/
public class MetadataFilteredRetriever {
public static EmbeddingStoreContentRetriever createCategoryFilteredRetriever(
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
String category
) {
// Create filter for specific category
Filter filter = new IsEqualTo("category", category);
return EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.filter(filter) // Only retrieve from this category
.build();
}
public static EmbeddingStoreContentRetriever createDateRangeRetriever(
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
LocalDate startDate,
LocalDate endDate
) {
// Filter by date range
Filter filter = Filter.and(
new IsGreaterThanOrEqualTo("date", startDate.toString()),
new IsLessThanOrEqualTo("date", endDate.toString())
);
return EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.filter(filter)
.build();
}
}
// Usage
EmbeddingStoreContentRetriever retriever = MetadataFilteredRetriever
.createCategoryFilteredRetriever(embeddingStore, embeddingModel, "documentation");import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import java.util.*;
/**
* Routes queries to specialized retrievers based on query content
* Improves precision by using domain-specific data sources
*/
public class DynamicQueryRouter implements QueryRouter {
private final Map<String, ContentRetriever> retrieversByKeyword;
private final ContentRetriever defaultRetriever;
public DynamicQueryRouter(
Map<String, ContentRetriever> retrieversByKeyword,
ContentRetriever defaultRetriever
) {
this.retrieversByKeyword = retrieversByKeyword;
this.defaultRetriever = defaultRetriever;
}
@Override
public Collection<ContentRetriever> route(Query query) {
String text = query.text().toLowerCase();
List<ContentRetriever> retrievers = new ArrayList<>();
// Check for keyword matches
for (Map.Entry<String, ContentRetriever> entry : retrieversByKeyword.entrySet()) {
if (text.contains(entry.getKey())) {
retrievers.add(entry.getValue());
}
}
// If no match, use default
if (retrievers.isEmpty()) {
retrievers.add(defaultRetriever);
}
return retrievers;
}
}
// Usage
Map<String, ContentRetriever> retrievers = Map.of(
"api", apiDocumentationRetriever,
"tutorial", tutorialRetriever,
"error", troubleshootingRetriever
);
QueryRouter router = new DynamicQueryRouter(retrievers, generalRetriever);
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(router)
.build();import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import java.util.List;
import java.util.stream.Collectors;
/**
* Adjusts minimum score threshold based on query characteristics
* Short/simple queries use lower threshold, specific queries use higher
*/
public class AdaptiveScoreRetriever implements ContentRetriever {
private final EmbeddingStoreContentRetriever delegate;
private final double baseMinScore;
public AdaptiveScoreRetriever(
EmbeddingStoreContentRetriever delegate,
double baseMinScore
) {
this.delegate = delegate;
this.baseMinScore = baseMinScore;
}
@Override
public List<Content> retrieve(Query query) {
// Retrieve with no score filter first
List<Content> allResults = delegate.retrieve(query);
// Calculate adaptive threshold
double adaptiveScore = calculateAdaptiveScore(query);
// Filter results by adaptive score
return allResults.stream()
.filter(content -> getScore(content) >= adaptiveScore)
.collect(Collectors.toList());
}
private double calculateAdaptiveScore(Query query) {
String text = query.text();
int wordCount = text.split("\\s+").length;
// Short queries (1-3 words): lower threshold
if (wordCount <= 3) {
return baseMinScore - 0.1;
}
// Medium queries (4-10 words): base threshold
else if (wordCount <= 10) {
return baseMinScore;
}
// Long queries (>10 words): higher threshold
else {
return baseMinScore + 0.05;
}
}
private double getScore(Content content) {
// Extract score from content metadata
// Implementation depends on how scores are stored
return 0.0; // Placeholder
}
}The DefaultRetrievalAugmentor orchestrates this flow:
Strategies for testing RAG implementations:
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.content.Content;
import java.util.List;
public class ContentRetrieverTest {
@Test
void testRetrieverReturnsResults() {
// Arrange
ContentRetriever retriever = createTestRetriever();
Query query = Query.from("test query");
// Act
List<Content> results = retriever.retrieve(query);
// Assert
assertNotNull(results);
assertFalse(results.isEmpty());
assertTrue(results.size() <= 5); // Respects maxResults
}
@Test
void testRetrieverWithEmptyQuery() {
// Arrange
ContentRetriever retriever = createTestRetriever();
Query query = Query.from("");
// Act
List<Content> results = retriever.retrieve(query);
// Assert
assertNotNull(results);
assertTrue(results.isEmpty()); // Empty query returns no results
}
@Test
void testRetrieverRespectsMinScore() {
// Arrange
double minScore = 0.8;
EmbeddingStoreContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(10)
.minScore(minScore)
.build();
Query query = Query.from("specific technical term");
// Act
List<Content> results = retriever.retrieve(query);
// Assert
// All results should have score >= minScore
// (In practice, you'd need to inspect Content metadata for scores)
assertNotNull(results);
}
}import org.junit.jupiter.api.Test;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.data.message.UserMessage;
public class RAGIntegrationTest {
@Test
void testEndToEndRAGFlow() {
// Arrange
RetrievalAugmentor augmentor = createTestAugmentor();
UserMessage message = UserMessage.from("What is the API for embeddings?");
AugmentationRequest request = AugmentationRequest.builder()
.chatMessage(message)
.build();
// Act
AugmentationResult result = augmentor.augment(request);
// Assert
assertNotNull(result);
assertNotNull(result.chatMessage());
String augmentedText = result.chatMessage().singleText();
assertTrue(augmentedText.contains("Context information:")); // Injected content
assertTrue(augmentedText.contains("What is the API for embeddings?")); // Original query
}
@Test
void testRAGWithNoMatchingContent() {
// Arrange
RetrievalAugmentor augmentor = createTestAugmentor();
UserMessage message = UserMessage.from("completely irrelevant query xyz123");
AugmentationRequest request = AugmentationRequest.builder()
.chatMessage(message)
.build();
// Act
AugmentationResult result = augmentor.augment(request);
// Assert
assertNotNull(result);
// Message should pass through unchanged if no content retrieved
assertEquals(message.singleText(), result.chatMessage().singleText());
}
}import org.junit.jupiter.api.Test;
import java.time.Duration;
import java.time.Instant;
import static org.junit.jupiter.api.Assertions.*;
public class RAGPerformanceTest {
@Test
void testRetrievalLatency() {
// Arrange
ContentRetriever retriever = createTestRetriever();
Query query = Query.from("test query");
// Act
Instant start = Instant.now();
List<Content> results = retriever.retrieve(query);
Instant end = Instant.now();
// Assert
Duration latency = Duration.between(start, end);
assertTrue(latency.toMillis() < 1000,
"Retrieval should complete in <1s, took: " + latency.toMillis() + "ms");
}
@Test
void testParallelRetrievalPerformance() throws Exception {
// Arrange
ExecutorService executor = Executors.newFixedThreadPool(4);
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.executor(executor)
.build();
// Act - submit 100 concurrent requests
Instant start = Instant.now();
List<Future<AugmentationResult>> futures = new ArrayList<>();
for (int i = 0; i < 100; i++) {
futures.add(executor.submit(() -> {
AugmentationRequest request = AugmentationRequest.builder()
.chatMessage(UserMessage.from("test query " + Math.random()))
.build();
return augmentor.augment(request);
}));
}
// Wait for all to complete
for (Future<AugmentationResult> future : futures) {
future.get();
}
Instant end = Instant.now();
// Assert
Duration totalTime = Duration.between(start, end);
double avgLatencyMs = totalTime.toMillis() / 100.0;
assertTrue(avgLatencyMs < 500,
"Average latency should be <500ms, was: " + avgLatencyMs + "ms");
executor.shutdown();
}
}import org.junit.jupiter.api.Test;
import java.util.Map;
public class RAGQualityTest {
// Golden dataset of queries and expected content
private static final Map<String, String> GOLDEN_QUERIES = Map.of(
"How do I embed text?", "EmbeddingModel",
"What is RAG?", "Retrieval Augmented Generation",
"How to store embeddings?", "EmbeddingStore"
);
@Test
void testRetrievalQuality() {
// Arrange
ContentRetriever retriever = createTestRetriever();
int correctRetrievals = 0;
// Act
for (Map.Entry<String, String> entry : GOLDEN_QUERIES.entrySet()) {
Query query = Query.from(entry.getKey());
List<Content> results = retriever.retrieve(query);
// Check if expected content is in top 3 results
boolean foundExpected = results.stream()
.limit(3)
.anyMatch(content -> content.textSegment().contains(entry.getValue()));
if (foundExpected) {
correctRetrievals++;
}
}
// Assert
double accuracy = (double) correctRetrievals / GOLDEN_QUERIES.size();
assertTrue(accuracy >= 0.8,
"Retrieval accuracy should be >=80%, was: " + (accuracy * 100) + "%");
}
}Strategies for handling failures in RAG pipelines:
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import java.util.logging.Logger;
/**
* RAG augmentor that falls back to no augmentation on failure
* Ensures system remains functional even when retrieval fails
*/
public class GracefullyDegradingAugmentor implements RetrievalAugmentor {
private static final Logger logger = Logger.getLogger(GracefullyDegradingAugmentor.class.getName());
private final RetrievalAugmentor delegate;
public GracefullyDegradingAugmentor(RetrievalAugmentor delegate) {
this.delegate = delegate;
}
@Override
public AugmentationResult augment(AugmentationRequest request) {
try {
return delegate.augment(request);
} catch (Exception e) {
logger.severe("RAG augmentation failed, falling back to no augmentation: " + e.getMessage());
// Return original message unchanged
return AugmentationResult.builder()
.chatMessage(UserMessage.from(request.chatMessage().text()))
.build();
}
}
}
// Usage
RetrievalAugmentor robustAugmentor = new GracefullyDegradingAugmentor(
DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.build()
);import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* Circuit breaker for retrieval operations
* Prevents cascading failures by stopping requests after threshold
*/
public class CircuitBreakerRetriever implements ContentRetriever {
private static final Logger logger = Logger.getLogger(CircuitBreakerRetriever.class.getName());
private final ContentRetriever delegate;
private final int failureThreshold;
private final long resetTimeoutMs;
private final AtomicInteger consecutiveFailures = new AtomicInteger(0);
private final AtomicLong lastFailureTime = new AtomicLong(0);
private enum State { CLOSED, OPEN, HALF_OPEN }
private volatile State state = State.CLOSED;
public CircuitBreakerRetriever(
ContentRetriever delegate,
int failureThreshold,
long resetTimeoutMs
) {
this.delegate = delegate;
this.failureThreshold = failureThreshold;
this.resetTimeoutMs = resetTimeoutMs;
}
@Override
public List<Content> retrieve(Query query) {
if (state == State.OPEN) {
// Check if we should try again
if (System.currentTimeMillis() - lastFailureTime.get() > resetTimeoutMs) {
state = State.HALF_OPEN;
logger.info("Circuit breaker entering HALF_OPEN state");
} else {
logger.warning("Circuit breaker is OPEN, returning empty results");
return Collections.emptyList();
}
}
try {
List<Content> results = delegate.retrieve(query);
onSuccess();
return results;
} catch (Exception e) {
onFailure();
throw e;
}
}
private void onSuccess() {
consecutiveFailures.set(0);
if (state == State.HALF_OPEN) {
state = State.CLOSED;
logger.info("Circuit breaker reset to CLOSED state");
}
}
private void onFailure() {
int failures = consecutiveFailures.incrementAndGet();
lastFailureTime.set(System.currentTimeMillis());
if (failures >= failureThreshold) {
state = State.OPEN;
logger.severe("Circuit breaker opened after " + failures + " consecutive failures");
}
}
}
// Usage
ContentRetriever resilientRetriever = new CircuitBreakerRetriever(
embeddingStoreRetriever,
5, // Open after 5 failures
30000 // Reset after 30 seconds
);/**
* Retries retrieval with exponential backoff on transient failures
*/
public class RetryableRetriever implements ContentRetriever {
private static final Logger logger = Logger.getLogger(RetryableRetriever.class.getName());
private final ContentRetriever delegate;
private final int maxRetries;
private final long initialBackoffMs;
public RetryableRetriever(
ContentRetriever delegate,
int maxRetries,
long initialBackoffMs
) {
this.delegate = delegate;
this.maxRetries = maxRetries;
this.initialBackoffMs = initialBackoffMs;
}
@Override
public List<Content> retrieve(Query query) {
Exception lastException = null;
for (int attempt = 0; attempt < maxRetries; attempt++) {
try {
return delegate.retrieve(query);
} catch (Exception e) {
lastException = e;
if (!isRetriable(e)) {
logger.severe("Non-retriable exception, not retrying: " + e.getMessage());
throw new RuntimeException("Retrieval failed", e);
}
if (attempt < maxRetries - 1) {
long backoffMs = initialBackoffMs * (1L << attempt);
logger.warning("Retrieval failed (attempt " + (attempt + 1) + "/" + maxRetries +
"), retrying in " + backoffMs + "ms: " + e.getMessage());
try {
Thread.sleep(backoffMs);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted during retry", ie);
}
}
}
}
throw new RuntimeException("Retrieval failed after " + maxRetries + " attempts", lastException);
}
private boolean isRetriable(Exception e) {
// Retry on network errors, timeouts, rate limits
String message = e.getMessage();
return message != null && (
message.contains("timeout") ||
message.contains("network") ||
message.contains("rate limit") ||
message.contains("503") ||
message.contains("429")
);
}
}
// Usage
ContentRetriever retriever = new RetryableRetriever(
embeddingStoreRetriever,
3, // Max 3 retries
1000 // Start with 1s backoff
);/**
* Tries multiple retrievers in sequence until one succeeds
*/
public class FallbackChainRetriever implements ContentRetriever {
private static final Logger logger = Logger.getLogger(FallbackChainRetriever.class.getName());
private final List<ContentRetriever> retrievers;
public FallbackChainRetriever(List<ContentRetriever> retrievers) {
if (retrievers.isEmpty()) {
throw new IllegalArgumentException("Must provide at least one retriever");
}
this.retrievers = retrievers;
}
@Override
public List<Content> retrieve(Query query) {
Exception lastException = null;
for (int i = 0; i < retrievers.size(); i++) {
ContentRetriever retriever = retrievers.get(i);
try {
logger.info("Trying retriever " + (i + 1) + "/" + retrievers.size());
List<Content> results = retriever.retrieve(query);
if (!results.isEmpty()) {
logger.info("Retriever " + (i + 1) + " succeeded with " + results.size() + " results");
return results;
}
logger.info("Retriever " + (i + 1) + " returned no results, trying next");
} catch (Exception e) {
lastException = e;
logger.warning("Retriever " + (i + 1) + " failed: " + e.getMessage());
}
}
if (lastException != null) {
throw new RuntimeException("All retrievers failed", lastException);
}
return Collections.emptyList();
}
}
// Usage
ContentRetriever fallbackRetriever = new FallbackChainRetriever(List.of(
primaryEmbeddingRetriever, // Try embedding search first
webSearchRetriever, // Fall back to web search
cachedRetriever // Finally try cache
));import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
/**
* Retriever that tracks metrics and alerts on anomalies
*/
public class MonitoredRetriever implements ContentRetriever {
private static final Logger logger = Logger.getLogger(MonitoredRetriever.class.getName());
private final ContentRetriever delegate;
private final LongAdder totalRequests = new LongAdder();
private final LongAdder totalFailures = new LongAdder();
private final LongAdder totalLatencyMs = new LongAdder();
private final AtomicLong maxLatencyMs = new AtomicLong(0);
// Alert thresholds
private final double failureRateThreshold = 0.1; // 10%
private final long latencyThresholdMs = 2000; // 2 seconds
public MonitoredRetriever(ContentRetriever delegate) {
this.delegate = delegate;
}
@Override
public List<Content> retrieve(Query query) {
totalRequests.increment();
long startMs = System.currentTimeMillis();
try {
List<Content> results = delegate.retrieve(query);
recordSuccess(startMs);
return results;
} catch (Exception e) {
recordFailure(startMs);
throw e;
}
}
private void recordSuccess(long startMs) {
long latencyMs = System.currentTimeMillis() - startMs;
totalLatencyMs.add(latencyMs);
// Update max latency
maxLatencyMs.updateAndGet(current -> Math.max(current, latencyMs));
if (latencyMs > latencyThresholdMs) {
logger.warning("High latency detected: " + latencyMs + "ms (threshold: " + latencyThresholdMs + "ms)");
}
}
private void recordFailure(long startMs) {
totalFailures.increment();
recordSuccess(startMs); // Still record latency
long requests = totalRequests.sum();
long failures = totalFailures.sum();
double failureRate = (double) failures / requests;
if (failureRate > failureRateThreshold) {
logger.severe("High failure rate detected: " +
String.format("%.2f%%", failureRate * 100) +
" (" + failures + "/" + requests + " requests failed)");
}
}
public void logMetrics() {
long requests = totalRequests.sum();
long failures = totalFailures.sum();
long avgLatency = requests > 0 ? totalLatencyMs.sum() / requests : 0;
logger.info("RAG Retrieval Metrics:\n" +
" Total Requests: " + requests + "\n" +
" Failures: " + failures + " (" +
String.format("%.2f%%", (double)failures/requests * 100) + ")\n" +
" Avg Latency: " + avgLatency + "ms\n" +
" Max Latency: " + maxLatencyMs.get() + "ms");
}
}
// Usage
MonitoredRetriever retriever = new MonitoredRetriever(embeddingStoreRetriever);
// Log metrics periodically
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.scheduleAtFixedRate(
() -> retriever.logMetrics(),
1, 1, TimeUnit.MINUTES
);embedding-store.md for details on storing and searching embeddingsembedding.md for embedding generation APIschat.md for LLM integrationai-services.md for high-level service integrationfilters.md for metadata filtering capabilitiesInstall with Tessl CLI
npx tessl i tessl/maven-dev-langchain4j--langchain4j