掘金 人工智能 04月29日 15:28
聊聊Spring AI Alibaba的PlantUMLGenerator
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入研究了Spring AI Alibaba项目中的PlantUMLGenerator,一个用于生成流程图的组件。文章分析了DiagramGenerator抽象类及其子类PlantUMLGenerator,探讨了它们在图生成过程中的作用和实现细节。通过对核心方法的解读,揭示了如何利用PlantUMLGenerator将代码逻辑转化为可视化的PlantUML图表,从而帮助开发者更好地理解和维护代码。

💡 `DiagramGenerator` 抽象类定义了流程图生成的框架,它包含了生成流程图所需的基础方法,如 `appendHeader`、`appendFooter`、`call` 等,这些方法需要在子类中实现。

🔨 `PlantUMLGenerator` 继承自 `DiagramGenerator`,负责具体的PlantUML图表生成。它通过重写 `appendHeader` 等方法,将代码逻辑转化为PlantUML的语法,从而创建图表。

🔄 `generate` 方法是核心,它接收节点、边和上下文作为输入,遍历节点和边,根据条件调用不同的方法,最终生成PlantUML的文本表示。

🚦 `Context` 类用于存储生成图表所需的上下文信息,如标题、是否打印条件边等。它还提供了将标题转换为snake_case格式的方法,方便PlantUML图表的命名。

本文主要研究一下Spring AI Alibaba的PlantUMLGenerator

DiagramGenerator

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/DiagramGenerator.java

public abstract class DiagramGenerator {  public enum CallStyle {    DEFAULT, START, END, CONDITIONAL, PARALLEL  }  public record Context(StringBuilder sb, String title, boolean printConditionalEdge, boolean isSubGraph) {    static Builder builder() {      return new Builder();    }    static public class Builder {      String title;      boolean printConditionalEdge;      boolean IsSubGraph;      private Builder() {      }      public Builder title(String title) {        this.title = title;        return this;      }      public Builder printConditionalEdge(boolean value) {        this.printConditionalEdge = value;        return this;      }      public Builder isSubGraph(boolean value) {        this.IsSubGraph = value;        return this;      }      public Context build() {        return new Context(new StringBuilder(), title, printConditionalEdge, IsSubGraph);      }    }    /**     * Converts a given title string to snake_case format by replacing all     * non-alphanumeric characters with underscores.     * @return the snake_case formatted string     */    public String titleToSnakeCase() {      return title.replaceAll("[^a-zA-Z0-9]", "_");    }    /**     * Returns a string representation of this object by returning the string built in     * {@link #sb}.     * @return a string representation of this object.     */    @Override    public String toString() {      return sb.toString();    }  }  /**   * Appends a header to the output based on the provided context.   * @param ctx The {@link Context} containing the information needed for appending the   * header.   */  protected abstract void appendHeader(Context ctx);  /**   * Appends a footer to the content.   * @param ctx Context object containing the necessary information.   */  protected abstract void appendFooter(Context ctx);  /**   * This method is an abstract method that must be implemented by subclasses. It is   * used to initiate a communication call between two parties identified by their phone   * numbers.   * @param ctx The current context in which the call is being made.   * @param from The phone number of the caller.   * @param to The phone number of the recipient.   */  protected abstract void call(Context ctx, String from, String to, CallStyle style);  /**   * Abstract method that must be implemented by subclasses to handle the logic of   * making a call.   * @param ctx The context in which the call is being made.   * @param from The phone number of the caller.   * @param to The phone number of the recipient.   * @param description A brief description of the call.   */  protected abstract void call(Context ctx, String from, String to, String description, CallStyle style);  /**   * Declares a conditional element in the configuration or template. This method is   * used to mark the start of a conditional section based on the provided {@code name}.   * It takes a {@code Context} object that may contain additional parameters necessary   * for the declaration, and a {@code name} which identifies the type or key associated   * with the conditional section.   * @param ctx The context containing contextual information needed for the   * declaration.   * @param name The name of the conditional section to be declared.   */  protected abstract void declareConditionalStart(Context ctx, String name);  /**   * Declares a node in the specified context with the given name.   * @param ctx the context in which to declare the node {@code @literal (not null)}   * @param name the name of the node to be declared   * {@code @literal (not null, not empty)}   */  protected abstract void declareNode(Context ctx, String name);  /**   * Declares a conditional edge in the context with a specified ordinal.   * @param ctx the context   * @param ordinal the ordinal value   */  protected abstract void declareConditionalEdge(Context ctx, int ordinal);  /**   * Comment a line in the given context.   * @param ctx The context in which the line is to be commented.   * @param yesOrNo Whether the line should be uncommented ({@literal true}) or   * commented ({@literal false}).   */  protected abstract void commentLine(Context ctx, boolean yesOrNo);  /**   * Generate a textual representation of the given graph.   * @param nodes the state graph nodes used to generate the context, which must not be   * null   * @param edges the state graph edges used to generate the context, which must not be   * null   * @param title The title of the graph.   * @param printConditionalEdge Whether to print the conditional edge condition.   * @return A string representation of the graph.   */  public final String generate(StateGraph.Nodes nodes, StateGraph.Edges edges, String title,      boolean printConditionalEdge) {    return generate(nodes, edges,        Context.builder().title(title).isSubGraph(false).printConditionalEdge(printConditionalEdge).build())      .toString();  }  /**   * Generates a context based on the given state graph.   * @param nodes the state graph nodes used to generate the context, which must not be   * null   * @param edges the state graph edges used to generate the context, which must not be   * null   * @param ctx the initial context, which must not be null   * @return the generated context, which will not be null   */  protected final Context generate(StateGraph.Nodes nodes, StateGraph.Edges edges, Context ctx) {    appendHeader(ctx);    for (var n : nodes.elements) {      if (n instanceof SubGraphNode subGraphNode) {        @SuppressWarnings("unchecked")        var subGraph = (StateGraph) subGraphNode.subGraph();        Context subgraphCtx = generate(subGraph.nodes, subGraph.edges,            Context.builder()              .title(n.id())              .printConditionalEdge(ctx.printConditionalEdge)              .isSubGraph(true)              .build());        ctx.sb().append(subgraphCtx);      }      else {        declareNode(ctx, n.id());      }    }    final int[] conditionalEdgeCount = { 0 };    edges.elements.stream()      .filter(e -> !Objects.equals(e.sourceId(), START))      .filter(e -> !e.isParallel())      .forEach(e -> {        if (e.target().value() != null) {          conditionalEdgeCount[0] += 1;          commentLine(ctx, !ctx.printConditionalEdge());          declareConditionalEdge(ctx, conditionalEdgeCount[0]);        }      });    var edgeStart = edges.elements.stream()      .filter(e -> Objects.equals(e.sourceId(), START))      .findFirst()      .orElseThrow();    if (edgeStart.isParallel()) {      edgeStart.targets().forEach(target -> {        call(ctx, START, target.id(), CallStyle.START);      });    }    else if (edgeStart.target().id() != null) {      call(ctx, START, edgeStart.target().id(), CallStyle.START);    }    else if (edgeStart.target().value() != null) {      String conditionName = "startcondition";      commentLine(ctx, !ctx.printConditionalEdge());      declareConditionalStart(ctx, conditionName);      edgeCondition(ctx, edgeStart.target().value(), START, conditionName);    }    conditionalEdgeCount[0] = 0; // reset    edges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).forEach(v -> {      if (v.isParallel()) {        v.targets().forEach(target -> {          call(ctx, v.sourceId(), target.id(), CallStyle.PARALLEL);        });      }      else if (v.target().id() != null) {        call(ctx, v.sourceId(), v.target().id(), CallStyle.DEFAULT);      }      else if (v.target().value() != null) {        conditionalEdgeCount[0] += 1;        String conditionName = format("condition%d", conditionalEdgeCount[0]);        edgeCondition(ctx, v.targets().get(0).value(), v.sourceId(), conditionName);      }    });    appendFooter(ctx);    return ctx;  }  /**   * Evaluates an edge condition based on the given context and condition.   * @param ctx the current context used for evaluation   * @param condition the condition to be evaluated   * @param k a string identifier for the condition   * @param conditionName the name of the condition being processed   */  private void edgeCondition(Context ctx, EdgeCondition condition, String k, String conditionName) {    commentLine(ctx, !ctx.printConditionalEdge());    call(ctx, k, conditionName, CallStyle.CONDITIONAL);    condition.mappings().forEach((cond, to) -> {      commentLine(ctx, !ctx.printConditionalEdge());      call(ctx, conditionName, to, cond, CallStyle.CONDITIONAL);      commentLine(ctx, ctx.printConditionalEdge());      call(ctx, k, to, cond, CallStyle.CONDITIONAL);    });  }}

DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。

PlantUMLGenerator

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/diagram/PlantUMLGenerator.java

public class PlantUMLGenerator extends DiagramGenerator {  @Override  protected void appendHeader(Context ctx) {    if (ctx.isSubGraph()) {      ctx.sb()        .append(format("rectangle %s [ {{\ntitle \"%s\"\n", ctx.title(), ctx.title()))        .append(format("circle \" \" as %s\n", START))        .append(format("circle exit as %s\n", END));    }    else {      ctx.sb()        .append(format("@startuml %s\n", ctx.titleToSnakeCase()))        .append("skinparam usecaseFontSize 14\n")        .append("skinparam usecaseStereotypeFontSize 12\n")        .append("skinparam hexagonFontSize 14\n")        .append("skinparam hexagonStereotypeFontSize 12\n")        .append(format("title \"%s\"\n", ctx.title()))        .append("footer\n\n")        .append("powered by spring-ai-alibaba\n")        .append("end footer\n")        .append(format("circle start<<input>> as %s\n", START))        .append(format("circle stop as %s\n", END));    }  }  @Override  protected void appendFooter(Context ctx) {    if (ctx.isSubGraph()) {      ctx.sb().append("\n}} ]\n");    }    else {      ctx.sb().append("@enduml\n");    }  }  @Override  protected void call(Context ctx, String from, String to, CallStyle style) {    ctx.sb().append(switch (style) {      case CONDITIONAL -> format("\"%s\" .down.> \"%s\"\n", from, to);      default -> format("\"%s\" -down-> \"%s\"\n", from, to);    });  }  @Override  protected void call(Context ctx, String from, String to, String description, CallStyle style) {    ctx.sb().append(switch (style) {      case CONDITIONAL -> format("\"%s\" .down.> \"%s\": \"%s\"\n", from, to, description);      default -> format("\"%s\" -down-> \"%s\": \"%s\"\n", from, to, description);    });  }  @Override  protected void declareConditionalStart(Context ctx, String name) {    ctx.sb().append(format("hexagon \"check state\" as %s<<Condition>>\n", name));  }  @Override  protected void declareNode(Context ctx, String name) {    ctx.sb().append(format("usecase \"%s\"<<Node>>\n", name));  }  @Override  protected void declareConditionalEdge(Context ctx, int ordinal) {    ctx.sb().append(format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal));  }  @Override  protected void commentLine(Context ctx, boolean yesOrNo) {    if (yesOrNo)      ctx.sb().append("'");  }}

PlantUMLGenerator实现了DiagramGenerator的抽象方法

StateGraph

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java

/** * Represents a state graph with nodes and edges. * */public class StateGraph {  public static String END = "__END__";  public static String START = "__START__";  final Nodes nodes = new Nodes();  final Edges edges = new Edges();  private OverAllState overAllState;  private String name;  public OverAllState getOverAllState() {    return overAllState;  }  public StateGraph setOverAllState(OverAllState overAllState) {    this.overAllState = overAllState;    return this;  }  private final PlainTextStateSerializer stateSerializer;  //......  /**   * Instantiates a new State graph.   * @param overAllState the over all state   * @param plainTextStateSerializer the plain text state serializer   */  public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) {    this.overAllState = overAllState;    this.stateSerializer = plainTextStateSerializer;  }  public StateGraph(String name, OverAllState overAllState) {    this.name = name;    this.overAllState = overAllState;    this.stateSerializer = new GsonSerializer();  }  /**   * Instantiates a new State graph.   * @param overAllState the over all state   */  public StateGraph(OverAllState overAllState) {    this.overAllState = overAllState;    this.stateSerializer = new GsonSerializer();  }  public StateGraph(String name, AgentStateFactory<OverAllState> factory) {    this.name = name;    this.overAllState = factory.apply(Map.of());    this.stateSerializer = new GsonSerializer2(factory);  }  public StateGraph(AgentStateFactory<OverAllState> factory) {    this.overAllState = factory.apply(Map.of());    this.stateSerializer = new GsonSerializer2(factory);  }  /**   * Instantiates a new State graph.   */  public StateGraph() {    this.stateSerializer = new GsonSerializer();  }  public String getName() {    return name;  }  /**   * Key strategies map.   * @return the map   */  public Map<String, KeyStrategy> keyStrategies() {    return overAllState.keyStrategies();  }  /**   * Gets state serializer.   * @return the state serializer   */  public StateSerializer getStateSerializer() {    return stateSerializer;  }  /**   * Gets state factory.   * @return the state factory   */  public final AgentStateFactory<OverAllState> getStateFactory() {    return stateSerializer.stateFactory();  }  /**   * /** Adds a node to the graph.   * @param id the identifier of the node   * @param action the action to be performed by the node   * @throws GraphStateException if the node identifier is invalid or the node already   * exists   */  public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {    return addNode(id, AsyncNodeActionWithConfig.of(action));  }  /**   * @param id the identifier of the node   * @param actionWithConfig the action to be performed by the node   * @return this   * @throws GraphStateException if the node identifier is invalid or the node already   * exists   */  public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {    Node node = new Node(id, (config) -> actionWithConfig);    return addNode(id, node);  }  /**   * @param id the identifier of the node   * @param node the node to be added   * @return this   * @throws GraphStateException if the node identifier is invalid or the node already   * exists   */  public StateGraph addNode(String id, Node node) throws GraphStateException {    if (Objects.equals(node.id(), END)) {      throw Errors.invalidNodeIdentifier.exception(END);    }    if (!Objects.equals(node.id(), id)) {      throw Errors.invalidNodeIdentifier.exception(node.id(), id);    }    if (nodes.elements.contains(node)) {      throw Errors.duplicateNodeError.exception(id);    }    nodes.elements.add(node);    return this;  }  /**   * Adds a subgraph to the state graph by creating a node with the specified   * identifier. This implies that Subgraph share the same state with parent graph   * @param id the identifier of the node representing the subgraph   * @param subGraph the compiled subgraph to be added   * @return this state graph instance   * @throws GraphStateException if the node identifier is invalid or the node already   * exists   */  public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {    if (Objects.equals(id, END)) {      throw Errors.invalidNodeIdentifier.exception(END);    }    var node = new SubCompiledGraphNode(id, subGraph);    if (nodes.elements.contains(node)) {      throw Errors.duplicateNodeError.exception(id);    }    nodes.elements.add(node);    return this;  }  /**   * Adds a subgraph to the state graph by creating a node with the specified   * identifier. This implies that Subgraph share the same state with parent graph   * @param id the identifier of the node representing the subgraph   * @param subGraph the subgraph to be added. it will be compiled on compilation of the   * parent   * @return this state graph instance   * @throws GraphStateException if the node identifier is invalid or the node already   * exists   */  public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {    if (Objects.equals(id, END)) {      throw Errors.invalidNodeIdentifier.exception(END);    }    subGraph.validateGraph();    OverAllState subGraphOverAllState = subGraph.getOverAllState();    OverAllState superOverAllState = getOverAllState();    if (subGraphOverAllState != null) {      Map<String, KeyStrategy> strategies = subGraphOverAllState.keyStrategies();      for (Map.Entry<String, KeyStrategy> strategyEntry : strategies.entrySet()) {        if (!superOverAllState.containStrategy(strategyEntry.getKey())) {          superOverAllState.registerKeyAndStrategy(strategyEntry.getKey(), strategyEntry.getValue());        }      }    }    subGraph.setOverAllState(getOverAllState());    var node = new SubStateGraphNode(id, subGraph);    if (nodes.elements.contains(node)) {      throw Errors.duplicateNodeError.exception(id);    }    nodes.elements.add(node);    return this;  }  /**   * Adds an edge to the graph.   * @param sourceId the identifier of the source node   * @param targetId the identifier of the target node   * @throws GraphStateException if the edge identifier is invalid or the edge already   * exists   */  public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {    if (Objects.equals(sourceId, END)) {      throw Errors.invalidEdgeIdentifier.exception(END);    }    // if (Objects.equals(sourceId, START)) {    // this.entryPoint = new EdgeValue<>(targetId);    // return this;    // }    var newEdge = new Edge(sourceId, new EdgeValue(targetId));    int index = edges.elements.indexOf(newEdge);    if (index >= 0) {      var newTargets = new ArrayList<>(edges.elements.get(index).targets());      newTargets.add(newEdge.target());      edges.elements.set(index, new Edge(sourceId, newTargets));    }    else {      edges.elements.add(newEdge);    }    return this;  }  /**   * Adds conditional edges to the graph.   * @param sourceId the identifier of the source node   * @param condition the condition to determine the target node   * @param mappings the mappings of conditions to target nodes   * @throws GraphStateException if the edge identifier is invalid, the mappings are   * empty, or the edge already exists   */  public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)      throws GraphStateException {    if (Objects.equals(sourceId, END)) {      throw Errors.invalidEdgeIdentifier.exception(END);    }    if (mappings == null || mappings.isEmpty()) {      throw Errors.edgeMappingIsEmpty.exception(sourceId);    }    var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));    if (edges.elements.contains(newEdge)) {      throw Errors.duplicateConditionalEdgeError.exception(sourceId);    }    else {      edges.elements.add(newEdge);    }    return this;  }  void validateGraph() throws GraphStateException {    var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);    edgeStart.validate(nodes);    for (Edge edge : edges.elements) {      edge.validate(nodes);    }  }  /**   * Compiles the state graph into a compiled graph.   * @param config the compile configuration   * @return a compiled graph   * @throws GraphStateException if there are errors related to the graph state   */  public CompiledGraph compile(CompileConfig config) throws GraphStateException {    Objects.requireNonNull(config, "config cannot be null");    validateGraph();    return new CompiledGraph(this, config);  }  /**   * Compiles the state graph into a compiled graph.   * @return a compiled graph   * @throws GraphStateException if there are errors related to the graph state   */  public CompiledGraph compile() throws GraphStateException {    SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();    return compile(CompileConfig.builder()      .plainTextStateSerializer(new JacksonSerializer())      .saverConfig(saverConfig)      .build());  }  /**   * Generates a drawable graph representation of the state graph.   * @param type the type of graph representation to generate   * @param title the title of the graph   * @param printConditionalEdges whether to print conditional edges   * @return a diagram code of the state graph   */  public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {    String content = type.generator.generate(nodes, edges, title, printConditionalEdges);    return new GraphRepresentation(type, content);  }  /**   * Generates a drawable graph representation of the state graph.   * @param type the type of graph representation to generate   * @param title the title of the graph   * @return a diagram code of the state graph   */  public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {    String content = type.generator.generate(nodes, edges, title, true);    return new GraphRepresentation(type, content);  }  public GraphRepresentation getGraph(GraphRepresentation.Type type) {    String content = type.generator.generate(nodes, edges, name, true);    return new GraphRepresentation(type, content);  }  //......}        

StateGraph提供了addNode、addEdge、addConditionalEdges等方法,其中getGraph方法根据指定GraphRepresentation.Type的DiagramGenerator来生成状态图

示例

  @Test  public void testGraph() throws GraphStateException {    OverAllState overAllState = getOverAllState();    StateGraph workflow = new StateGraph(overAllState).addNode("agent_1", node_async(state -> {          System.out.println("agent_1");          return Map.of("messages", "message1");        })).addNode("agent_2", node_async(state -> {          System.out.println("agent_2");          return Map.of("messages", new String[] { "message2" });        })).addNode("agent_3", node_async(state -> {          System.out.println("agent_3");          List<String> messages = Optional.ofNullable(state.value("messages").get())              .filter(List.class::isInstance)              .map(List.class::cast)              .orElse(new ArrayList<>());          int steps = messages.size() + 1;          return Map.of("messages", "message3", "steps", steps);        }))        .addEdge("agent_1", "agent_2")        .addEdge("agent_2", "agent_3")        .addEdge(StateGraph.START, "agent_1")        .addEdge("agent_3", StateGraph.END);    GraphRepresentation representation = workflow.getGraph(GraphRepresentation.Type.PLANTUML, "demo");    System.out.println(representation.content());  }

输出如下:

@startuml demoskinparam usecaseFontSize 14skinparam usecaseStereotypeFontSize 12skinparam hexagonFontSize 14skinparam hexagonStereotypeFontSize 12title "demo"footerpowered by spring-ai-alibabaend footercircle start<<input>> as __START__circle stop as __END__usecase "agent_1"<<Node>>usecase "agent_2"<<Node>>usecase "agent_3"<<Node>>"__START__" -down-> "agent_1""agent_1" -down-> "agent_2""agent_2" -down-> "agent_3""agent_3" -down-> "__END__"@enduml

小结

DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。PlantUMLGenerator继承了DiagramGenerator,根据plantUML语法实现了抽象方法。

doc

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Spring AI PlantUML 流程图生成 DiagramGenerator
相关文章