diff --git a/src/mk_graph/mod.rs b/src/mk_graph/mod.rs index 480ee518..e93c4d00 100644 --- a/src/mk_graph/mod.rs +++ b/src/mk_graph/mod.rs @@ -14,6 +14,7 @@ use crate::printer::collect_smir; pub mod context; pub mod index; pub mod output; +pub mod traverse; pub mod util; // Re-exports for convenience @@ -45,7 +46,9 @@ pub fn emit_dotfile(tcx: TyCtxt<'_>) { /// Entry point to write the D2 file pub fn emit_d2file(tcx: TyCtxt<'_>) { - let smir_d2 = collect_smir(tcx).to_d2_file(); + let smir = collect_smir(tcx); + + let smir_d2 = smir.to_d2_file(); match mir_output_path(tcx, "smir.d2") { OutputDest::Stdout => { diff --git a/src/mk_graph/output/d2.rs b/src/mk_graph/output/d2.rs index c3ccc491..843a1223 100644 --- a/src/mk_graph/output/d2.rs +++ b/src/mk_graph/output/d2.rs @@ -1,151 +1,133 @@ //! D2 diagram format output for MIR graphs. -use crate::compat::stable_mir; -use stable_mir::mir::TerminatorKind; - use crate::printer::SmirJson; -use crate::MonoItemKind; -use crate::mk_graph::context::GraphContext; -use crate::mk_graph::util::{ - escape_d2, is_unqualified, name_lines, short_name, terminator_targets, -}; +use crate::mk_graph::util::escape_d2; -impl SmirJson { - /// Convert the MIR to D2 diagram format - pub fn to_d2_file(self) -> String { - let ctx = GraphContext::from_smir(&self); - let mut output = String::new(); - - output.push_str("direction: right\n\n"); - render_d2_allocs_legend(&ctx, &mut output); - - for item in self.items { - match item.mono_item_kind { - MonoItemKind::MonoItemFn { name, body, .. } => { - render_d2_function(&name, body.as_ref(), &ctx, &mut output); - } - MonoItemKind::MonoItemGlobalAsm { asm } => { - render_d2_asm(&asm, &mut output); - } - MonoItemKind::MonoItemStatic { name, .. } => { - render_d2_static(&name, &mut output); - } - } - } - - output - } -} +use crate::mk_graph::traverse::render_graph; +use crate::mk_graph::traverse::{GraphBuilder, RenderedFunction}; // ============================================================================= -// D2 Rendering Helpers +// D2 Builder // ============================================================================= -fn render_d2_allocs_legend(ctx: &GraphContext, out: &mut String) { - let legend_lines = ctx.allocs_legend_lines(); - - out.push_str("ALLOCS: {\n"); - out.push_str(" style.fill: \"#ffffcc\"\n"); - out.push_str(" style.stroke: \"#999999\"\n"); - let legend_text = legend_lines - .iter() - .map(|s| escape_d2(s)) - .collect::>() - .join("\\n"); - out.push_str(&format!(" label: \"{}\"\n", legend_text)); - out.push_str("}\n\n"); +pub struct D2Builder { + buf: String, } -fn render_d2_function( - name: &str, - body: Option<&stable_mir::mir::Body>, - ctx: &GraphContext, - out: &mut String, -) { - let fn_id = short_name(name); - let display_name = escape_d2(&name_lines(name)); - - // Function container - out.push_str(&format!("{}: {{\n", fn_id)); - out.push_str(&format!(" label: \"{}\"\n", display_name)); - out.push_str(" style.fill: \"#e0e0ff\"\n"); - - if let Some(body) = body { - render_d2_blocks(body, ctx, out); - render_d2_block_edges(body, out); +impl D2Builder { + pub fn new() -> Self { + Self { buf: String::new() } } +} - out.push_str("}\n\n"); - - // Call edges (must be outside the container) - if let Some(body) = body { - render_d2_call_edges(&fn_id, body, ctx, out); +impl Default for D2Builder { + fn default() -> Self { + Self::new() } } -fn render_d2_blocks(body: &stable_mir::mir::Body, ctx: &GraphContext, out: &mut String) { - for (idx, block) in body.blocks.iter().enumerate() { - let stmts: Vec = block - .statements +impl GraphBuilder for D2Builder { + type Output = String; + + fn begin_graph(&mut self, _name: &str) { + self.buf.push_str("direction: right\n\n"); + } + + fn alloc_legend(&mut self, lines: &[String]) { + self.buf.push_str("ALLOCS: {\n"); + self.buf.push_str(" style.fill: \"#ffffcc\"\n"); + self.buf.push_str(" style.stroke: \"#999999\"\n"); + + let text = lines .iter() - .map(|s| escape_d2(&ctx.render_stmt(s))) - .collect(); - let term_str = escape_d2(&ctx.render_terminator(&block.terminator)); + .map(|l| escape_d2(l)) + .collect::>() + .join("\\n"); - let mut label = format!("bb{}:", idx); - for stmt in &stmts { - label.push_str(&format!("\\n{}", stmt)); - } - label.push_str(&format!("\\n---\\n{}", term_str)); + self.buf.push_str(&format!(" label: \"{}\"\n", text)); + self.buf.push_str("}\n\n"); + } + + fn type_legend(&mut self, _lines: &[String]) {} - out.push_str(&format!(" bb{}: \"{}\"\n", idx, label)); + fn external_function(&mut self, id: &str, name: &str) { + self.buf + .push_str(&format!("{}: \"{}\"\n", id, escape_d2(name))); } -} -fn render_d2_block_edges(body: &stable_mir::mir::Body, out: &mut String) { - for (idx, block) in body.blocks.iter().enumerate() { - for target in terminator_targets(&block.terminator) { - out.push_str(&format!(" bb{} -> bb{}\n", idx, target)); + fn render_function(&mut self, func: &RenderedFunction) { + self.buf.push_str(&format!("{}: {{\n", func.id)); + self.buf + .push_str(&format!(" label: \"{}\"\n", escape_d2(&func.display_name))); + self.buf.push_str(" style.fill: \"#e0e0ff\"\n"); + + for block in &func.blocks { + let mut label = format!("bb{}:", block.idx); + + for stmt in &block.stmts { + label.push_str(&format!("\\n{}", escape_d2(stmt))); + } + + label.push_str(&format!("\\n---\\n{}", escape_d2(&block.terminator))); + + self.buf + .push_str(&format!(" bb{}: \"{}\"\n", block.idx, label)); } - } -} -fn render_d2_call_edges( - fn_id: &str, - body: &stable_mir::mir::Body, - ctx: &GraphContext, - out: &mut String, -) { - for (idx, block) in body.blocks.iter().enumerate() { - let TerminatorKind::Call { func, .. } = &block.terminator.kind else { - continue; - }; - let Some(callee_name) = ctx.resolve_call_target(func) else { - continue; - }; - if !is_unqualified(&callee_name) { - continue; + for block in &func.blocks { + for (target, _) in &block.cfg_edges { + self.buf + .push_str(&format!(" bb{} -> bb{}\n", block.idx, target)); + } } - let target_id = short_name(&callee_name); - out.push_str(&format!("{}: \"{}\"\n", target_id, escape_d2(&callee_name))); - out.push_str(&format!("{}.style.fill: \"#ffe0e0\"\n", target_id)); - out.push_str(&format!("{}.bb{} -> {}: call\n", fn_id, idx, target_id)); + self.buf.push_str("}\n\n"); + + for edge in &func.call_edges { + self.buf.push_str(&format!( + "{}: \"{}\"\n", + edge.callee_id, + escape_d2(&edge.callee_name) + )); + + self.buf + .push_str(&format!("{}.style.fill: \"#ffe0e0\"\n", edge.callee_id)); + + self.buf.push_str(&format!( + "{}.bb{} -> {}: call\n", + func.id, edge.block_idx, edge.callee_id + )); + } } -} -fn render_d2_asm(asm: &str, out: &mut String) { - let asm_id = short_name(asm); - let asm_text = escape_d2(&asm.lines().collect::()); - out.push_str(&format!("{}: \"{}\" {{\n", asm_id, asm_text)); - out.push_str(" style.fill: \"#ffe0ff\"\n"); - out.push_str("}\n\n"); + fn static_item(&mut self, id: &str, name: &str) { + self.buf + .push_str(&format!("{}: \"{}\" {{\n", id, escape_d2(name))); + self.buf.push_str(" style.fill: \"#e0ffe0\"\n"); + self.buf.push_str("}\n\n"); + } + + fn asm_item(&mut self, id: &str, content: &str) { + let text = escape_d2(&content.lines().collect::()); + + self.buf.push_str(&format!("{}: \"{}\" {{\n", id, text)); + self.buf.push_str(" style.fill: \"#ffe0ff\"\n"); + self.buf.push_str("}\n\n"); + } + + fn finish(self) -> String { + self.buf + } } -fn render_d2_static(name: &str, out: &mut String) { - let static_id = short_name(name); - out.push_str(&format!("{}: \"{}\" {{\n", static_id, escape_d2(name))); - out.push_str(" style.fill: \"#e0ffe0\"\n"); - out.push_str("}\n\n"); +// ============================================================================= +// Public entry point +// ============================================================================= + +impl SmirJson { + /// Convert the MIR to D2 using GraphBuilder traversal + pub fn to_d2_file(&self) -> String { + render_graph(self, D2Builder::new()) + } } diff --git a/src/mk_graph/traverse.rs b/src/mk_graph/traverse.rs new file mode 100644 index 00000000..1d9bd9b1 --- /dev/null +++ b/src/mk_graph/traverse.rs @@ -0,0 +1,194 @@ +//! Generic MIR graph traversal. +//! +//! This module owns the traversal order and graph semantics. +extern crate stable_mir; +use stable_mir::mir::{Body, Statement, Terminator, TerminatorKind}; + +use crate::printer::SmirJson; +use crate::MonoItemKind; + +use crate::mk_graph::context::GraphContext; +use crate::mk_graph::util::{ + hash_body, is_unqualified, name_lines, short_name, terminator_targets, +}; + +/// Represents a call from a block to another function. +/// +/// The callee is resolved during traversal and arguments are already +/// rendered as a string. Builders may choose how to visualize this edge. +pub struct CallEdge { + pub block_idx: usize, + pub callee_id: String, + pub callee_name: String, + pub rendered_args: String, +} + +/// A basic block with pre-rendered textual content and structural edges. +/// +/// `stmts` and `terminator` are pre-rendered strings produced using +/// `GraphContext`. Builders are free to format or escape them according +/// to their output format. +/// +/// `raw_stmts` and `raw_terminator` are escape hatches for renderers +/// that need to inspect the underlying MIR structure. +pub struct RenderedBlock<'a> { + pub idx: usize, + pub stmts: Vec, + pub raw_stmts: &'a [Statement], + pub terminator: String, + pub raw_terminator: &'a Terminator, + pub cfg_edges: Vec<(usize, Option)>, +} + +/// A fully analyzed MIR function ready for rendering. +/// +/// The traversal layer resolves call targets, renders statements and +/// terminators, and computes the control-flow edges. Builders receive +/// this structure and are responsible only for formatting it into a +/// specific graph representation. +pub struct RenderedFunction<'a> { + pub id: String, + pub display_name: String, + pub locals: Vec<(usize, String)>, + pub blocks: Vec>, + pub call_edges: Vec, +} + +/// Trait implemented by graph renderers. +/// +/// The traversal layer walks the MIR graph and constructs a +/// `RenderedFunction` representation. Implementations of this trait +/// consume those structures and emit format-specific output such as +/// D2, DOT, or other diagram formats. +/// +/// The trait intentionally separates graph structure from formatting. +/// Traversal decides *what* the graph contains while the builder +/// decides *how* it is rendered. +pub trait GraphBuilder { + type Output; + + fn begin_graph(&mut self, name: &str); + + fn alloc_legend(&mut self, lines: &[String]); + + fn type_legend(&mut self, lines: &[String]); + + fn external_function(&mut self, id: &str, name: &str); + + fn render_function(&mut self, func: &RenderedFunction); + + fn static_item(&mut self, id: &str, name: &str); + + fn asm_item(&mut self, id: &str, content: &str); + + fn finish(self) -> Self::Output; +} + +/// Traverse the SMIR representation and produce rendered graph data. +/// +/// This function performs MIR traversal, resolves call targets, and +/// constructs `RenderedFunction` structures which are then passed to +/// the provided `GraphBuilder`. +pub fn render_graph(smir: &SmirJson, mut builder: B) -> B::Output { + let ctx = GraphContext::from_smir(smir); + + builder.begin_graph(&smir.name); + + builder.alloc_legend(&ctx.allocs_legend_lines()); + builder.type_legend(&ctx.types_legend_lines()); + + for item in &smir.items { + match &item.mono_item_kind { + MonoItemKind::MonoItemFn { name, body, .. } => { + let func = render_function(&ctx, name, body.as_ref()); + builder.render_function(&func); + } + MonoItemKind::MonoItemStatic { name, .. } => { + let id = short_name(name); + builder.static_item(&id, name); + } + MonoItemKind::MonoItemGlobalAsm { asm } => { + let id = short_name(asm); + builder.asm_item(&id, asm); + } + } + } + + builder.finish() +} + +/// Emit graph events for a single function body. +/// Traverses blocks, CFG edges, and call edges without renderer-specific logic. +fn render_function<'a>( + ctx: &GraphContext, + name: &str, + body: Option<&'a Body>, +) -> RenderedFunction<'a> { + let id = match body { + Some(b) => format!("fn_{}_{}", short_name(name), hash_body(b)), + None => format!("fn_{}_no_body", short_name(name)), + }; + + let display_name = name_lines(name); + + let mut blocks = Vec::new(); + let mut call_edges = Vec::new(); + let mut locals = Vec::new(); + + if let Some(body) = body { + for (idx, decl) in body.local_decls() { + locals.push((idx, ctx.render_type_with_layout(decl.ty))); + } + + for (idx, block) in body.blocks.iter().enumerate() { + let stmts = block + .statements + .iter() + .map(|s| ctx.render_stmt(s)) + .collect(); + + let terminator = ctx.render_terminator(&block.terminator); + + let cfg_edges = terminator_targets(&block.terminator) + .into_iter() + .map(|t| (t, None)) + .collect(); + + blocks.push(RenderedBlock { + idx, + stmts, + raw_stmts: &block.statements, + terminator, + raw_terminator: &block.terminator, + cfg_edges, + }); + + if let TerminatorKind::Call { func, args, .. } = &block.terminator.kind { + if let Some(callee) = ctx.resolve_call_target(func) { + if is_unqualified(&callee) { + let rendered_args = args + .iter() + .map(|a| ctx.render_operand(a)) + .collect::>() + .join(", "); + + call_edges.push(CallEdge { + block_idx: idx, + callee_id: short_name(&callee), + callee_name: callee, + rendered_args, + }); + } + } + } + } + } + + RenderedFunction { + id, + display_name, + locals, + blocks, + call_edges, + } +} diff --git a/src/mk_graph/util.rs b/src/mk_graph/util.rs index ac21cbb1..d7708c74 100644 --- a/src/mk_graph/util.rs +++ b/src/mk_graph/util.rs @@ -4,8 +4,8 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use crate::compat::stable_mir; use stable_mir::mir::{ - AggregateKind, BorrowKind, ConstOperand, Mutability, NonDivergingIntrinsic, NullOp, Operand, - Place, ProjectionElem, Rvalue, Terminator, TerminatorKind, UnwindAction, + AggregateKind, Body, BorrowKind, ConstOperand, Mutability, NonDivergingIntrinsic, NullOp, + Operand, Place, ProjectionElem, Rvalue, Terminator, TerminatorKind, UnwindAction, }; use stable_mir::ty::{IndexedVal, RigidTy}; @@ -280,3 +280,29 @@ pub fn terminator_targets(term: &Terminator) -> Vec { } } } + +/// Generate a consistent short hash for a MIR body. +/// Used to avoid fn_id collisions between monomorphizations. +pub fn hash_body(body: &Body) -> u64 { + let mut h = DefaultHasher::new(); + + // Hash number of blocks + body.blocks.len().hash(&mut h); + + for (idx, block) in body.blocks.iter().enumerate() { + idx.hash(&mut h); + + // Hash terminator kind + std::mem::discriminant(&block.terminator.kind).hash(&mut h); + + // Hash control-flow edges + for target in terminator_targets(&block.terminator) { + target.hash(&mut h); + } + + // Statement count for entropy + block.statements.len().hash(&mut h); + } + + h.finish() +}