序
本文主要研究一下Spring AI Alibaba的RedisChatMemory
RedisChatMemory
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java
public class RedisChatMemory implements ChatMemory, AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class); private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory"; private static final String DEFAULT_HOST = "127.0.0.1"; private static final int DEFAULT_PORT = 6379; private static final String DEFAULT_PASSWORD = null; private final JedisPool jedisPool; private final Jedis jedis; private final ObjectMapper objectMapper; public RedisChatMemory() { this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD); } public RedisChatMemory(String host, int port, String password) { JedisPoolConfig poolConfig = new JedisPoolConfig(); this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password); this.jedis = jedisPool.getResource(); this.objectMapper = new ObjectMapper(); SimpleModule module = new SimpleModule(); module.addDeserializer(Message.class, new MessageDeserializer()); this.objectMapper.registerModule(module); logger.info("Connected to Redis at {}:{}", host, port); } @Override public void add(String conversationId, List<Message> messages) { String key = DEFAULT_KEY_PREFIX + conversationId; for (Message message : messages) { try { String messageJson = objectMapper.writeValueAsString(message); jedis.rpush(key, messageJson); } catch (JsonProcessingException e) { throw new RuntimeException("Error serializing message", e); } } logger.info("Added messages to conversationId: {}", conversationId); } @Override public List<Message> get(String conversationId, int lastN) { String key = DEFAULT_KEY_PREFIX + conversationId; List<String> messageStrings = jedis.lrange(key, -lastN, -1); List<Message> messages = new ArrayList<>(); for (String messageString : messageStrings) { try { Message message = objectMapper.readValue(messageString, Message.class); messages.add(message); } catch (JsonProcessingException e) { throw new RuntimeException("Error deserializing message", e); } } logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId); return messages; } @Override public void clear(String conversationId) { String key = DEFAULT_KEY_PREFIX + conversationId; jedis.del(key); logger.info("Cleared messages for conversationId: {}", conversationId); } @Override public void close() { if (jedis != null) { jedis.close(); logger.info("Redis connection closed."); } if (jedisPool != null) { jedisPool.close(); logger.info("Jedis pool closed."); } } public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) { try { String key = DEFAULT_KEY_PREFIX + conversationId; List<String> all = jedis.lrange(key, 0, -1); if (all.size() >= maxLimit) { all = all.stream().skip(Math.max(0, deleteSize)).toList(); } this.clear(conversationId); for (String message : all) { jedis.rpush(key, message); } } catch (Exception e) { logger.error("Error clearing messages from Redis chat memory", e); throw new RuntimeException(e); } } public void updateMessageById(String conversationId, String messages) { String key = "spring_ai_alibaba_chat_memory:" + conversationId; try { this.jedis.del(key); this.jedis.rpush(key, new String[] { messages }); } catch (Exception var6) { logger.error("Error updating messages from Redis chat memory", var6); throw new RuntimeException(var6); } }}
RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了
org.springframework.ai.chat.messages.Message
类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}
中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool
MessageDeserializer
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java
public class MessageDeserializer extends JsonDeserializer<Message> { private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class); public Message deserialize(JsonParser p, DeserializationContext ctxt) { ObjectMapper mapper = (ObjectMapper) p.getCodec(); JsonNode node = null; Message message = null; try { node = mapper.readTree(p); String messageType = node.get("messageType").asText(); switch (messageType) { case "USER" -> message = new UserMessage(node.get("text").asText(), mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() { }), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() { })); case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() { }), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"), new TypeReference<Collection<AssistantMessage.ToolCall>>() { }), (List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() { })); default -> throw new IllegalArgumentException("Unknown message type: " + messageType); } ; } catch (IOException e) { logger.error("Error deserializing message", e); } return message; }}
MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage
小结
spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。