Skip to content

记忆组件 ChatMemory

作用

大型语言模型 (LLM) 是无状态的,这意味着它们不会保留先前交互的信息。当希望在多个交互之间传递上下文或状态时,需要使用 ChatMemory 组件进行实现。

核心接口如下:

  • ChatMemoryRepository:处理记忆消息的底层存储和检索
  • ChatMemory:制定策略,实现决定保留哪些消息以及何时删除它们
  • ChatClient:通过 Advisors 机制来使用 ChatMemory

ChatMemory 有一个内置实现:MessageWindowChatMemory,其维护了一个消息窗口,窗口大小不超过指定的最大限制(默认值:20 条消息)。当消息数量超过此限制时,较旧的消息将被移除,但 system 消息将被保留。如果添加了新的 system 消息,所有先前的 system 消息都将从内存中删除。这确保了对话始终可以使用最新的上下文,同时保持内存使用量有限。

MessageWindowChatMemory 内部通过 ChatMemoryRepository 接口提供的实现来实现消息存储,ChatMemoryRepository 的实现包括 InMemoryChatMemoryRepository / JdbcChatMemoryRepository / CassandraChatMemoryRepository / Neo4jChatMemoryRepository 等,可以实现内存存储和持久化存储,用于实现短期记忆和长期记忆。默认使用 InMemoryChatMemoryRepository 进行内存存储。

接口定义

ChatMemory

java
public interface ChatMemory {
	/**
	 * Save the specified messages in the chat memory for the specified conversation.
	 */
	void add(String conversationId, List<Message> messages);

	/**
	 * Get the messages in the chat memory for the specified conversation.
	 */
	List<Message> get(String conversationId);

	/**
	 * Clear the chat memory for the specified conversation.
	 */
	void clear(String conversationId);
}

说明:

  • 每一个对话都需要设置一个 conversationId,用于区分不同的对话
  • ChatMemory 提供了向对话添加消息、从对话中检索消息以及清除对话历史记录的方法

ChatMemoryRepository

java
public interface ChatMemoryRepository {
	List<Message> findByConversationId(String conversationId);
	/**
	 * Replaces all the existing messages for the given conversation ID with the provided
	 * messages.
	 */
	void saveAll(String conversationId, List<Message> messages);
	void deleteByConversationId(String conversationId);
}

InMemoryChatMemoryRepository 内存存储短期记忆

java
public final class InMemoryChatMemoryRepository implements ChatMemoryRepository {
    // 存储容器
	Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();

	@Override
	public List<Message> findByConversationId(String conversationId) {
		List<Message> messages = this.chatMemoryStore.get(conversationId);
		return messages != null ? new ArrayList<>(messages) : List.of();
	}

	@Override
	public void saveAll(String conversationId, List<Message> messages) {
		this.chatMemoryStore.put(conversationId, messages);
	}

	@Override
	public void deleteByConversationId(String conversationId) {
		this.chatMemoryStore.remove(conversationId);
	}
}

JdbcChatMemoryRepository 数据库存储长期记忆

java
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
	private final JdbcTemplate jdbcTemplate;
	private final TransactionTemplate transactionTemplate;
	private final JdbcChatMemoryRepositoryDialect dialect;

	private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
			PlatformTransactionManager txManager) {
		this.jdbcTemplate = jdbcTemplate;
		this.dialect = dialect;
		this.transactionTemplate = new TransactionTemplate(
				txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
	}

	@Override
	public List<Message> findByConversationId(String conversationId) {
		return this.jdbcTemplate.query(this.dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
	}

	@Override
	public void saveAll(String conversationId, List<Message> messages) {
		this.transactionTemplate.execute(status -> {
			deleteByConversationId(conversationId);
			this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
					new AddBatchPreparedStatement(conversationId, messages));
			return null;
		});
	}

	@Override
	public void deleteByConversationId(String conversationId) {
		this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
	}
}

可以看到 JdbcChatMemoryRepository 调用的是 JdbcChatMemoryRepositoryDialect 中的 sql 模板来发起数据库访问的。

MysqlChatMemoryRepositoryDialect 内容如下:

java
public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {

	@Override
	public String getSelectMessagesSql() {
		return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
	}

	@Override
	public String getInsertMessageSql() {
		return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, `timestamp`) VALUES (?, ?, ?, ?)";
	}

	@Override
	public String getSelectConversationIdsSql() {
		return "SELECT DISTINCT conversation_id FROM SPRING_AI_CHAT_MEMORY";
	}

	@Override
	public String getDeleteMessagesSql() {
		return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
	}

}

内存短期记忆模式使用

java
@Bean
public ChatMemoryRepository inMemoryChatMemoryRepository() {
    return new InMemoryChatMemoryRepository();
}

@Bean
public ChatMemory messageWindowChatMemory() {
    return MessageWindowChatMemory.builder()
            .maxMessages(10)
            .chatMemoryRepository(inMemoryChatMemoryRepository()) // 设定仓储
            .build();
}

@Bean
public ChatClient customChatClient(ChatModel chatModel) {
    return ChatClient.builder(chatModel)
            .defaultAdvisors(MessageChatMemoryAdvisor.builder(messageWindowChatMemory()).build())
            .build();
}
java
@Resource
private ChatClient customChatClient;

@RequestMapping("/20001")
public String execute20001(@RequestParam("conversationId") String conversationId,
                           @RequestParam("userRequest") String userRequest) {
    return customChatClient.prompt()
            .advisors(advisorSpec -> advisorSpec.param(ChatMemory.CONVERSATION_ID, conversationId))
            .advisors(new SimpleLoggerAdvisor())
            .user(userRequest)
            .call().content();
}

测试:

text
http://localhost:8080/20001?conversationId=1&userRequest=My name is James Bond // 写入上下文信息到记忆中
http://localhost:8080/20001?conversationId=1&userRequest=What is my name? // 获取到 James Bond,表示记忆成功

数据库长期记忆模式使用

引入依赖

xml
<!-- memory-jdbc -->
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency>
<!-- 数据源 -->
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>druid</artifactId>
    <version>1.2.27</version>
</dependency>
<!-- mysql 驱动 -->
<dependency>
    <groupId>com.mysql</groupId>
    <artifactId>mysql-connector-j</artifactId>
    <version>8.4.0</version>
</dependency>

默认会自动配置一个 JdbcChatMemoryRepository Bean,我们只需直接注入即可。这里我们手动创建一个 Bean,方便了解底层原理。

java
@Bean
public DataSource dataSource() {
    DruidDataSource dataSource = new DruidDataSource();
    dataSource.setUrl("jdbc:mysql://localhost:3306/ai_demo");
    dataSource.setUsername("root");
    dataSource.setPassword("xxx");
    dataSource.setDriverClassName("com.mysql.jdbc.Driver");
    return dataSource;
}

@Bean
public JdbcTemplate jdbcTemplate() {
    return new JdbcTemplate(dataSource());
}

@Bean
public DataSourceTransactionManager dataSourceTransactionManager() {
    return new DataSourceTransactionManager(jdbcTemplate().getDataSource());
}

@Bean
public JdbcChatMemoryRepository jdbcChatMemoryRepository() {
    return JdbcChatMemoryRepository.builder()
            .jdbcTemplate(jdbcTemplate())
            .dialect(new MysqlChatMemoryRepositoryDialect())
            .transactionManager(dataSourceTransactionManager())
            .build();
}

@Bean
public ChatMemory messageWindowChatMemory() {
    return MessageWindowChatMemory.builder()
            .maxMessages(10)
            .chatMemoryRepository(jdbcChatMemoryRepository())
            .build();
}

@Bean
public ChatClient customChatClient(ChatModel chatModel) {
    return ChatClient.builder(chatModel)
            .defaultAdvisors(MessageChatMemoryAdvisor.builder(messageWindowChatMemory()).build())
            .build();
}

application.properties

text
# init database schema
spring.ai.chat.memory.repository.jdbc.initialize-schema=always

说明:

  • spring.ai.chat.memory.repository.jdbc.initialize-schema:是否自动创建数据库表,always-总是创建;never-不创建;embedded-内嵌型数据库(例如,H2)进行创建。自动创建表的语句在 org.springframework.ai.spring-ai-model-chat-memory-repository-jdbc 包下,例如 schema-mysql.sql 内容如下:
sql
CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
    `conversation_id` VARCHAR(36) NOT NULL,
    `content` TEXT NOT NULL,
    `type` ENUM('USER', 'ASSISTANT', 'SYSTEM', 'TOOL') NOT NULL,
    `timestamp` TIMESTAMP NOT NULL,

    INDEX `SPRING_AI_CHAT_MEMORY_CONVERSATION_ID_TIMESTAMP_IDX` (`conversation_id`, `timestamp`)
);

文章的最后,如果您觉得本文对您有用,请打赏一杯咖啡!感谢!