Spring Boot Starter for OpenAI integration providing auto-configuration for chat completion, embeddings, image generation, audio speech synthesis, audio transcription, and content moderation models. Includes high-level ChatClient API and conversation memory support.
The Chat Memory subsystem manages conversation history across multiple interactions, enabling context-aware conversations with configurable retention strategies.
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;package org.springframework.ai.chat.memory;
public interface ChatMemory {
// Constants
public static final String DEFAULT_CONVERSATION_ID = "default";
public static final String CONVERSATION_ID = "conversationId";
// Add a message to conversation history
void add(String conversationId, Message message);
// Add multiple messages
void add(String conversationId, List<Message> messages);
// Retrieve recent messages
List<Message> get(String conversationId);
// Clear conversation history
void clear(String conversationId);
}package org.springframework.ai.chat.memory;
public interface ChatMemoryRepository {
// Save messages for a conversation
void saveAll(String conversationId, List<Message> messages);
// Find messages by conversation ID
List<Message> findByConversationId(String conversationId);
// Delete a conversation
void deleteByConversationId(String conversationId);
// Find all conversation IDs
List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class MessageWindowChatMemory implements ChatMemory {
// Constants
public static final int DEFAULT_MAX_MESSAGES = 10;
// Factory method (builder pattern)
public static Builder builder();
// Methods
public void add(String conversationId, Message message);
public void add(String conversationId, List<Message> messages);
public List<Message> get(String conversationId);
public void clear(String conversationId);
// Builder
public static class Builder {
public Builder chatMemoryRepository(ChatMemoryRepository repository);
public Builder maxMessages(int maxMessages);
public MessageWindowChatMemory build();
}
}package org.springframework.ai.chat.memory;
public class InMemoryChatMemoryRepository implements ChatMemoryRepository {
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
public JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate);
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class MongoChatMemoryRepository implements ChatMemoryRepository {
public MongoChatMemoryRepository(MongoTemplate mongoTemplate);
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class CassandraChatMemoryRepository implements ChatMemoryRepository {
public CassandraChatMemoryRepository(CassandraTemplate cassandraTemplate);
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
public Neo4jChatMemoryRepository(Neo4jTemplate neo4jTemplate);
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}package org.springframework.ai.chat.memory;
public class CosmosDBChatMemoryRepository implements ChatMemoryRepository {
public CosmosDBChatMemoryRepository(CosmosTemplate cosmosTemplate);
public void saveAll(String conversationId, List<Message> messages);
public List<Message> findByConversationId(String conversationId);
public void deleteByConversationId(String conversationId);
public List<String> findConversationIds();
}import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class ConversationService {
private final ChatMemory chatMemory;
@Autowired
public ConversationService(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
}
public void recordInteraction(String conversationId, String userInput, String assistantResponse) {
chatMemory.add(conversationId, new UserMessage(userInput));
chatMemory.add(conversationId, new AssistantMessage(assistantResponse));
}
public List<Message> getRecentMessages(String conversationId) {
return chatMemory.get(conversationId);
}
public void clearConversation(String conversationId) {
chatMemory.clear(conversationId);
}
}import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
@Service
public class MemoryChatService {
private final ChatClient chatClient;
private final ChatMemory chatMemory;
@Autowired
public MemoryChatService(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
this.chatMemory = chatMemory;
this.chatClient = chatClientBuilder
.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
.build();
}
public String chat(String conversationId, String userMessage) {
return chatClient.prompt()
.user(userMessage)
.advisors(advisor -> advisor.param("conversationId", conversationId))
.call()
.content();
}
public void clearMemory(String conversationId) {
chatMemory.clear(conversationId);
}
}import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class ChatMemoryConfig {
@Bean
public ChatMemory chatMemory(ChatMemoryRepository repository) {
// Keep only last 10 messages using builder pattern
return MessageWindowChatMemory.builder()
.chatMemoryRepository(repository)
.maxMessages(10)
.build();
}
}public String getConversationSummary(String conversationId) {
List<Message> messages = chatMemory.get(conversationId);
StringBuilder summary = new StringBuilder("Conversation History:\n");
for (Message message : messages) {
String role = message.getMessageType().name();
String content = message.getContent();
summary.append(role).append(": ").append(content).append("\n");
}
return summary.toString();
}import java.util.List;
public void initializeConversation(String conversationId) {
List<Message> initialMessages = List.of(
new SystemMessage("You are a helpful assistant."),
new UserMessage("Hello!"),
new AssistantMessage("Hello! How can I help you today?")
);
chatMemory.add(conversationId, initialMessages);
}import java.util.HashMap;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;
public class CustomChatMemoryRepository implements ChatMemoryRepository {
private final Map<String, List<Message>> storage = new HashMap<>();
@Override
public void saveAll(String conversationId, List<Message> messages) {
storage.put(conversationId, new ArrayList<>(messages));
}
@Override
public List<Message> findByConversationId(String conversationId) {
return storage.getOrDefault(conversationId, new ArrayList<>());
}
@Override
public void deleteByConversationId(String conversationId) {
storage.remove(conversationId);
}
@Override
public List<String> findConversationIds() {
return new ArrayList<>(storage.keySet());
}
}import org.springframework.jdbc.core.JdbcTemplate;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
private final JdbcTemplate jdbcTemplate;
private final ObjectMapper objectMapper;
public JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, ObjectMapper objectMapper) {
this.jdbcTemplate = jdbcTemplate;
this.objectMapper = objectMapper;
}
@Override
public void saveAll(String conversationId, List<Message> messages) {
try {
String messagesJson = objectMapper.writeValueAsString(messages);
jdbcTemplate.update(
"INSERT INTO chat_memory (conversation_id, messages) VALUES (?, ?) " +
"ON CONFLICT (conversation_id) DO UPDATE SET messages = ?",
conversationId, messagesJson, messagesJson
);
} catch (Exception e) {
throw new RuntimeException("Failed to save chat memory", e);
}
}
@Override
public List<Message> findByConversationId(String conversationId) {
try {
String messagesJson = jdbcTemplate.queryForObject(
"SELECT messages FROM chat_memory WHERE conversation_id = ?",
String.class,
conversationId
);
return objectMapper.readValue(messagesJson,
objectMapper.getTypeFactory().constructCollectionType(List.class, Message.class)
);
} catch (Exception e) {
return new ArrayList<>();
}
}
@Override
public void deleteByConversationId(String conversationId) {
jdbcTemplate.update(
"DELETE FROM chat_memory WHERE conversation_id = ?",
conversationId
);
}
@Override
public List<String> findConversationIds() {
return jdbcTemplate.queryForList(
"SELECT conversation_id FROM chat_memory",
String.class
);
}
}import org.springframework.ai.chat.messages.MessageType;
public List<Message> getUserMessages(String conversationId) {
List<Message> allMessages = chatMemory.get(conversationId);
return allMessages.stream()
.filter(msg -> msg.getMessageType() == MessageType.USER)
.toList();
}
public List<Message> getAssistantMessages(String conversationId) {
List<Message> allMessages = chatMemory.get(conversationId);
return allMessages.stream()
.filter(msg -> msg.getMessageType() == MessageType.ASSISTANT)
.toList();
}import java.util.Set;
import java.util.HashSet;
@Service
public class ConversationManager {
private final ChatMemory chatMemory;
private final Set<String> activeConversations = new HashSet<>();
@Autowired
public ConversationManager(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
}
public String startConversation() {
String conversationId = generateConversationId();
activeConversations.add(conversationId);
return conversationId;
}
public void endConversation(String conversationId) {
chatMemory.clear(conversationId);
activeConversations.remove(conversationId);
}
public boolean isActiveConversation(String conversationId) {
return activeConversations.contains(conversationId);
}
public List<Message> getHistory(String conversationId) {
if (!isActiveConversation(conversationId)) {
throw new IllegalArgumentException("Conversation not found: " + conversationId);
}
return chatMemory.get(conversationId);
}
private String generateConversationId() {
return java.util.UUID.randomUUID().toString();
}
}public class TokenAwareChatMemory {
private final ChatMemory chatMemory;
private final int maxTokens;
public TokenAwareChatMemory(ChatMemory chatMemory, int maxTokens) {
this.chatMemory = chatMemory;
this.maxTokens = maxTokens;
}
public List<Message> getWithinTokenLimit(String conversationId) {
List<Message> messages = chatMemory.get(conversationId);
List<Message> result = new ArrayList<>();
int tokenCount = 0;
// Iterate from most recent to oldest
for (int i = messages.size() - 1; i >= 0; i--) {
Message message = messages.get(i);
int messageTokens = estimateTokens(message.getContent());
if (tokenCount + messageTokens > maxTokens) {
break;
}
result.add(0, message); // Add to beginning
tokenCount += messageTokens;
}
return result;
}
private int estimateTokens(String text) {
// Rough estimate: ~4 characters per token
return text.length() / 4;
}
}import jakarta.servlet.http.HttpSession;
@RestController
@RequestMapping("/api/chat")
public class SessionChatController {
private final ChatClient chatClient;
private final ChatMemory chatMemory;
@Autowired
public SessionChatController(ChatClient.Builder builder, ChatMemory chatMemory) {
this.chatMemory = chatMemory;
this.chatClient = builder
.defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
.build();
}
@PostMapping
public String chat(@RequestBody String message, HttpSession session) {
String conversationId = session.getId();
return chatClient.prompt()
.user(message)
.advisors(advisor -> advisor.param("conversationId", conversationId))
.call()
.content();
}
@DeleteMapping
public void clearHistory(HttpSession session) {
chatMemory.clear(session.getId());
}
}The starter auto-configures chat memory with in-memory storage:
@Configuration
@ConditionalOnClass(ChatMemory.class)
public class ChatMemoryAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public ChatMemoryRepository chatMemoryRepository() {
return new InMemoryChatMemoryRepository();
}
@Bean
@ConditionalOnMissingBean
public ChatMemory chatMemory(ChatMemoryRepository repository) {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(repository)
.build();
}
}Configure chat memory via application.properties:
# In-memory repository (default)
spring.ai.chat.memory.repository-type=in-memory
# JDBC repository
spring.ai.chat.memory.repository-type=jdbc
spring.ai.chat.memory.jdbc.table-name=chat_memory
spring.ai.chat.memory.jdbc.conversation-id-column=conversation_id
spring.ai.chat.memory.jdbc.messages-column=messages
# MongoDB repository
spring.ai.chat.memory.repository-type=mongo
spring.ai.chat.memory.mongo.collection-name=chat_memory
# Cassandra repository
spring.ai.chat.memory.repository-type=cassandra
spring.ai.chat.memory.cassandra.keyspace=chat
spring.ai.chat.memory.cassandra.table-name=chat_memory
# Neo4j repository
spring.ai.chat.memory.repository-type=neo4j
spring.ai.chat.memory.neo4j.label=ChatMemory
# CosmosDB repository
spring.ai.chat.memory.repository-type=cosmosdb
spring.ai.chat.memory.cosmosdb.container-name=chat-memory
spring.ai.chat.memory.cosmosdb.database-name=spring-ai
# Message window settings
spring.ai.chat.memory.max-messages=10To provide custom implementations, define your own beans:
@Configuration
public class CustomMemoryConfig {
@Bean
public ChatMemoryRepository customRepository() {
return new CustomChatMemoryRepository();
}
@Bean
public ChatMemory customChatMemory(ChatMemoryRepository repository) {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(repository)
.maxMessages(20)
.build();
}
}The MessageWindowChatMemory keeps a fixed number of recent messages:
// Keep last 10 messages using builder pattern
ChatMemory memory = MessageWindowChatMemory.builder()
.chatMemoryRepository(repository)
.maxMessages(10)
.build();Create a custom implementation that summarizes old messages:
public class SummaryChatMemory implements ChatMemory {
private final ChatMemoryRepository repository;
private final ChatClient summarizer;
private final int maxMessages;
public SummaryChatMemory(
ChatMemoryRepository repository,
ChatClient summarizer,
int maxMessages
) {
this.repository = repository;
this.summarizer = summarizer;
this.maxMessages = maxMessages;
}
@Override
public List<Message> get(String conversationId) {
List<Message> messages = repository.findByConversationId(conversationId);
if (messages.size() > maxMessages) {
// Summarize old messages
List<Message> oldMessages = messages.subList(0, messages.size() - maxMessages);
String summary = summarizeMessages(oldMessages);
List<Message> result = new ArrayList<>();
result.add(new SystemMessage("Previous conversation summary: " + summary));
result.addAll(messages.subList(messages.size() - maxMessages, messages.size()));
return result;
}
return messages;
}
private String summarizeMessages(List<Message> messages) {
String messagesText = messages.stream()
.map(m -> m.getMessageType() + ": " + m.getContent())
.collect(Collectors.joining("\n"));
return summarizer.prompt()
.user("Summarize this conversation:\n" + messagesText)
.call()
.content();
}
// Implement other methods...
}package org.springframework.ai.chat.messages;
public enum MessageType {
USER,
ASSISTANT,
SYSTEM,
FUNCTION
}package org.springframework.ai.chat.messages;
public interface Message {
String getContent();
MessageType getMessageType();
Map<String, Object> getMetadata();
}package org.springframework.ai.chat.messages;
public class UserMessage implements Message {
public UserMessage(String content);
public UserMessage(String content, List<Media> media);
}package org.springframework.ai.chat.messages;
public class AssistantMessage implements Message {
public AssistantMessage(String content);
public AssistantMessage(String content, Map<String, Object> metadata);
}package org.springframework.ai.chat.messages;
public class SystemMessage implements Message {
public SystemMessage(String content);
}tessl i tessl/maven-org-springframework-ai--spring-ai-starter-model-openai@1.1.1