Common AI framework utilities for the Embabel Agent system including LLM configuration, output converters, prompt contributors, and embedding service abstractions.
Common patterns for integrating Embabel Agent Common with Spring, reactive streams, and LLM clients.
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import com.fasterxml.jackson.databind.ObjectMapper
@Configuration
class AIConfiguration {
@Bean
fun objectMapper(): ObjectMapper {
return ObjectMapper()
.registerKotlinModule()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
}
@Bean
fun defaultLlmOptions(): LlmOptions {
return LlmOptions.withModel("gpt-4")
.withTemperature(0.7)
.withMaxTokens(1000)
}
@Bean
fun embeddingService(embeddingModel: EmbeddingModel): EmbeddingService {
return SpringAiEmbeddingService(
name = "text-embedding-ada-002",
provider = "openai",
model = embeddingModel
)
}
}import org.springframework.stereotype.Service
@Service
class LLMService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun <T> callWithStructuredOutput(
prompt: String,
responseType: Class<T>,
options: LlmOptions = LlmOptions.withDefaultLlm()
): T? {
val converter = JacksonOutputConverter(responseType, objectMapper)
val fullPrompt = """
$prompt
${converter.getFormat()}
""".trimIndent()
val response = chatClient
.prompt()
.user(fullPrompt)
.call()
.content()
return converter.convert(response)
}
}@Component
class SystemPromptContributor : PromptContributor {
override val role = "system"
override val promptContributionLocation = PromptContributionLocation.BEGINNING
override fun contribution(): String {
return "You are a helpful AI assistant."
}
}
@Service
class PromptService(
private val contributors: List<PromptContributor>
) {
fun buildPrompt(userInput: String): String {
val beginning = contributors
.filter { it.promptContributionLocation == PromptContributionLocation.BEGINNING }
.joinToString("\n\n") { it.contribution() }
val end = contributors
.filter { it.promptContributionLocation == PromptContributionLocation.END }
.joinToString("\n\n") { it.contribution() }
return buildString {
if (beginning.isNotEmpty()) {
appendLine(beginning)
appendLine()
}
appendLine(userInput)
if (end.isNotEmpty()) {
appendLine()
appendLine(end)
}
}.trim()
}
}Complete pattern for LLM calls with typed responses.
data class ExtractedData(
val entities: List<String>,
val sentiment: String,
val summary: String
)
@Service
class ExtractionService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper,
private val costTracker: CostTracker
) {
fun extractData(text: String): ExtractedData? {
val converter = JacksonOutputConverter(ExtractedData::class.java, objectMapper)
val prompt = """
Extract information from the following text.
${converter.getFormat()}
Text: $text
""".trimIndent()
val options = LlmOptions.withModel("gpt-4")
.withTemperature(0.3)
.withMaxTokens(500)
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
// Track costs
val usage = response.metadata?.usage
if (usage != null) {
costTracker.recordUsage(usage)
}
return converter.convert(response)
}
}Process streaming LLM responses reactively.
@Service
class StreamingService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun streamEvents(prompt: String): Flux<Event> {
val converter = StreamingJacksonOutputConverter(
Event::class.java,
objectMapper
)
val fullPrompt = """
$prompt
${converter.getFormat()}
""".trimIndent()
return chatClient
.prompt()
.user(fullPrompt)
.stream()
.content()
.buffer()
.map { chunks -> chunks.joinToString("") }
.flatMapMany { jsonl -> converter.convertStream(jsonl) }
}
fun streamWithThinking(prompt: String): Flux<StreamingEvent<Event>> {
val converter = StreamingJacksonOutputConverter(
Event::class.java,
objectMapper
)
return chatClient
.prompt()
.user(prompt)
.stream()
.content()
.buffer()
.map { it.joinToString("") }
.flatMapMany { converter.convertStreamWithThinking(it) }
}
}Track costs across requests and sessions.
@Service
class ManagedLLMService(
private val chatClient: ChatClient,
private val pricing: PricingModel
) {
private val sessionTracker = CostTracker(pricing)
fun callWithTracking(prompt: String, options: LlmOptions): String {
val response = chatClient
.prompt()
.user(prompt)
.call()
val usage = response.metadata?.usage
if (usage != null) {
sessionTracker.recordUsage(usage)
logger.info("Request cost: $${pricing.costOf(usage)}")
logger.info("Session total: $${sessionTracker.totalCost()}")
}
return response.content()
}
fun getSessionCost(): Double = sessionTracker.totalCost()
fun resetSession() {
sessionTracker.reset()
}
}Route requests to different models based on criteria.
@Service
class MultiModelService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
private val modelStrategy = mapOf(
"quick" to LlmOptions.withModel("gpt-3.5-turbo").withMaxTokens(500),
"balanced" to LlmOptions.withModel("gpt-4").withMaxTokens(1000),
"creative" to LlmOptions.withModel("claude-3-opus").withTemperature(0.9),
"analytical" to LlmOptions.withModel("gpt-4").withTemperature(0.2)
)
fun call(prompt: String, strategy: String = "balanced"): String {
val options = modelStrategy[strategy]
?: throw IllegalArgumentException("Unknown strategy: $strategy")
return chatClient
.prompt()
.user(prompt)
.call()
.content()
}
fun <T> callStructured(
prompt: String,
responseType: Class<T>,
strategy: String = "balanced"
): T? {
val converter = JacksonOutputConverter(responseType, objectMapper)
val options = modelStrategy[strategy]
?: throw IllegalArgumentException("Unknown strategy: $strategy")
val fullPrompt = "$prompt\n\n${converter.getFormat()}"
val response = call(fullPrompt, strategy)
return converter.convert(response)
}
}Semantic search and similarity detection.
@Service
class SemanticSearchService(
private val embeddingService: EmbeddingService
) {
private val documents = mutableListOf<Pair<String, FloatArray>>()
fun indexDocument(text: String) {
val embedding = embeddingService.embed(text)
documents.add(text to embedding)
}
fun indexDocuments(texts: List<String>) {
val embeddings = embeddingService.embed(texts)
texts.zip(embeddings).forEach { (text, embedding) ->
documents.add(text to embedding)
}
}
fun search(query: String, topK: Int = 5): List<String> {
val queryEmbedding = embeddingService.embed(query)
return documents
.map { (text, embedding) ->
text to cosineSimilarity(queryEmbedding, embedding)
}
.sortedByDescending { it.second }
.take(topK)
.map { it.first }
}
fun findSimilar(text: String, threshold: Double = 0.7): List<String> {
val embedding = embeddingService.embed(text)
return documents
.filter { (_, docEmbedding) ->
cosineSimilarity(embedding, docEmbedding) >= threshold
}
.map { it.first }
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Double {
val dotProduct = a.zip(b).sumOf { (x, y) -> (x * y).toDouble() }
val magA = kotlin.math.sqrt(a.sumOf { (it * it).toDouble() })
val magB = kotlin.math.sqrt(b.sumOf { (it * it).toDouble() })
return dotProduct / (magA * magB)
}
}Retrieval-Augmented Generation with embeddings.
@Service
class RAGService(
private val embeddingService: EmbeddingService,
private val chatClient: ChatClient
) {
private val knowledgeBase = mutableListOf<Pair<String, FloatArray>>()
fun addKnowledge(text: String) {
val embedding = embeddingService.embed(text)
knowledgeBase.add(text to embedding)
}
fun addKnowledge(texts: List<String>) {
val embeddings = embeddingService.embed(texts)
texts.zip(embeddings).forEach { (text, embedding) ->
knowledgeBase.add(text to embedding)
}
}
fun query(question: String, topK: Int = 3): String {
// 1. Embed question
val questionEmbedding = embeddingService.embed(question)
// 2. Find relevant context
val relevantDocs = knowledgeBase
.map { (text, embedding) ->
text to cosineSimilarity(questionEmbedding, embedding)
}
.sortedByDescending { it.second }
.take(topK)
.map { it.first }
// 3. Build prompt with context
val context = relevantDocs.joinToString("\n\n")
val prompt = """
Use the following context to answer the question.
Context:
$context
Question: $question
Answer:
""".trimIndent()
// 4. Generate answer
return chatClient
.prompt()
.user(prompt)
.call()
.content()
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Double {
val dotProduct = a.zip(b).sumOf { (x, y) -> (x * y).toDouble() }
val magA = kotlin.math.sqrt(a.sumOf { (it * it).toDouble() })
val magB = kotlin.math.sqrt(b.sumOf { (it * it).toDouble() })
return dotProduct / (magA * magB)
}
}Process multiple requests efficiently.
@Service
class BatchProcessingService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun <T> processBatch(
items: List<String>,
responseType: Class<T>,
batchSize: Int = 10
): List<T?> {
val converter = JacksonOutputConverter(responseType, objectMapper)
return items.chunked(batchSize).flatMap { batch ->
batch.mapNotNull { item ->
val prompt = "Process: $item\n\n${converter.getFormat()}"
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
converter.convert(response)
}
}
}
fun <T> processBatchParallel(
items: List<String>,
responseType: Class<T>
): Flux<T> {
val converter = JacksonOutputConverter(responseType, objectMapper)
return Flux.fromIterable(items)
.parallel()
.runOn(Schedulers.parallel())
.flatMap { item ->
val prompt = "Process: $item\n\n${converter.getFormat()}"
Mono.fromCallable {
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
converter.convert(response)
}
}
.sequential()
.filter { it != null }
.map { it!! }
}
}@Service
class ResilientLLMService(
private val chatClient: ChatClient
) {
fun callWithRetry(
prompt: String,
maxRetries: Int = 3
): String? {
var lastError: Exception? = null
repeat(maxRetries) { attempt ->
try {
return chatClient
.prompt()
.user(prompt)
.call()
.content()
} catch (e: Exception) {
lastError = e
logger.warn("Attempt ${attempt + 1} failed: ${e.message}")
if (attempt < maxRetries - 1) {
val delay = 1000L * (attempt + 1) // Exponential backoff
Thread.sleep(delay)
}
}
}
logger.error("All retries failed", lastError)
return null
}
fun callWithReactiveRetry(prompt: String): Mono<String> {
return Mono.fromCallable {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.onErrorResume { error ->
logger.error("All retries failed", error)
Mono.empty()
}
}
}@Service
class CachedLLMService(
private val chatClient: ChatClient
) {
private val cache = ConcurrentHashMap<String, String>()
fun callWithCache(prompt: String): String {
return cache.getOrPut(prompt) {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
}
fun callWithTimedCache(
prompt: String,
ttl: Duration = Duration.ofMinutes(10)
): String {
val cacheKey = "$prompt:${System.currentTimeMillis() / ttl.toMillis()}"
return cache.getOrPut(cacheKey) {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
}
fun clearCache() {
cache.clear()
}
}class MockConverter<T>(private val mockResult: T) : StructuredOutputConverter<T> {
override fun convert(text: String): T = mockResult
override fun getFormat(): String = "Mock format"
}
@Test
fun `test service with mock converter`() {
val mockPerson = Person("Test", 30, "test@example.com")
val mockConverter = MockConverter(mockPerson)
// Use in tests
val result = service.process(mockConverter)
assertEquals(mockPerson, result)
}@SpringBootTest
class LLMServiceIntegrationTest {
@Autowired
private lateinit var llmService: LLMService
@Test
fun `test structured output extraction`() {
val result = llmService.callWithStructuredOutput(
"Extract: John Doe, 30, john@example.com",
Person::class.java
)
assertNotNull(result)
assertEquals("John Doe", result?.name)
assertEquals(30, result?.age)
}
}tessl i tessl/maven-com-embabel-agent--embabel-agent-common@0.3.1