掘金 人工智能 04月27日 12:28
聊聊Spring AI Alibaba的RedisChatMemory
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了Spring AI Alibaba中RedisChatMemory的实现细节。该组件利用Redis作为消息存储,提供了添加、获取、清除和关闭会话消息的功能。通过JedisPool连接Redis,并使用ObjectMapper进行消息的序列化与反序列化。此外,还分析了MessageDeserializer,它负责将JSON消息转换为UserMessage或AssistantMessage,从而实现了基于Redis的聊天记忆功能。

💡 RedisChatMemory 是 Spring AI Alibaba 提供的基于 Redis 的 ChatMemory 实现,它实现了 ChatMemory 接口,用于存储和检索聊天消息。

🔑 RedisChatMemory 构造器初始化了 JedisPool,并为 ObjectMapper 注册了 MessageDeserializer。 JedisPool 用于管理 Redis 连接,ObjectMapper 用于消息的序列化和反序列化。

➕ add 方法将消息序列化为 JSON 后,使用 rpush 命令将其添加到 Redis 中,存储在以 conversationId 为键的 Redis 列表中。

📜 get 方法使用 lrange 命令从 Redis 中获取最近的 n 条消息,然后反序列化为 Message 对象。

🗑️ clear 方法使用 del 命令删除指定 conversationId 对应的所有消息。

🚪 close 方法关闭 Jedis 连接和 JedisPool 连接,释放资源。

🛠️ MessageDeserializer 负责将 JSON 消息反序列化为 UserMessage 或 AssistantMessage,它根据 messageType 字段选择相应的消息类型进行反序列化。

本文主要研究一下Spring AI Alibaba的RedisChatMemory

RedisChatMemory

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java

public class RedisChatMemory implements ChatMemory, AutoCloseable {  private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);  private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory";  private static final String DEFAULT_HOST = "127.0.0.1";  private static final int DEFAULT_PORT = 6379;  private static final String DEFAULT_PASSWORD = null;  private final JedisPool jedisPool;  private final Jedis jedis;  private final ObjectMapper objectMapper;  public RedisChatMemory() {    this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);  }  public RedisChatMemory(String host, int port, String password) {    JedisPoolConfig poolConfig = new JedisPoolConfig();    this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);    this.jedis = jedisPool.getResource();    this.objectMapper = new ObjectMapper();    SimpleModule module = new SimpleModule();    module.addDeserializer(Message.class, new MessageDeserializer());    this.objectMapper.registerModule(module);    logger.info("Connected to Redis at {}:{}", host, port);  }  @Override  public void add(String conversationId, List<Message> messages) {    String key = DEFAULT_KEY_PREFIX + conversationId;    for (Message message : messages) {      try {        String messageJson = objectMapper.writeValueAsString(message);        jedis.rpush(key, messageJson);      }      catch (JsonProcessingException e) {        throw new RuntimeException("Error serializing message", e);      }    }    logger.info("Added messages to conversationId: {}", conversationId);  }  @Override  public List<Message> get(String conversationId, int lastN) {    String key = DEFAULT_KEY_PREFIX + conversationId;    List<String> messageStrings = jedis.lrange(key, -lastN, -1);    List<Message> messages = new ArrayList<>();    for (String messageString : messageStrings) {      try {        Message message = objectMapper.readValue(messageString, Message.class);        messages.add(message);      }      catch (JsonProcessingException e) {        throw new RuntimeException("Error deserializing message", e);      }    }    logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId);    return messages;  }  @Override  public void clear(String conversationId) {    String key = DEFAULT_KEY_PREFIX + conversationId;    jedis.del(key);    logger.info("Cleared messages for conversationId: {}", conversationId);  }  @Override  public void close() {    if (jedis != null) {      jedis.close();      logger.info("Redis connection closed.");    }    if (jedisPool != null) {      jedisPool.close();      logger.info("Jedis pool closed.");    }  }  public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {    try {      String key = DEFAULT_KEY_PREFIX + conversationId;      List<String> all = jedis.lrange(key, 0, -1);      if (all.size() >= maxLimit) {        all = all.stream().skip(Math.max(0, deleteSize)).toList();      }      this.clear(conversationId);      for (String message : all) {        jedis.rpush(key, message);      }    }    catch (Exception e) {      logger.error("Error clearing messages from Redis chat memory", e);      throw new RuntimeException(e);    }  }  public void updateMessageById(String conversationId, String messages) {    String key = "spring_ai_alibaba_chat_memory:" + conversationId;    try {      this.jedis.del(key);      this.jedis.rpush(key, new String[] { messages });    }    catch (Exception var6) {      logger.error("Error updating messages from Redis chat memory", var6);      throw new RuntimeException(var6);    }  }}

RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了org.springframework.ai.chat.messages.Message类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool

MessageDeserializer

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java

public class MessageDeserializer extends JsonDeserializer<Message> {  private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class);  public Message deserialize(JsonParser p, DeserializationContext ctxt) {    ObjectMapper mapper = (ObjectMapper) p.getCodec();    JsonNode node = null;    Message message = null;    try {      node = mapper.readTree(p);      String messageType = node.get("messageType").asText();      switch (messageType) {        case "USER" -> message = new UserMessage(node.get("text").asText(),            mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {            }), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {            }));        case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(),            mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {            }), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"),                new TypeReference<Collection<AssistantMessage.ToolCall>>() {                }),            (List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {            }));        default -> throw new IllegalArgumentException("Unknown message type: " + messageType);      }      ;    }    catch (IOException e) {      logger.error("Error deserializing message", e);    }    return message;  }}

MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage

小结

spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。

doc

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Spring AI Redis ChatMemory Java
相关文章