Common AI framework utilities for the Embabel Agent system including LLM configuration, output converters, prompt contributors, and embedding service abstractions.
—
Common patterns for integrating Embabel Agent Common with Spring, reactive streams, and LLM clients.
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import com.fasterxml.jackson.databind.ObjectMapper
@Configuration
class AIConfiguration {
@Bean
fun objectMapper(): ObjectMapper {
return ObjectMapper()
.registerKotlinModule()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
}
@Bean
fun defaultLlmOptions(): LlmOptions {
return LlmOptions.withModel("gpt-4")
.withTemperature(0.7)
.withMaxTokens(1000)
}
@Bean
fun embeddingService(embeddingModel: EmbeddingModel): EmbeddingService {
return SpringAiEmbeddingService(
name = "text-embedding-ada-002",
provider = "openai",
model = embeddingModel
)
}
}import org.springframework.stereotype.Service
@Service
class LLMService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun <T> callWithStructuredOutput(
prompt: String,
responseType: Class<T>,
options: LlmOptions = LlmOptions.withDefaultLlm()
): T? {
val converter = JacksonOutputConverter(responseType, objectMapper)
val fullPrompt = """
$prompt
${converter.getFormat()}
""".trimIndent()
val response = chatClient
.prompt()
.user(fullPrompt)
.call()
.content()
return converter.convert(response)
}
}@Component
class SystemPromptContributor : PromptContributor {
override val role = "system"
override val promptContributionLocation = PromptContributionLocation.BEGINNING
override fun contribution(): String {
return "You are a helpful AI assistant."
}
}
@Service
class PromptService(
private val contributors: List<PromptContributor>
) {
fun buildPrompt(userInput: String): String {
val beginning = contributors
.filter { it.promptContributionLocation == PromptContributionLocation.BEGINNING }
.joinToString("\n\n") { it.contribution() }
val end = contributors
.filter { it.promptContributionLocation == PromptContributionLocation.END }
.joinToString("\n\n") { it.contribution() }
return buildString {
if (beginning.isNotEmpty()) {
appendLine(beginning)
appendLine()
}
appendLine(userInput)
if (end.isNotEmpty()) {
appendLine()
appendLine(end)
}
}.trim()
}
}Complete pattern for LLM calls with typed responses.
data class ExtractedData(
val entities: List<String>,
val sentiment: String,
val summary: String
)
@Service
class ExtractionService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper,
private val costTracker: CostTracker
) {
fun extractData(text: String): ExtractedData? {
val converter = JacksonOutputConverter(ExtractedData::class.java, objectMapper)
val prompt = """
Extract information from the following text.
${converter.getFormat()}
Text: $text
""".trimIndent()
val options = LlmOptions.withModel("gpt-4")
.withTemperature(0.3)
.withMaxTokens(500)
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
// Track costs
val usage = response.metadata?.usage
if (usage != null) {
costTracker.recordUsage(usage)
}
return converter.convert(response)
}
}Process streaming LLM responses reactively.
@Service
class StreamingService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun streamEvents(prompt: String): Flux<Event> {
val converter = StreamingJacksonOutputConverter(
Event::class.java,
objectMapper
)
val fullPrompt = """
$prompt
${converter.getFormat()}
""".trimIndent()
return chatClient
.prompt()
.user(fullPrompt)
.stream()
.content()
.buffer()
.map { chunks -> chunks.joinToString("") }
.flatMapMany { jsonl -> converter.convertStream(jsonl) }
}
fun streamWithThinking(prompt: String): Flux<StreamingEvent<Event>> {
val converter = StreamingJacksonOutputConverter(
Event::class.java,
objectMapper
)
return chatClient
.prompt()
.user(prompt)
.stream()
.content()
.buffer()
.map { it.joinToString("") }
.flatMapMany { converter.convertStreamWithThinking(it) }
}
}Track costs across requests and sessions.
@Service
class ManagedLLMService(
private val chatClient: ChatClient,
private val pricing: PricingModel
) {
private val sessionTracker = CostTracker(pricing)
fun callWithTracking(prompt: String, options: LlmOptions): String {
val response = chatClient
.prompt()
.user(prompt)
.call()
val usage = response.metadata?.usage
if (usage != null) {
sessionTracker.recordUsage(usage)
logger.info("Request cost: $${pricing.costOf(usage)}")
logger.info("Session total: $${sessionTracker.totalCost()}")
}
return response.content()
}
fun getSessionCost(): Double = sessionTracker.totalCost()
fun resetSession() {
sessionTracker.reset()
}
}Route requests to different models based on criteria.
@Service
class MultiModelService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
private val modelStrategy = mapOf(
"quick" to LlmOptions.withModel("gpt-3.5-turbo").withMaxTokens(500),
"balanced" to LlmOptions.withModel("gpt-4").withMaxTokens(1000),
"creative" to LlmOptions.withModel("claude-3-opus").withTemperature(0.9),
"analytical" to LlmOptions.withModel("gpt-4").withTemperature(0.2)
)
fun call(prompt: String, strategy: String = "balanced"): String {
val options = modelStrategy[strategy]
?: throw IllegalArgumentException("Unknown strategy: $strategy")
return chatClient
.prompt()
.user(prompt)
.call()
.content()
}
fun <T> callStructured(
prompt: String,
responseType: Class<T>,
strategy: String = "balanced"
): T? {
val converter = JacksonOutputConverter(responseType, objectMapper)
val options = modelStrategy[strategy]
?: throw IllegalArgumentException("Unknown strategy: $strategy")
val fullPrompt = "$prompt\n\n${converter.getFormat()}"
val response = call(fullPrompt, strategy)
return converter.convert(response)
}
}Semantic search and similarity detection.
@Service
class SemanticSearchService(
private val embeddingService: EmbeddingService
) {
private val documents = mutableListOf<Pair<String, FloatArray>>()
fun indexDocument(text: String) {
val embedding = embeddingService.embed(text)
documents.add(text to embedding)
}
fun indexDocuments(texts: List<String>) {
val embeddings = embeddingService.embed(texts)
texts.zip(embeddings).forEach { (text, embedding) ->
documents.add(text to embedding)
}
}
fun search(query: String, topK: Int = 5): List<String> {
val queryEmbedding = embeddingService.embed(query)
return documents
.map { (text, embedding) ->
text to cosineSimilarity(queryEmbedding, embedding)
}
.sortedByDescending { it.second }
.take(topK)
.map { it.first }
}
fun findSimilar(text: String, threshold: Double = 0.7): List<String> {
val embedding = embeddingService.embed(text)
return documents
.filter { (_, docEmbedding) ->
cosineSimilarity(embedding, docEmbedding) >= threshold
}
.map { it.first }
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Double {
val dotProduct = a.zip(b).sumOf { (x, y) -> (x * y).toDouble() }
val magA = kotlin.math.sqrt(a.sumOf { (it * it).toDouble() })
val magB = kotlin.math.sqrt(b.sumOf { (it * it).toDouble() })
return dotProduct / (magA * magB)
}
}Retrieval-Augmented Generation with embeddings.
@Service
class RAGService(
private val embeddingService: EmbeddingService,
private val chatClient: ChatClient
) {
private val knowledgeBase = mutableListOf<Pair<String, FloatArray>>()
fun addKnowledge(text: String) {
val embedding = embeddingService.embed(text)
knowledgeBase.add(text to embedding)
}
fun addKnowledge(texts: List<String>) {
val embeddings = embeddingService.embed(texts)
texts.zip(embeddings).forEach { (text, embedding) ->
knowledgeBase.add(text to embedding)
}
}
fun query(question: String, topK: Int = 3): String {
// 1. Embed question
val questionEmbedding = embeddingService.embed(question)
// 2. Find relevant context
val relevantDocs = knowledgeBase
.map { (text, embedding) ->
text to cosineSimilarity(questionEmbedding, embedding)
}
.sortedByDescending { it.second }
.take(topK)
.map { it.first }
// 3. Build prompt with context
val context = relevantDocs.joinToString("\n\n")
val prompt = """
Use the following context to answer the question.
Context:
$context
Question: $question
Answer:
""".trimIndent()
// 4. Generate answer
return chatClient
.prompt()
.user(prompt)
.call()
.content()
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Double {
val dotProduct = a.zip(b).sumOf { (x, y) -> (x * y).toDouble() }
val magA = kotlin.math.sqrt(a.sumOf { (it * it).toDouble() })
val magB = kotlin.math.sqrt(b.sumOf { (it * it).toDouble() })
return dotProduct / (magA * magB)
}
}Process multiple requests efficiently.
@Service
class BatchProcessingService(
private val chatClient: ChatClient,
private val objectMapper: ObjectMapper
) {
fun <T> processBatch(
items: List<String>,
responseType: Class<T>,
batchSize: Int = 10
): List<T?> {
val converter = JacksonOutputConverter(responseType, objectMapper)
return items.chunked(batchSize).flatMap { batch ->
batch.mapNotNull { item ->
val prompt = "Process: $item\n\n${converter.getFormat()}"
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
converter.convert(response)
}
}
}
fun <T> processBatchParallel(
items: List<String>,
responseType: Class<T>
): Flux<T> {
val converter = JacksonOutputConverter(responseType, objectMapper)
return Flux.fromIterable(items)
.parallel()
.runOn(Schedulers.parallel())
.flatMap { item ->
val prompt = "Process: $item\n\n${converter.getFormat()}"
Mono.fromCallable {
val response = chatClient
.prompt()
.user(prompt)
.call()
.content()
converter.convert(response)
}
}
.sequential()
.filter { it != null }
.map { it!! }
}
}@Service
class ResilientLLMService(
private val chatClient: ChatClient
) {
fun callWithRetry(
prompt: String,
maxRetries: Int = 3
): String? {
var lastError: Exception? = null
repeat(maxRetries) { attempt ->
try {
return chatClient
.prompt()
.user(prompt)
.call()
.content()
} catch (e: Exception) {
lastError = e
logger.warn("Attempt ${attempt + 1} failed: ${e.message}")
if (attempt < maxRetries - 1) {
val delay = 1000L * (attempt + 1) // Exponential backoff
Thread.sleep(delay)
}
}
}
logger.error("All retries failed", lastError)
return null
}
fun callWithReactiveRetry(prompt: String): Mono<String> {
return Mono.fromCallable {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.onErrorResume { error ->
logger.error("All retries failed", error)
Mono.empty()
}
}
}@Service
class CachedLLMService(
private val chatClient: ChatClient
) {
private val cache = ConcurrentHashMap<String, String>()
fun callWithCache(prompt: String): String {
return cache.getOrPut(prompt) {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
}
fun callWithTimedCache(
prompt: String,
ttl: Duration = Duration.ofMinutes(10)
): String {
val cacheKey = "$prompt:${System.currentTimeMillis() / ttl.toMillis()}"
return cache.getOrPut(cacheKey) {
chatClient
.prompt()
.user(prompt)
.call()
.content()
}
}
fun clearCache() {
cache.clear()
}
}class MockConverter<T>(private val mockResult: T) : StructuredOutputConverter<T> {
override fun convert(text: String): T = mockResult
override fun getFormat(): String = "Mock format"
}
@Test
fun `test service with mock converter`() {
val mockPerson = Person("Test", 30, "test@example.com")
val mockConverter = MockConverter(mockPerson)
// Use in tests
val result = service.process(mockConverter)
assertEquals(mockPerson, result)
}@SpringBootTest
class LLMServiceIntegrationTest {
@Autowired
private lateinit var llmService: LLMService
@Test
fun `test structured output extraction`() {
val result = llmService.callWithStructuredOutput(
"Extract: John Doe, 30, john@example.com",
Person::class.java
)
assertNotNull(result)
assertEquals("John Doe", result?.name)
assertEquals(30, result?.age)
}
}Install with Tessl CLI
npx tessl i tessl/maven-com-embabel-agent--embabel-agent-common@0.3.0