OpenAI compatible model factory for the Embabel Agent Framework
Guide to subclassing OpenAiCompatibleModelFactory for custom behavior.
Extend the factory when you need:
The factory exposes these protected members for subclasses:
/**
* Logger instance for the factory class
*/
protected val logger: Logger
/**
* The configured OpenAI API instance
*/
protected val openAiApi: OpenAiApi
/**
* Creates a Spring AI ChatModel with the configured API
*
* @param model Model name/identifier
* @param retryTemplate Retry template for resilience
* @return ChatModel configured with the specified model
*/
protected fun chatModelOf(
model: String,
retryTemplate: RetryTemplate
): ChatModelimport com.embabel.agent.openai.OpenAiCompatibleModelFactory
import io.micrometer.observation.ObservationRegistry
import org.springframework.beans.factory.ObjectProvider
import org.springframework.http.client.ClientHttpRequestFactory
class CustomOpenAiFactory(
baseUrl: String?,
apiKey: String?,
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
// Access protected logger in init block
init {
logger.info("Custom factory initialized with base URL: {}", baseUrl ?: "default")
}
// Override or add methods as needed
fun createCustomChatModel(model: String): ChatModel {
logger.debug("Creating custom chat model: {}", model)
return chatModelOf(model, RetryTemplate())
}
}class SimpleOpenAiFactory(
apiKey: String,
observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
baseUrl = null, // Always use OpenAI default
apiKey = apiKey,
completionsPath = null,
embeddingsPath = null,
observationRegistry = observationRegistry
) {
init {
logger.info("Simple OpenAI factory initialized")
}
// Convenience method with preset configurations
fun createGpt4() = openAiCompatibleLlm(
model = "gpt-4",
pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
provider = "OpenAI",
knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
)
fun createGpt35Turbo() = openAiCompatibleLlm(
model = "gpt-3.5-turbo",
pricingModel = PricingModel.usdPer1MTokens(0.5, 1.5),
provider = "OpenAI",
knowledgeCutoffDate = LocalDate.of(2021, 9, 1)
)
}class AzureOpenAiFactory(
resourceName: String,
apiKey: String,
gpt4DeploymentName: String,
embeddingDeploymentName: String,
observationRegistry: ObservationRegistry,
apiVersion: String = "2024-02-15-preview"
) : OpenAiCompatibleModelFactory(
baseUrl = "https://$resourceName.openai.azure.com",
apiKey = apiKey,
completionsPath = "/openai/deployments/$gpt4DeploymentName/chat/completions?api-version=$apiVersion",
embeddingsPath = "/openai/deployments/$embeddingDeploymentName/embeddings?api-version=$apiVersion",
observationRegistry = observationRegistry
) {
init {
logger.info("Azure OpenAI factory initialized for resource: {}", resourceName)
}
// Azure-specific convenience methods
fun createGpt4Service() = openAiCompatibleLlm(
model = "gpt-4",
pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
provider = "Azure OpenAI",
knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
)
fun createEmbeddingService() = openAiCompatibleEmbeddingService(
model = "text-embedding-3-small",
provider = "Azure OpenAI"
)
}
// Usage
val factory = AzureOpenAiFactory(
resourceName = "my-azure-resource",
apiKey = System.getenv("AZURE_OPENAI_API_KEY"),
gpt4DeploymentName = "gpt-4-deployment",
embeddingDeploymentName = "embedding-deployment",
observationRegistry = ObservationRegistry.create()
)
val service = factory.createGpt4Service()class OllamaFactory(
host: String = "localhost",
port: Int = 11434,
observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
baseUrl = "http://$host:$port",
apiKey = null, // Ollama doesn't need auth
completionsPath = null,
embeddingsPath = null,
observationRegistry = observationRegistry
) {
init {
logger.info("Ollama factory initialized at {}:{}", host, port)
}
fun llama3_8b() = openAiCompatibleLlm(
model = "llama3:8b",
pricingModel = PricingModel.ALL_YOU_CAN_EAT,
provider = "Ollama",
knowledgeCutoffDate = null
)
fun llama3_70b() = openAiCompatibleLlm(
model = "llama3:70b",
pricingModel = PricingModel.ALL_YOU_CAN_EAT,
provider = "Ollama",
knowledgeCutoffDate = null
)
fun codellama() = openAiCompatibleLlm(
model = "codellama:34b",
pricingModel = PricingModel.ALL_YOU_CAN_EAT,
provider = "Ollama",
knowledgeCutoffDate = null
)
}
// Usage
val factory = OllamaFactory(
host = "localhost",
port = 11434,
observationRegistry = ObservationRegistry.create()
)
val llama = factory.llama3_70b()class LmStudioFactory(
port: Int = 1234,
observationRegistry: ObservationRegistry
) : OpenAiCompatibleModelFactory(
baseUrl = "http://localhost:$port",
apiKey = null,
completionsPath = "/v1/chat/completions",
embeddingsPath = null,
observationRegistry = observationRegistry
) {
init {
logger.info("LM Studio factory initialized on port {}", port)
}
fun createModelService(
modelName: String,
knowledgeCutoffDate: LocalDate? = null
) = openAiCompatibleLlm(
model = modelName,
pricingModel = PricingModel.ALL_YOU_CAN_EAT,
provider = "LM Studio",
knowledgeCutoffDate = knowledgeCutoffDate
)
}
// Usage
val factory = LmStudioFactory(
port = 1234,
observationRegistry = ObservationRegistry.create()
)
val service = factory.createModelService("local-model")class MonitoredOpenAiFactory(
baseUrl: String?,
apiKey: String?,
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>,
private val meterRegistry: MeterRegistry
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
private val modelCreationCounter = meterRegistry.counter("openai.model.created")
private val embeddingCreationCounter = meterRegistry.counter("openai.embedding.created")
override fun openAiCompatibleLlm(
model: String,
pricingModel: PricingModel,
provider: String,
knowledgeCutoffDate: LocalDate?,
optionsConverter: OptionsConverter<*>,
retryTemplate: RetryTemplate
): LlmService<*> {
modelCreationCounter.increment()
logger.info("Creating LLM service for model: {}", model)
return super.openAiCompatibleLlm(
model,
pricingModel,
provider,
knowledgeCutoffDate,
optionsConverter,
retryTemplate
)
}
override fun openAiCompatibleEmbeddingService(
model: String,
provider: String
): EmbeddingService {
embeddingCreationCounter.increment()
logger.info("Creating embedding service for model: {}", model)
return super.openAiCompatibleEmbeddingService(model, provider)
}
}class ValidatingOpenAiFactory(
baseUrl: String?,
apiKey: String?,
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
private val validModels = setOf(
"gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-5-turbo"
)
override fun openAiCompatibleLlm(
model: String,
pricingModel: PricingModel,
provider: String,
knowledgeCutoffDate: LocalDate?,
optionsConverter: OptionsConverter<*>,
retryTemplate: RetryTemplate
): LlmService<*> {
require(model in validModels) {
"Model '$model' is not in the allowed list: $validModels"
}
logger.info("Validated model '{}' - proceeding with creation", model)
return super.openAiCompatibleLlm(
model,
pricingModel,
provider,
knowledgeCutoffDate,
optionsConverter,
retryTemplate
)
}
}class CachingOpenAiFactory(
baseUrl: String?,
apiKey: String?,
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
private val llmServiceCache = ConcurrentHashMap<String, LlmService<*>>()
private val embeddingServiceCache = ConcurrentHashMap<String, EmbeddingService>()
override fun openAiCompatibleLlm(
model: String,
pricingModel: PricingModel,
provider: String,
knowledgeCutoffDate: LocalDate?,
optionsConverter: OptionsConverter<*>,
retryTemplate: RetryTemplate
): LlmService<*> {
val cacheKey = "$provider:$model"
return llmServiceCache.getOrPut(cacheKey) {
logger.info("Creating new LLM service for {}", cacheKey)
super.openAiCompatibleLlm(
model,
pricingModel,
provider,
knowledgeCutoffDate,
optionsConverter,
retryTemplate
)
}.also {
logger.debug("Returning cached LLM service for {}", cacheKey)
}
}
override fun openAiCompatibleEmbeddingService(
model: String,
provider: String
): EmbeddingService {
val cacheKey = "$provider:$model"
return embeddingServiceCache.getOrPut(cacheKey) {
logger.info("Creating new embedding service for {}", cacheKey)
super.openAiCompatibleEmbeddingService(model, provider)
}.also {
logger.debug("Returning cached embedding service for {}", cacheKey)
}
}
fun clearCache() {
logger.info("Clearing service caches")
llmServiceCache.clear()
embeddingServiceCache.clear()
}
}class TokenRefreshingFactory(
baseUrl: String?,
private val tokenProvider: () -> String, // Function to get current token
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey = null, // Will be set dynamically
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
private fun getCurrentApiKey(): String {
val token = tokenProvider()
logger.debug("Retrieved fresh API token")
return token
}
override fun openAiCompatibleLlm(
model: String,
pricingModel: PricingModel,
provider: String,
knowledgeCutoffDate: LocalDate?,
optionsConverter: OptionsConverter<*>,
retryTemplate: RetryTemplate
): LlmService<*> {
// Refresh token before creating service
val currentToken = getCurrentApiKey()
// Create a new factory instance with the current token
val refreshedFactory = OpenAiCompatibleModelFactory(
baseUrl,
currentToken,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
)
return refreshedFactory.openAiCompatibleLlm(
model,
pricingModel,
provider,
knowledgeCutoffDate,
optionsConverter,
retryTemplate
)
}
}
// Usage
val factory = TokenRefreshingFactory(
baseUrl = null,
tokenProvider = { authService.getCurrentToken() },
completionsPath = null,
embeddingsPath = null,
observationRegistry = ObservationRegistry.create(),
requestFactory = ObjectProviders.empty()
)The chatModelOf protected method creates Spring AI ChatModel instances:
class AdvancedFactory(
baseUrl: String?,
apiKey: String?,
completionsPath: String?,
embeddingsPath: String?,
observationRegistry: ObservationRegistry,
requestFactory: ObjectProvider<ClientHttpRequestFactory>
) : OpenAiCompatibleModelFactory(
baseUrl,
apiKey,
completionsPath,
embeddingsPath,
observationRegistry,
requestFactory
) {
// Create a raw ChatModel for direct Spring AI usage
fun createRawChatModel(
model: String,
retryTemplate: RetryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE
): ChatModel {
logger.debug("Creating raw ChatModel for model: {}", model)
return chatModelOf(model, retryTemplate)
}
// Wrap ChatModel with additional functionality
fun createInstrumentedChatModel(model: String): ChatModel {
val rawModel = chatModelOf(model, RetryUtils.DEFAULT_RETRY_TEMPLATE)
return object : ChatModel by rawModel {
override fun call(prompt: Prompt): ChatResponse {
logger.info("Calling model with prompt: {}", prompt)
val startTime = System.currentTimeMillis()
val response = rawModel.call(prompt)
val duration = System.currentTimeMillis() - startTime
logger.info("Model responded in {}ms", duration)
return response
}
}
}
}@Configuration
class CustomFactoryConfiguration(
private val observationRegistry: ObservationRegistry
) {
@Bean
fun azureOpenAiFactory(
@Value("\${azure.openai.resource}") resourceName: String,
@Value("\${azure.openai.api.key}") apiKey: String
): AzureOpenAiFactory {
return AzureOpenAiFactory(
resourceName = resourceName,
apiKey = apiKey,
gpt4DeploymentName = "gpt-4",
embeddingDeploymentName = "embeddings",
observationRegistry = observationRegistry
)
}
@Bean
fun ollamaFactory(): OllamaFactory {
return OllamaFactory(
host = "localhost",
port = 11434,
observationRegistry = observationRegistry
)
}
@Bean("azureGpt4")
fun azureGpt4(factory: AzureOpenAiFactory) = factory.createGpt4Service()
@Bean("localLlama")
fun localLlama(factory: OllamaFactory) = factory.llama3_70b()
}class CustomFactoryTest {
private lateinit var factory: CustomOpenAiFactory
private lateinit var observationRegistry: ObservationRegistry
@BeforeEach
fun setup() {
observationRegistry = ObservationRegistry.create()
factory = CustomOpenAiFactory(
baseUrl = null,
apiKey = "test-key",
completionsPath = null,
embeddingsPath = null,
observationRegistry = observationRegistry,
requestFactory = ObjectProviders.empty()
)
}
@Test
fun `should create LLM service`() {
val service = factory.openAiCompatibleLlm(
model = "gpt-4",
pricingModel = PricingModel.usdPer1MTokens(30.0, 60.0),
provider = "OpenAI",
knowledgeCutoffDate = LocalDate.of(2023, 4, 1)
)
assertNotNull(service)
}
@Test
fun `should access protected logger`() {
// Logger is used in init block
// Verify through log output or mocking
}
}Install with Tessl CLI
npx tessl i tessl/maven-com-embabel-agent--embabel-agent-openai