import asyncio
import base64
import re
from dataclasses import asdict
from typing import Literal, Optional
from langchain_core.runnables.graph import (
CurveStyle,
Edge,
MermaidDrawMethod,
Node,
NodeStyles,
)
MARKDOWN_SPECIAL_CHARS = "*_`"
[docs]
def draw_mermaid(
nodes: dict[str, Node],
edges: list[Edge],
*,
first_node: Optional[str] = None,
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data.
Args:
nodes (dict[str, str]): List of node ids.
edges (List[Edge]): List of edges, object with a source,
target and data.
first_node (str, optional): Id of the first node. Defaults to None.
last_node (str, optional): Id of the last node. Defaults to None.
with_styles (bool, optional): Whether to include styles in the graph.
Defaults to True.
curve_style (CurveStyle, optional): Curve style for the edges.
Defaults to CurveStyle.LINEAR.
node_styles (NodeStyles, optional): Node colors for different types.
Defaults to NodeStyles().
wrap_label_n_words (int, optional): Words to wrap the edge labels.
Defaults to 9.
Returns:
str: Mermaid graph syntax.
"""
# Initialize Mermaid graph configuration
mermaid_graph = (
(
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n"
)
if with_styles
else "graph TD;\n"
)
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}({1})"}
if first_node is not None:
format_dict[first_node] = "{0}([{1}]):::first"
if last_node is not None:
format_dict[last_node] = "{0}([{1}]):::last"
# Add nodes to the graph
for key, node in nodes.items():
node_name = node.name.split(":")[-1]
label = (
f"<p>{node_name}</p>"
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
else node_name
)
if node.metadata:
label = (
f"{label}<hr/><small><em>"
+ "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
+ "</em></small>"
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
mermaid_graph += f"\t{node_label}\n"
# Group edges by their common prefixes
edge_groups: dict[str, list[Edge]] = {}
for edge in edges:
src_parts = edge.source.split(":")
tgt_parts = edge.target.split(":")
common_prefix = ":".join(
src for src, tgt in zip(src_parts, tgt_parts) if src == tgt
)
edge_groups.setdefault(common_prefix, []).append(edge)
seen_subgraphs = set()
def add_subgraph(edges: list[Edge], prefix: str) -> None:
nonlocal mermaid_graph
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
if prefix and not self_loop:
subgraph = prefix.split(":")[-1]
if subgraph in seen_subgraphs:
msg = (
f"Found duplicate subgraph '{subgraph}' -- this likely means that "
"you're reusing a subgraph node with the same name. "
"Please adjust your graph to have subgraph nodes with unique names."
)
raise ValueError(msg)
seen_subgraphs.add(subgraph)
mermaid_graph += f"\tsubgraph {subgraph}\n"
for edge in edges:
source, target = edge.source, edge.target
# Add BR every wrap_label_n_words words
if edge.data is not None:
edge_data = edge.data
words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
edge_data = " <br> ".join(
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
)
if edge.conditional:
edge_label = f" -. {edge_data} .-> "
else:
edge_label = f" -- {edge_data} --> "
else:
edge_label = " -.-> " if edge.conditional else " --> "
mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
)
# Recursively add nested subgraphs
for nested_prefix in edge_groups:
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
continue
add_subgraph(edge_groups[nested_prefix], nested_prefix)
if prefix and not self_loop:
mermaid_graph += "\tend\n"
# Start with the top-level edges (no common prefix)
add_subgraph(edge_groups.get("", []), "")
# Add remaining subgraphs
for prefix in edge_groups:
if ":" in prefix or prefix == "":
continue
add_subgraph(edge_groups[prefix], prefix)
# Add custom styles for nodes
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
return mermaid_graph
def _escape_node_label(node_label: str) -> str:
"""Escapes the node label for Mermaid syntax."""
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
"""Generates Mermaid graph styles for different node types."""
styles = ""
for class_name, style in asdict(node_colors).items():
styles += f"\tclassDef {class_name} {style}\n"
return styles
[docs]
def draw_mermaid_png(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax.
Args:
mermaid_syntax (str): Mermaid graph syntax.
output_file_path (str, optional): Path to save the PNG image.
Defaults to None.
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color (str, optional): Background color of the image.
Defaults to "white".
padding (int, optional): Padding around the image. Defaults to 10.
Returns:
bytes: PNG image bytes.
Raises:
ValueError: If an invalid draw method is provided.
"""
if draw_method == MermaidDrawMethod.PYPPETEER:
import asyncio
img_bytes = asyncio.run(
_render_mermaid_using_pyppeteer(
mermaid_syntax, output_file_path, background_color, padding
)
)
elif draw_method == MermaidDrawMethod.API:
img_bytes = _render_mermaid_using_api(
mermaid_syntax, output_file_path, background_color
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
msg = (
f"Invalid draw method: {draw_method}. "
f"Supported draw methods are: {supported_methods}"
)
raise ValueError(msg)
return img_bytes
async def _render_mermaid_using_pyppeteer(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
padding: int = 10,
device_scale_factor: int = 3,
) -> bytes:
"""Renders Mermaid graph using Pyppeteer."""
try:
from pyppeteer import launch # type: ignore[import]
except ImportError as e:
msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."
raise ImportError(msg) from e
browser = await launch()
page = await browser.newPage()
# Setup Mermaid JS
await page.goto("about:blank")
await page.addScriptTag(
{"url": "https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"}
)
await page.evaluate(
"""() => {
mermaid.initialize({startOnLoad:true});
}"""
)
# Render SVG
svg_code = await page.evaluate(
"""(mermaidGraph) => {
return mermaid.mermaidAPI.render('mermaid', mermaidGraph);
}""",
mermaid_syntax,
)
# Set the page background to white
await page.evaluate(
"""(svg, background_color) => {
document.body.innerHTML = svg;
document.body.style.background = background_color;
}""",
svg_code["svg"],
background_color,
)
# Take a screenshot
dimensions = await page.evaluate(
"""() => {
const svgElement = document.querySelector('svg');
const rect = svgElement.getBoundingClientRect();
return { width: rect.width, height: rect.height };
}"""
)
await page.setViewport(
{
"width": int(dimensions["width"] + padding),
"height": int(dimensions["height"] + padding),
"deviceScaleFactor": device_scale_factor,
}
)
img_bytes = await page.screenshot({"fullPage": False})
await browser.close()
def write_to_file(path: str, bytes: bytes) -> None:
with open(path, "wb") as file:
file.write(bytes)
if output_file_path is not None:
await asyncio.get_event_loop().run_in_executor(
None, write_to_file, output_file_path, img_bytes
)
return img_bytes
def _render_mermaid_using_api(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API."""
try:
import requests # type: ignore[import]
except ImportError as e:
msg = (
"Install the `requests` module to use the Mermaid.INK API: "
"`pip install requests`."
)
raise ImportError(msg) from e
# Use Mermaid API to render the image
mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode(
"ascii"
)
# Check if the background color is a hexadecimal color code using regex
if background_color is not None:
hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$")
if not hex_color_pattern.match(background_color):
background_color = f"!{background_color}"
image_url = (
f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
f"?type={file_type}&bgColor={background_color}"
)
response = requests.get(image_url, timeout=10)
if response.status_code == 200:
img_bytes = response.content
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(response.content)
return img_bytes
else:
msg = (
f"Failed to render the graph using the Mermaid.INK API. "
f"Status code: {response.status_code}."
)
raise ValueError(msg)