记忆组件 ChatMemory
作用
大型语言模型 (LLM) 是无状态的,这意味着它们不会保留先前交互的信息。当希望在多个交互之间传递上下文或状态时,需要使用 ChatMemory
组件进行实现。
核心接口如下:
ChatMemoryRepository
:处理记忆消息的底层存储和检索ChatMemory
:制定策略,实现决定保留哪些消息以及何时删除它们ChatClient
:通过Advisors
机制来使用ChatMemory
ChatMemory
有一个内置实现:MessageWindowChatMemory
,其维护了一个消息窗口,窗口大小不超过指定的最大限制(默认值:20 条消息)。当消息数量超过此限制时,较旧的消息将被移除,但 system 消息将被保留。如果添加了新的 system 消息,所有先前的 system 消息都将从内存中删除。这确保了对话始终可以使用最新的上下文,同时保持内存使用量有限。
MessageWindowChatMemory
内部通过 ChatMemoryRepository
接口提供的实现来实现消息存储,ChatMemoryRepository
的实现包括 InMemoryChatMemoryRepository
/ JdbcChatMemoryRepository
/ CassandraChatMemoryRepository
/ Neo4jChatMemoryRepository
等,可以实现内存存储和持久化存储,用于实现短期记忆和长期记忆。默认使用 InMemoryChatMemoryRepository
进行内存存储。
接口定义
ChatMemory
public interface ChatMemory {
/**
* Save the specified messages in the chat memory for the specified conversation.
*/
void add(String conversationId, List<Message> messages);
/**
* Get the messages in the chat memory for the specified conversation.
*/
List<Message> get(String conversationId);
/**
* Clear the chat memory for the specified conversation.
*/
void clear(String conversationId);
}
说明:
- 每一个对话都需要设置一个
conversationId
,用于区分不同的对话 ChatMemory
提供了向对话添加消息、从对话中检索消息以及清除对话历史记录的方法
ChatMemoryRepository
public interface ChatMemoryRepository {
List<Message> findByConversationId(String conversationId);
/**
* Replaces all the existing messages for the given conversation ID with the provided
* messages.
*/
void saveAll(String conversationId, List<Message> messages);
void deleteByConversationId(String conversationId);
}
InMemoryChatMemoryRepository 内存存储短期记忆
public final class InMemoryChatMemoryRepository implements ChatMemoryRepository {
// 存储容器
Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();
@Override
public List<Message> findByConversationId(String conversationId) {
List<Message> messages = this.chatMemoryStore.get(conversationId);
return messages != null ? new ArrayList<>(messages) : List.of();
}
@Override
public void saveAll(String conversationId, List<Message> messages) {
this.chatMemoryStore.put(conversationId, messages);
}
@Override
public void deleteByConversationId(String conversationId) {
this.chatMemoryStore.remove(conversationId);
}
}
JdbcChatMemoryRepository 数据库存储长期记忆
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
private final JdbcTemplate jdbcTemplate;
private final TransactionTemplate transactionTemplate;
private final JdbcChatMemoryRepositoryDialect dialect;
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
PlatformTransactionManager txManager) {
this.jdbcTemplate = jdbcTemplate;
this.dialect = dialect;
this.transactionTemplate = new TransactionTemplate(
txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
}
@Override
public List<Message> findByConversationId(String conversationId) {
return this.jdbcTemplate.query(this.dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
}
@Override
public void saveAll(String conversationId, List<Message> messages) {
this.transactionTemplate.execute(status -> {
deleteByConversationId(conversationId);
this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
new AddBatchPreparedStatement(conversationId, messages));
return null;
});
}
@Override
public void deleteByConversationId(String conversationId) {
this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
}
}
可以看到 JdbcChatMemoryRepository
调用的是 JdbcChatMemoryRepositoryDialect
中的 sql 模板来发起数据库访问的。
MysqlChatMemoryRepositoryDialect
内容如下:
public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {
@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
}
@Override
public String getInsertMessageSql() {
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, `timestamp`) VALUES (?, ?, ?, ?)";
}
@Override
public String getSelectConversationIdsSql() {
return "SELECT DISTINCT conversation_id FROM SPRING_AI_CHAT_MEMORY";
}
@Override
public String getDeleteMessagesSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
}
}
内存短期记忆模式使用
@Bean
public ChatMemoryRepository inMemoryChatMemoryRepository() {
return new InMemoryChatMemoryRepository();
}
@Bean
public ChatMemory messageWindowChatMemory() {
return MessageWindowChatMemory.builder()
.maxMessages(10)
.chatMemoryRepository(inMemoryChatMemoryRepository()) // 设定仓储
.build();
}
@Bean
public ChatClient customChatClient(ChatModel chatModel) {
return ChatClient.builder(chatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(messageWindowChatMemory()).build())
.build();
}
@Resource
private ChatClient customChatClient;
@RequestMapping("/20001")
public String execute20001(@RequestParam("conversationId") String conversationId,
@RequestParam("userRequest") String userRequest) {
return customChatClient.prompt()
.advisors(advisorSpec -> advisorSpec.param(ChatMemory.CONVERSATION_ID, conversationId))
.advisors(new SimpleLoggerAdvisor())
.user(userRequest)
.call().content();
}
测试:
http://localhost:8080/20001?conversationId=1&userRequest=My name is James Bond // 写入上下文信息到记忆中
http://localhost:8080/20001?conversationId=1&userRequest=What is my name? // 获取到 James Bond,表示记忆成功
数据库长期记忆模式使用
引入依赖
<!-- memory-jdbc -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency>
<!-- 数据源 -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.2.27</version>
</dependency>
<!-- mysql 驱动 -->
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.4.0</version>
</dependency>
默认会自动配置一个 JdbcChatMemoryRepository
Bean,我们只需直接注入即可。这里我们手动创建一个 Bean,方便了解底层原理。
@Bean
public DataSource dataSource() {
DruidDataSource dataSource = new DruidDataSource();
dataSource.setUrl("jdbc:mysql://localhost:3306/ai_demo");
dataSource.setUsername("root");
dataSource.setPassword("xxx");
dataSource.setDriverClassName("com.mysql.jdbc.Driver");
return dataSource;
}
@Bean
public JdbcTemplate jdbcTemplate() {
return new JdbcTemplate(dataSource());
}
@Bean
public DataSourceTransactionManager dataSourceTransactionManager() {
return new DataSourceTransactionManager(jdbcTemplate().getDataSource());
}
@Bean
public JdbcChatMemoryRepository jdbcChatMemoryRepository() {
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate())
.dialect(new MysqlChatMemoryRepositoryDialect())
.transactionManager(dataSourceTransactionManager())
.build();
}
@Bean
public ChatMemory messageWindowChatMemory() {
return MessageWindowChatMemory.builder()
.maxMessages(10)
.chatMemoryRepository(jdbcChatMemoryRepository())
.build();
}
@Bean
public ChatClient customChatClient(ChatModel chatModel) {
return ChatClient.builder(chatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(messageWindowChatMemory()).build())
.build();
}
application.properties
# init database schema
spring.ai.chat.memory.repository.jdbc.initialize-schema=always
说明:
spring.ai.chat.memory.repository.jdbc.initialize-schema
:是否自动创建数据库表,always-总是创建;never-不创建;embedded-内嵌型数据库(例如,H2)进行创建。自动创建表的语句在org.springframework.ai.spring-ai-model-chat-memory-repository-jdbc
包下,例如schema-mysql.sql
内容如下:
CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
`conversation_id` VARCHAR(36) NOT NULL,
`content` TEXT NOT NULL,
`type` ENUM('USER', 'ASSISTANT', 'SYSTEM', 'TOOL') NOT NULL,
`timestamp` TIMESTAMP NOT NULL,
INDEX `SPRING_AI_CHAT_MEMORY_CONVERSATION_ID_TIMESTAMP_IDX` (`conversation_id`, `timestamp`)
);
文章的最后,如果您觉得本文对您有用,请打赏一杯咖啡!感谢!
