CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/maven-com-embabel-agent--embabel-agent-openai

OpenAI compatible model factory for the Embabel Agent Framework

Overview
Eval results
Files

extending.mddocs/

Extending the Factory

Guide to subclassing OpenAiCompatibleModelFactory for custom behavior.

Why Extend the Factory

Extend the factory when you need:

  • Provider-specific configuration or initialization
  • Custom model creation logic
  • Additional services or methods
  • Provider-specific authentication schemes
  • Custom observability or logging
  • Specialized model configurations

Protected Members

The factory exposes these protected members for subclasses:

/**
 * Logger instance for the factory class
 */
protected val logger: Logger

/**
 * The configured OpenAI API instance
 */
protected val openAiApi: OpenAiApi

/**
 * Creates a Spring AI ChatModel with the configured API
 *
 * @param model Model name/identifier
 * @param retryTemplate Retry template for resilience
 * @return ChatModel configured with the specified model
 */
protected fun chatModelOf(
    model: String,
    retryTemplate: RetryTemplate
): ChatModel

Basic Subclassing

Simple Extension with Custom Logging

import com.embabel.agent.openai.OpenAiCompatibleModelFactory
import io.micrometer.observation.ObservationRegistry
import org.springframework.beans.factory.ObjectProvider
import org.springframework.http.client.ClientHttpRequestFactory

class CustomOpenAiFactory(
    baseUrl: String?,
    apiKey: String?,
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey,
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    // Access protected logger in init block
    init {
        logger.info("Custom factory initialized with base URL: {}", baseUrl ?: "default")
    }

    // Override or add methods as needed
    fun createCustomChatModel(model: String): ChatModel {
        logger.debug("Creating custom chat model: {}", model)
        return chatModelOf(model, RetryTemplate())
    }
}

Extension with Default Configuration

class SimpleOpenAiFactory(
    apiKey: String,
    observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
    baseUrl = null,  // Always use OpenAI default
    apiKey = apiKey,
    completionsPath = null,
    embeddingsPath = null,
    observationRegistry = observationRegistry
) {

    init {
        logger.info("Simple OpenAI factory initialized")
    }

    // Convenience method with preset configurations
    fun createGpt4() = openAiCompatibleLlm(
        model = "gpt-4",
        pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
        provider = "OpenAI",
        knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
    )

    fun createGpt35Turbo() = openAiCompatibleLlm(
        model = "gpt-3.5-turbo",
        pricingModel = PricingModel.usdPer1MTokens(0.5, 1.5),
        provider = "OpenAI",
        knowledgeCutoffDate = LocalDate.of(2021, 9, 1)
    )
}

Provider-Specific Factories

Azure OpenAI Factory

class AzureOpenAiFactory(
    resourceName: String,
    apiKey: String,
    gpt4DeploymentName: String,
    embeddingDeploymentName: String,
    observationRegistry: ObservationRegistry,
    apiVersion: String = "2024-02-15-preview"
) : OpenAiCompatibleModelFactory(
    baseUrl = "https://$resourceName.openai.azure.com",
    apiKey = apiKey,
    completionsPath = "/openai/deployments/$gpt4DeploymentName/chat/completions?api-version=$apiVersion",
    embeddingsPath = "/openai/deployments/$embeddingDeploymentName/embeddings?api-version=$apiVersion",
    observationRegistry = observationRegistry
) {

    init {
        logger.info("Azure OpenAI factory initialized for resource: {}", resourceName)
    }

    // Azure-specific convenience methods
    fun createGpt4Service() = openAiCompatibleLlm(
        model = "gpt-4",
        pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
        provider = "Azure OpenAI",
        knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
    )

    fun createEmbeddingService() = openAiCompatibleEmbeddingService(
        model = "text-embedding-3-small",
        provider = "Azure OpenAI"
    )
}

// Usage
val factory = AzureOpenAiFactory(
    resourceName = "my-azure-resource",
    apiKey = System.getenv("AZURE_OPENAI_API_KEY"),
    gpt4DeploymentName = "gpt-4-deployment",
    embeddingDeploymentName = "embedding-deployment",
    observationRegistry = ObservationRegistry.create()
)

val service = factory.createGpt4Service()

Ollama Factory

class OllamaFactory(
    host: String = "localhost",
    port: Int = 11434,
    observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
    baseUrl = "http://$host:$port",
    apiKey = null,  // Ollama doesn't need auth
    completionsPath = null,
    embeddingsPath = null,
    observationRegistry = observationRegistry
) {

    init {
        logger.info("Ollama factory initialized at {}:{}", host, port)
    }

    fun llama3_8b() = openAiCompatibleLlm(
        model = "llama3:8b",
        pricingModel = PricingModel.ALL_YOU_CAN_EAT,
        provider = "Ollama",
        knowledgeCutoffDate = null
    )

    fun llama3_70b() = openAiCompatibleLlm(
        model = "llama3:70b",
        pricingModel = PricingModel.ALL_YOU_CAN_EAT,
        provider = "Ollama",
        knowledgeCutoffDate = null
    )

    fun codellama() = openAiCompatibleLlm(
        model = "codellama:34b",
        pricingModel = PricingModel.ALL_YOU_CAN_EAT,
        provider = "Ollama",
        knowledgeCutoffDate = null
    )
}

// Usage
val factory = OllamaFactory(
    host = "localhost",
    port = 11434,
    observationRegistry = ObservationRegistry.create()
)

val llama = factory.llama3_70b()

LM Studio Factory

class LmStudioFactory(
    port: Int = 1234,
    observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
    baseUrl = "http://localhost:$port",
    apiKey = null,
    completionsPath = "/v1/chat/completions",
    embeddingsPath = null,
    observationRegistry = observationRegistry
) {

    init {
        logger.info("LM Studio factory initialized on port {}", port)
    }

    fun createModelService(
        modelName: String,
        knowledgeCutoffDate: LocalDate? = null
    ) = openAiCompatibleLlm(
        model = modelName,
        pricingModel = PricingModel.ALL_YOU_CAN_EAT,
        provider = "LM Studio",
        knowledgeCutoffDate = knowledgeCutoffDate
    )
}

// Usage
val factory = LmStudioFactory(
    port = 1234,
    observationRegistry = ObservationRegistry.create()
)

val service = factory.createModelService("local-model")

Advanced Extensions

Factory with Monitoring

class MonitoredOpenAiFactory(
    baseUrl: String?,
    apiKey: String?,
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>,
    private val meterRegistry: MeterRegistry
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey,
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    private val modelCreationCounter = meterRegistry.counter("openai.model.created")
    private val embeddingCreationCounter = meterRegistry.counter("openai.embedding.created")

    override fun openAiCompatibleLlm(
        model: String,
        pricingModel: PricingModel,
        provider: String,
        knowledgeCutoffDate: LocalDate?,
        optionsConverter: OptionsConverter<*>,
        retryTemplate: RetryTemplate
    ): LlmService<*> {
        modelCreationCounter.increment()
        logger.info("Creating LLM service for model: {}", model)

        return super.openAiCompatibleLlm(
            model,
            pricingModel,
            provider,
            knowledgeCutoffDate,
            optionsConverter,
            retryTemplate
        )
    }

    override fun openAiCompatibleEmbeddingService(
        model: String,
        provider: String
    ): EmbeddingService {
        embeddingCreationCounter.increment()
        logger.info("Creating embedding service for model: {}", model)

        return super.openAiCompatibleEmbeddingService(model, provider)
    }
}

Factory with Validation

class ValidatingOpenAiFactory(
    baseUrl: String?,
    apiKey: String?,
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey,
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    private val validModels = setOf(
        "gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-5-turbo"
    )

    override fun openAiCompatibleLlm(
        model: String,
        pricingModel: PricingModel,
        provider: String,
        knowledgeCutoffDate: LocalDate?,
        optionsConverter: OptionsConverter<*>,
        retryTemplate: RetryTemplate
    ): LlmService<*> {
        require(model in validModels) {
            "Model '$model' is not in the allowed list: $validModels"
        }

        logger.info("Validated model '{}' - proceeding with creation", model)

        return super.openAiCompatibleLlm(
            model,
            pricingModel,
            provider,
            knowledgeCutoffDate,
            optionsConverter,
            retryTemplate
        )
    }
}

Factory with Caching

class CachingOpenAiFactory(
    baseUrl: String?,
    apiKey: String?,
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey,
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    private val llmServiceCache = ConcurrentHashMap<String, LlmService<*>>()
    private val embeddingServiceCache = ConcurrentHashMap<String, EmbeddingService>()

    override fun openAiCompatibleLlm(
        model: String,
        pricingModel: PricingModel,
        provider: String,
        knowledgeCutoffDate: LocalDate?,
        optionsConverter: OptionsConverter<*>,
        retryTemplate: RetryTemplate
    ): LlmService<*> {
        val cacheKey = "$provider:$model"

        return llmServiceCache.getOrPut(cacheKey) {
            logger.info("Creating new LLM service for {}", cacheKey)
            super.openAiCompatibleLlm(
                model,
                pricingModel,
                provider,
                knowledgeCutoffDate,
                optionsConverter,
                retryTemplate
            )
        }.also {
            logger.debug("Returning cached LLM service for {}", cacheKey)
        }
    }

    override fun openAiCompatibleEmbeddingService(
        model: String,
        provider: String
    ): EmbeddingService {
        val cacheKey = "$provider:$model"

        return embeddingServiceCache.getOrPut(cacheKey) {
            logger.info("Creating new embedding service for {}", cacheKey)
            super.openAiCompatibleEmbeddingService(model, provider)
        }.also {
            logger.debug("Returning cached embedding service for {}", cacheKey)
        }
    }

    fun clearCache() {
        logger.info("Clearing service caches")
        llmServiceCache.clear()
        embeddingServiceCache.clear()
    }
}

Factory with Custom Authentication

class TokenRefreshingFactory(
    baseUrl: String?,
    private val tokenProvider: () -> String,  // Function to get current token
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey = null,  // Will be set dynamically
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    private fun getCurrentApiKey(): String {
        val token = tokenProvider()
        logger.debug("Retrieved fresh API token")
        return token
    }

    override fun openAiCompatibleLlm(
        model: String,
        pricingModel: PricingModel,
        provider: String,
        knowledgeCutoffDate: LocalDate?,
        optionsConverter: OptionsConverter<*>,
        retryTemplate: RetryTemplate
    ): LlmService<*> {
        // Refresh token before creating service
        val currentToken = getCurrentApiKey()

        // Create a new factory instance with the current token
        val refreshedFactory = OpenAiCompatibleModelFactory(
            baseUrl,
            currentToken,
            completionsPath,
            embeddingsPath,
            observationRegistry,
            requestFactory
        )

        return refreshedFactory.openAiCompatibleLlm(
            model,
            pricingModel,
            provider,
            knowledgeCutoffDate,
            optionsConverter,
            retryTemplate
        )
    }
}

// Usage
val factory = TokenRefreshingFactory(
    baseUrl = null,
    tokenProvider = { authService.getCurrentToken() },
    completionsPath = null,
    embeddingsPath = null,
    observationRegistry = ObservationRegistry.create(),
    requestFactory = ObjectProviders.empty()
)

Using chatModelOf for Advanced Scenarios

The chatModelOf protected method creates Spring AI ChatModel instances:

class AdvancedFactory(
    baseUrl: String?,
    apiKey: String?,
    completionsPath: String?,
    embeddingsPath: String?,
    observationRegistry: ObservationRegistry,
    requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
    baseUrl,
    apiKey,
    completionsPath,
    embeddingsPath,
    observationRegistry,
    requestFactory
) {

    // Create a raw ChatModel for direct Spring AI usage
    fun createRawChatModel(
        model: String,
        retryTemplate: RetryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE
    ): ChatModel {
        logger.debug("Creating raw ChatModel for model: {}", model)
        return chatModelOf(model, retryTemplate)
    }

    // Wrap ChatModel with additional functionality
    fun createInstrumentedChatModel(model: String): ChatModel {
        val rawModel = chatModelOf(model, RetryUtils.DEFAULT_RETRY_TEMPLATE)

        return object : ChatModel by rawModel {
            override fun call(prompt: Prompt): ChatResponse {
                logger.info("Calling model with prompt: {}", prompt)
                val startTime = System.currentTimeMillis()

                val response = rawModel.call(prompt)

                val duration = System.currentTimeMillis() - startTime
                logger.info("Model responded in {}ms", duration)

                return response
            }
        }
    }
}

Spring Integration for Custom Factories

@Configuration
class CustomFactoryConfiguration(
    private val observationRegistry: ObservationRegistry
) {

    @Bean
    fun azureOpenAiFactory(
        @Value("\${azure.openai.resource}") resourceName: String,
        @Value("\${azure.openai.api.key}") apiKey: String
    ): AzureOpenAiFactory {
        return AzureOpenAiFactory(
            resourceName = resourceName,
            apiKey = apiKey,
            gpt4DeploymentName = "gpt-4",
            embeddingDeploymentName = "embeddings",
            observationRegistry = observationRegistry
        )
    }

    @Bean
    fun ollamaFactory(): OllamaFactory {
        return OllamaFactory(
            host = "localhost",
            port = 11434,
            observationRegistry = observationRegistry
        )
    }

    @Bean("azureGpt4")
    fun azureGpt4(factory: AzureOpenAiFactory) = factory.createGpt4Service()

    @Bean("localLlama")
    fun localLlama(factory: OllamaFactory) = factory.llama3_70b()
}

Testing Custom Factories

class CustomFactoryTest {

    private lateinit var factory: CustomOpenAiFactory
    private lateinit var observationRegistry: ObservationRegistry

    @BeforeEach
    fun setup() {
        observationRegistry = ObservationRegistry.create()
        factory = CustomOpenAiFactory(
            baseUrl = null,
            apiKey = "test-key",
            completionsPath = null,
            embeddingsPath = null,
            observationRegistry = observationRegistry,
            requestFactory = ObjectProviders.empty()
        )
    }

    @Test
    fun `should create LLM service`() {
        val service = factory.openAiCompatibleLlm(
            model = "gpt-4",
            pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
            provider = "OpenAI",
            knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
        )

        assertNotNull(service)
    }

    @Test
    fun `should access protected logger`() {
        // Logger is used in init block
        // Verify through log output or mocking
    }
}

Install with Tessl CLI

npx tessl i tessl/maven-com-embabel-agent--embabel-agent-openai@0.3.0

docs

api-reference.md

configuration.md

extending.md

index.md

java-usage.md

options-converters.md

quickstart.md

spring-integration.md

use-cases.md

tile.json