序
本文主要研究一下Spring AI Alibaba的PlantUMLGenerator
DiagramGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/DiagramGenerator.java
public abstract class DiagramGenerator { public enum CallStyle { DEFAULT, START, END, CONDITIONAL, PARALLEL } public record Context(StringBuilder sb, String title, boolean printConditionalEdge, boolean isSubGraph) { static Builder builder() { return new Builder(); } static public class Builder { String title; boolean printConditionalEdge; boolean IsSubGraph; private Builder() { } public Builder title(String title) { this.title = title; return this; } public Builder printConditionalEdge(boolean value) { this.printConditionalEdge = value; return this; } public Builder isSubGraph(boolean value) { this.IsSubGraph = value; return this; } public Context build() { return new Context(new StringBuilder(), title, printConditionalEdge, IsSubGraph); } } /** * Converts a given title string to snake_case format by replacing all * non-alphanumeric characters with underscores. * @return the snake_case formatted string */ public String titleToSnakeCase() { return title.replaceAll("[^a-zA-Z0-9]", "_"); } /** * Returns a string representation of this object by returning the string built in * {@link #sb}. * @return a string representation of this object. */ @Override public String toString() { return sb.toString(); } } /** * Appends a header to the output based on the provided context. * @param ctx The {@link Context} containing the information needed for appending the * header. */ protected abstract void appendHeader(Context ctx); /** * Appends a footer to the content. * @param ctx Context object containing the necessary information. */ protected abstract void appendFooter(Context ctx); /** * This method is an abstract method that must be implemented by subclasses. It is * used to initiate a communication call between two parties identified by their phone * numbers. * @param ctx The current context in which the call is being made. * @param from The phone number of the caller. * @param to The phone number of the recipient. */ protected abstract void call(Context ctx, String from, String to, CallStyle style); /** * Abstract method that must be implemented by subclasses to handle the logic of * making a call. * @param ctx The context in which the call is being made. * @param from The phone number of the caller. * @param to The phone number of the recipient. * @param description A brief description of the call. */ protected abstract void call(Context ctx, String from, String to, String description, CallStyle style); /** * Declares a conditional element in the configuration or template. This method is * used to mark the start of a conditional section based on the provided {@code name}. * It takes a {@code Context} object that may contain additional parameters necessary * for the declaration, and a {@code name} which identifies the type or key associated * with the conditional section. * @param ctx The context containing contextual information needed for the * declaration. * @param name The name of the conditional section to be declared. */ protected abstract void declareConditionalStart(Context ctx, String name); /** * Declares a node in the specified context with the given name. * @param ctx the context in which to declare the node {@code @literal (not null)} * @param name the name of the node to be declared * {@code @literal (not null, not empty)} */ protected abstract void declareNode(Context ctx, String name); /** * Declares a conditional edge in the context with a specified ordinal. * @param ctx the context * @param ordinal the ordinal value */ protected abstract void declareConditionalEdge(Context ctx, int ordinal); /** * Comment a line in the given context. * @param ctx The context in which the line is to be commented. * @param yesOrNo Whether the line should be uncommented ({@literal true}) or * commented ({@literal false}). */ protected abstract void commentLine(Context ctx, boolean yesOrNo); /** * Generate a textual representation of the given graph. * @param nodes the state graph nodes used to generate the context, which must not be * null * @param edges the state graph edges used to generate the context, which must not be * null * @param title The title of the graph. * @param printConditionalEdge Whether to print the conditional edge condition. * @return A string representation of the graph. */ public final String generate(StateGraph.Nodes nodes, StateGraph.Edges edges, String title, boolean printConditionalEdge) { return generate(nodes, edges, Context.builder().title(title).isSubGraph(false).printConditionalEdge(printConditionalEdge).build()) .toString(); } /** * Generates a context based on the given state graph. * @param nodes the state graph nodes used to generate the context, which must not be * null * @param edges the state graph edges used to generate the context, which must not be * null * @param ctx the initial context, which must not be null * @return the generated context, which will not be null */ protected final Context generate(StateGraph.Nodes nodes, StateGraph.Edges edges, Context ctx) { appendHeader(ctx); for (var n : nodes.elements) { if (n instanceof SubGraphNode subGraphNode) { @SuppressWarnings("unchecked") var subGraph = (StateGraph) subGraphNode.subGraph(); Context subgraphCtx = generate(subGraph.nodes, subGraph.edges, Context.builder() .title(n.id()) .printConditionalEdge(ctx.printConditionalEdge) .isSubGraph(true) .build()); ctx.sb().append(subgraphCtx); } else { declareNode(ctx, n.id()); } } final int[] conditionalEdgeCount = { 0 }; edges.elements.stream() .filter(e -> !Objects.equals(e.sourceId(), START)) .filter(e -> !e.isParallel()) .forEach(e -> { if (e.target().value() != null) { conditionalEdgeCount[0] += 1; commentLine(ctx, !ctx.printConditionalEdge()); declareConditionalEdge(ctx, conditionalEdgeCount[0]); } }); var edgeStart = edges.elements.stream() .filter(e -> Objects.equals(e.sourceId(), START)) .findFirst() .orElseThrow(); if (edgeStart.isParallel()) { edgeStart.targets().forEach(target -> { call(ctx, START, target.id(), CallStyle.START); }); } else if (edgeStart.target().id() != null) { call(ctx, START, edgeStart.target().id(), CallStyle.START); } else if (edgeStart.target().value() != null) { String conditionName = "startcondition"; commentLine(ctx, !ctx.printConditionalEdge()); declareConditionalStart(ctx, conditionName); edgeCondition(ctx, edgeStart.target().value(), START, conditionName); } conditionalEdgeCount[0] = 0; // reset edges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).forEach(v -> { if (v.isParallel()) { v.targets().forEach(target -> { call(ctx, v.sourceId(), target.id(), CallStyle.PARALLEL); }); } else if (v.target().id() != null) { call(ctx, v.sourceId(), v.target().id(), CallStyle.DEFAULT); } else if (v.target().value() != null) { conditionalEdgeCount[0] += 1; String conditionName = format("condition%d", conditionalEdgeCount[0]); edgeCondition(ctx, v.targets().get(0).value(), v.sourceId(), conditionName); } }); appendFooter(ctx); return ctx; } /** * Evaluates an edge condition based on the given context and condition. * @param ctx the current context used for evaluation * @param condition the condition to be evaluated * @param k a string identifier for the condition * @param conditionName the name of the condition being processed */ private void edgeCondition(Context ctx, EdgeCondition condition, String k, String conditionName) { commentLine(ctx, !ctx.printConditionalEdge()); call(ctx, k, conditionName, CallStyle.CONDITIONAL); condition.mappings().forEach((cond, to) -> { commentLine(ctx, !ctx.printConditionalEdge()); call(ctx, conditionName, to, cond, CallStyle.CONDITIONAL); commentLine(ctx, ctx.printConditionalEdge()); call(ctx, k, to, cond, CallStyle.CONDITIONAL); }); }}
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。
PlantUMLGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/diagram/PlantUMLGenerator.java
public class PlantUMLGenerator extends DiagramGenerator { @Override protected void appendHeader(Context ctx) { if (ctx.isSubGraph()) { ctx.sb() .append(format("rectangle %s [ {{\ntitle \"%s\"\n", ctx.title(), ctx.title())) .append(format("circle \" \" as %s\n", START)) .append(format("circle exit as %s\n", END)); } else { ctx.sb() .append(format("@startuml %s\n", ctx.titleToSnakeCase())) .append("skinparam usecaseFontSize 14\n") .append("skinparam usecaseStereotypeFontSize 12\n") .append("skinparam hexagonFontSize 14\n") .append("skinparam hexagonStereotypeFontSize 12\n") .append(format("title \"%s\"\n", ctx.title())) .append("footer\n\n") .append("powered by spring-ai-alibaba\n") .append("end footer\n") .append(format("circle start<<input>> as %s\n", START)) .append(format("circle stop as %s\n", END)); } } @Override protected void appendFooter(Context ctx) { if (ctx.isSubGraph()) { ctx.sb().append("\n}} ]\n"); } else { ctx.sb().append("@enduml\n"); } } @Override protected void call(Context ctx, String from, String to, CallStyle style) { ctx.sb().append(switch (style) { case CONDITIONAL -> format("\"%s\" .down.> \"%s\"\n", from, to); default -> format("\"%s\" -down-> \"%s\"\n", from, to); }); } @Override protected void call(Context ctx, String from, String to, String description, CallStyle style) { ctx.sb().append(switch (style) { case CONDITIONAL -> format("\"%s\" .down.> \"%s\": \"%s\"\n", from, to, description); default -> format("\"%s\" -down-> \"%s\": \"%s\"\n", from, to, description); }); } @Override protected void declareConditionalStart(Context ctx, String name) { ctx.sb().append(format("hexagon \"check state\" as %s<<Condition>>\n", name)); } @Override protected void declareNode(Context ctx, String name) { ctx.sb().append(format("usecase \"%s\"<<Node>>\n", name)); } @Override protected void declareConditionalEdge(Context ctx, int ordinal) { ctx.sb().append(format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal)); } @Override protected void commentLine(Context ctx, boolean yesOrNo) { if (yesOrNo) ctx.sb().append("'"); }}
PlantUMLGenerator实现了DiagramGenerator的抽象方法
StateGraph
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java
/** * Represents a state graph with nodes and edges. * */public class StateGraph { public static String END = "__END__"; public static String START = "__START__"; final Nodes nodes = new Nodes(); final Edges edges = new Edges(); private OverAllState overAllState; private String name; public OverAllState getOverAllState() { return overAllState; } public StateGraph setOverAllState(OverAllState overAllState) { this.overAllState = overAllState; return this; } private final PlainTextStateSerializer stateSerializer; //...... /** * Instantiates a new State graph. * @param overAllState the over all state * @param plainTextStateSerializer the plain text state serializer */ public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) { this.overAllState = overAllState; this.stateSerializer = plainTextStateSerializer; } public StateGraph(String name, OverAllState overAllState) { this.name = name; this.overAllState = overAllState; this.stateSerializer = new GsonSerializer(); } /** * Instantiates a new State graph. * @param overAllState the over all state */ public StateGraph(OverAllState overAllState) { this.overAllState = overAllState; this.stateSerializer = new GsonSerializer(); } public StateGraph(String name, AgentStateFactory<OverAllState> factory) { this.name = name; this.overAllState = factory.apply(Map.of()); this.stateSerializer = new GsonSerializer2(factory); } public StateGraph(AgentStateFactory<OverAllState> factory) { this.overAllState = factory.apply(Map.of()); this.stateSerializer = new GsonSerializer2(factory); } /** * Instantiates a new State graph. */ public StateGraph() { this.stateSerializer = new GsonSerializer(); } public String getName() { return name; } /** * Key strategies map. * @return the map */ public Map<String, KeyStrategy> keyStrategies() { return overAllState.keyStrategies(); } /** * Gets state serializer. * @return the state serializer */ public StateSerializer getStateSerializer() { return stateSerializer; } /** * Gets state factory. * @return the state factory */ public final AgentStateFactory<OverAllState> getStateFactory() { return stateSerializer.stateFactory(); } /** * /** Adds a node to the graph. * @param id the identifier of the node * @param action the action to be performed by the node * @throws GraphStateException if the node identifier is invalid or the node already * exists */ public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException { return addNode(id, AsyncNodeActionWithConfig.of(action)); } /** * @param id the identifier of the node * @param actionWithConfig the action to be performed by the node * @return this * @throws GraphStateException if the node identifier is invalid or the node already * exists */ public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException { Node node = new Node(id, (config) -> actionWithConfig); return addNode(id, node); } /** * @param id the identifier of the node * @param node the node to be added * @return this * @throws GraphStateException if the node identifier is invalid or the node already * exists */ public StateGraph addNode(String id, Node node) throws GraphStateException { if (Objects.equals(node.id(), END)) { throw Errors.invalidNodeIdentifier.exception(END); } if (!Objects.equals(node.id(), id)) { throw Errors.invalidNodeIdentifier.exception(node.id(), id); } if (nodes.elements.contains(node)) { throw Errors.duplicateNodeError.exception(id); } nodes.elements.add(node); return this; } /** * Adds a subgraph to the state graph by creating a node with the specified * identifier. This implies that Subgraph share the same state with parent graph * @param id the identifier of the node representing the subgraph * @param subGraph the compiled subgraph to be added * @return this state graph instance * @throws GraphStateException if the node identifier is invalid or the node already * exists */ public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException { if (Objects.equals(id, END)) { throw Errors.invalidNodeIdentifier.exception(END); } var node = new SubCompiledGraphNode(id, subGraph); if (nodes.elements.contains(node)) { throw Errors.duplicateNodeError.exception(id); } nodes.elements.add(node); return this; } /** * Adds a subgraph to the state graph by creating a node with the specified * identifier. This implies that Subgraph share the same state with parent graph * @param id the identifier of the node representing the subgraph * @param subGraph the subgraph to be added. it will be compiled on compilation of the * parent * @return this state graph instance * @throws GraphStateException if the node identifier is invalid or the node already * exists */ public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException { if (Objects.equals(id, END)) { throw Errors.invalidNodeIdentifier.exception(END); } subGraph.validateGraph(); OverAllState subGraphOverAllState = subGraph.getOverAllState(); OverAllState superOverAllState = getOverAllState(); if (subGraphOverAllState != null) { Map<String, KeyStrategy> strategies = subGraphOverAllState.keyStrategies(); for (Map.Entry<String, KeyStrategy> strategyEntry : strategies.entrySet()) { if (!superOverAllState.containStrategy(strategyEntry.getKey())) { superOverAllState.registerKeyAndStrategy(strategyEntry.getKey(), strategyEntry.getValue()); } } } subGraph.setOverAllState(getOverAllState()); var node = new SubStateGraphNode(id, subGraph); if (nodes.elements.contains(node)) { throw Errors.duplicateNodeError.exception(id); } nodes.elements.add(node); return this; } /** * Adds an edge to the graph. * @param sourceId the identifier of the source node * @param targetId the identifier of the target node * @throws GraphStateException if the edge identifier is invalid or the edge already * exists */ public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException { if (Objects.equals(sourceId, END)) { throw Errors.invalidEdgeIdentifier.exception(END); } // if (Objects.equals(sourceId, START)) { // this.entryPoint = new EdgeValue<>(targetId); // return this; // } var newEdge = new Edge(sourceId, new EdgeValue(targetId)); int index = edges.elements.indexOf(newEdge); if (index >= 0) { var newTargets = new ArrayList<>(edges.elements.get(index).targets()); newTargets.add(newEdge.target()); edges.elements.set(index, new Edge(sourceId, newTargets)); } else { edges.elements.add(newEdge); } return this; } /** * Adds conditional edges to the graph. * @param sourceId the identifier of the source node * @param condition the condition to determine the target node * @param mappings the mappings of conditions to target nodes * @throws GraphStateException if the edge identifier is invalid, the mappings are * empty, or the edge already exists */ public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings) throws GraphStateException { if (Objects.equals(sourceId, END)) { throw Errors.invalidEdgeIdentifier.exception(END); } if (mappings == null || mappings.isEmpty()) { throw Errors.edgeMappingIsEmpty.exception(sourceId); } var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings))); if (edges.elements.contains(newEdge)) { throw Errors.duplicateConditionalEdgeError.exception(sourceId); } else { edges.elements.add(newEdge); } return this; } void validateGraph() throws GraphStateException { var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception); edgeStart.validate(nodes); for (Edge edge : edges.elements) { edge.validate(nodes); } } /** * Compiles the state graph into a compiled graph. * @param config the compile configuration * @return a compiled graph * @throws GraphStateException if there are errors related to the graph state */ public CompiledGraph compile(CompileConfig config) throws GraphStateException { Objects.requireNonNull(config, "config cannot be null"); validateGraph(); return new CompiledGraph(this, config); } /** * Compiles the state graph into a compiled graph. * @return a compiled graph * @throws GraphStateException if there are errors related to the graph state */ public CompiledGraph compile() throws GraphStateException { SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build(); return compile(CompileConfig.builder() .plainTextStateSerializer(new JacksonSerializer()) .saverConfig(saverConfig) .build()); } /** * Generates a drawable graph representation of the state graph. * @param type the type of graph representation to generate * @param title the title of the graph * @param printConditionalEdges whether to print conditional edges * @return a diagram code of the state graph */ public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) { String content = type.generator.generate(nodes, edges, title, printConditionalEdges); return new GraphRepresentation(type, content); } /** * Generates a drawable graph representation of the state graph. * @param type the type of graph representation to generate * @param title the title of the graph * @return a diagram code of the state graph */ public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) { String content = type.generator.generate(nodes, edges, title, true); return new GraphRepresentation(type, content); } public GraphRepresentation getGraph(GraphRepresentation.Type type) { String content = type.generator.generate(nodes, edges, name, true); return new GraphRepresentation(type, content); } //......}
StateGraph提供了addNode、addEdge、addConditionalEdges等方法,其中getGraph方法根据指定GraphRepresentation.Type的DiagramGenerator来生成状态图
示例
@Test public void testGraph() throws GraphStateException { OverAllState overAllState = getOverAllState(); StateGraph workflow = new StateGraph(overAllState).addNode("agent_1", node_async(state -> { System.out.println("agent_1"); return Map.of("messages", "message1"); })).addNode("agent_2", node_async(state -> { System.out.println("agent_2"); return Map.of("messages", new String[] { "message2" }); })).addNode("agent_3", node_async(state -> { System.out.println("agent_3"); List<String> messages = Optional.ofNullable(state.value("messages").get()) .filter(List.class::isInstance) .map(List.class::cast) .orElse(new ArrayList<>()); int steps = messages.size() + 1; return Map.of("messages", "message3", "steps", steps); })) .addEdge("agent_1", "agent_2") .addEdge("agent_2", "agent_3") .addEdge(StateGraph.START, "agent_1") .addEdge("agent_3", StateGraph.END); GraphRepresentation representation = workflow.getGraph(GraphRepresentation.Type.PLANTUML, "demo"); System.out.println(representation.content()); }
输出如下:
@startuml demoskinparam usecaseFontSize 14skinparam usecaseStereotypeFontSize 12skinparam hexagonFontSize 14skinparam hexagonStereotypeFontSize 12title "demo"footerpowered by spring-ai-alibabaend footercircle start<<input>> as __START__circle stop as __END__usecase "agent_1"<<Node>>usecase "agent_2"<<Node>>usecase "agent_3"<<Node>>"__START__" -down-> "agent_1""agent_1" -down-> "agent_2""agent_2" -down-> "agent_3""agent_3" -down-> "__END__"@enduml
小结
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。PlantUMLGenerator继承了DiagramGenerator,根据plantUML语法实现了抽象方法。