Skip to main content

miden_assembly/linker/
callgraph.rs

1use alloc::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8/// Represents the inability to construct a topological ordering of the nodes in a [CallGraph]
9/// due to a cycle in the graph, which can happen due to recursion.
10#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14    pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
15        self.0.into_iter()
16    }
17}
18
19// CALL GRAPH
20// ================================================================================================
21
22/// A [CallGraph] is a directed, acyclic graph which represents all of the edges between procedures
23/// formed by a caller/callee relationship.
24///
25/// More precisely, this graph can be used to perform the following analyses:
26///
27/// - What is the maximum call stack depth for a program?
28/// - Are there any recursive procedure calls?
29/// - Are there procedures which are unreachable from the program entrypoint?, i.e. dead code
30/// - What is the set of procedures which are reachable from a given procedure, and which of those
31///   are (un)conditionally called?
32///
33/// A [CallGraph] is the actual graph underpinning the conceptual "module graph" of the linker, and
34/// the two are intrinsically linked to one another (i.e. a [CallGraph] is meaningless without
35/// the corresponding [super::Linker] state).
36#[derive(Default, Clone)]
37pub struct CallGraph {
38    /// The adjacency matrix for procedures in the call graph
39    nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
40}
41
42impl CallGraph {
43    /// Gets the set of edges from the given caller to its callees in the graph.
44    pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
45        self.nodes.get(&gid).map(|out_edges| out_edges.as_slice()).unwrap_or(&[])
46    }
47
48    /// Inserts a node in the graph for `id`, if not already present.
49    ///
50    /// Returns the set of [GlobalItemIndex] which are the outbound neighbors of `id` in the
51    /// graph, i.e. the callees of a call-like instruction.
52    pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
53        self.nodes.entry(id).or_default()
54    }
55
56    /// Add an edge in the call graph from `caller` to `callee`.
57    ///
58    /// This operation is unchecked, i.e. it is possible to introduce cycles in the graph using it.
59    /// As a result, it is essential that the caller either know that adding the edge does _not_
60    /// introduce a cycle, or that [Self::toposort] is run once the graph is built, in order to
61    /// verify that the graph is valid and has no cycles.
62    ///
63    /// NOTE: This function will panic if you attempt to add an edge from a function to itself,
64    /// which trivially introduces a cycle. All other cycle-inducing edges must be caught by a
65    /// call to [Self::toposort].
66    pub fn add_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
67        assert_ne!(caller, callee, "a procedure cannot call itself");
68
69        // Make sure the callee is in the graph
70        self.get_or_insert_node(callee);
71        // Make sure the caller is in the graph
72        let callees = self.get_or_insert_node(caller);
73        // If the caller already references the callee, we're done
74        if callees.contains(&callee) {
75            return;
76        }
77
78        callees.push(callee);
79    }
80
81    /// Removes the edge between `caller` and `callee` from the graph
82    pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
83        if let Some(out_edges) = self.nodes.get_mut(&caller) {
84            out_edges.retain(|n| *n != callee);
85        }
86    }
87
88    /// Returns the number of predecessors of `id` in the graph, i.e.
89    /// the number of procedures which call `id`.
90    pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
91        self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
92    }
93
94    /// Construct the topological ordering of all nodes in the call graph.
95    ///
96    /// Returns `Err` if a cycle is detected in the graph
97    pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
98        if self.nodes.is_empty() {
99            return Ok(vec![]);
100        }
101
102        let mut output = Vec::with_capacity(self.nodes.len());
103        let mut graph = self.clone();
104
105        // Build the set of roots by finding all nodes
106        // that have no predecessors
107        let mut has_preds = BTreeSet::default();
108        for (_node, out_edges) in graph.nodes.iter() {
109            for succ in out_edges.iter() {
110                has_preds.insert(*succ);
111            }
112        }
113        let mut roots =
114            VecDeque::from_iter(graph.nodes.keys().copied().filter(|n| !has_preds.contains(n)));
115
116        // If all nodes have predecessors, there must be a cycle, so just pick a node and let the
117        // algorithm find the cycle for that node so we have a useful error. Set a flag so that we
118        // can assert that the cycle was actually found as a sanity check
119        let mut expect_cycle = false;
120        if roots.is_empty() {
121            expect_cycle = true;
122            roots.extend(graph.nodes.keys().next());
123        }
124
125        let mut successors = Vec::with_capacity(4);
126        while let Some(id) = roots.pop_front() {
127            output.push(id);
128            successors.clear();
129            successors.extend(graph.nodes[&id].iter().copied());
130            for mid in successors.drain(..) {
131                graph.remove_edge(id, mid);
132                if graph.num_predecessors(mid) == 0 {
133                    roots.push_back(mid);
134                }
135            }
136        }
137
138        let has_cycle = graph
139            .nodes
140            .iter()
141            .any(|(n, out_edges)| !output.contains(n) || !out_edges.is_empty());
142        if has_cycle {
143            let mut in_cycle = BTreeSet::default();
144            for (n, edges) in graph.nodes.iter() {
145                if edges.is_empty() {
146                    continue;
147                }
148                in_cycle.insert(*n);
149                in_cycle.extend(edges.as_slice());
150            }
151            Err(CycleError(in_cycle))
152        } else {
153            assert!(!expect_cycle, "we expected a cycle to be found, but one was not identified");
154            Ok(output)
155        }
156    }
157
158    /// Gets a new graph which is a subgraph of `self` containing all of the nodes reachable from
159    /// `root`, and nothing else.
160    pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
161        let mut worklist = VecDeque::from_iter([root]);
162        let mut graph = Self::default();
163        let mut visited = BTreeSet::default();
164
165        while let Some(gid) = worklist.pop_front() {
166            if !visited.insert(gid) {
167                continue;
168            }
169
170            let new_successors = graph.get_or_insert_node(gid);
171            let prev_successors = self.out_edges(gid);
172            worklist.extend(prev_successors.iter().cloned());
173            new_successors.extend_from_slice(prev_successors);
174        }
175
176        graph
177    }
178
179    /// Computes the set of nodes in this graph which can reach `root`.
180    fn reverse_reachable(&self, root: GlobalItemIndex) -> BTreeSet<GlobalItemIndex> {
181        let mut worklist = VecDeque::from_iter([root]);
182        let mut visited = BTreeSet::default();
183
184        while let Some(gid) = worklist.pop_front() {
185            if !visited.insert(gid) {
186                continue;
187            }
188
189            worklist.extend(
190                self.nodes
191                    .iter()
192                    .filter(|(_, out_edges)| out_edges.contains(&gid))
193                    .map(|(pred, _)| *pred),
194            );
195        }
196
197        visited
198    }
199
200    /// Constructs the topological ordering of nodes in the call graph, for which `caller` is an
201    /// ancestor.
202    ///
203    /// # Errors
204    /// Returns an error if a cycle is detected in the graph.
205    pub fn toposort_caller(
206        &self,
207        caller: GlobalItemIndex,
208    ) -> Result<Vec<GlobalItemIndex>, CycleError> {
209        let mut output = Vec::with_capacity(self.nodes.len());
210
211        // Build a subgraph of `self` containing only those nodes reachable from `caller`
212        let caller_subgraph = self.subgraph(caller);
213        let mut graph = caller_subgraph.clone();
214
215        // Preserve the full set of nodes participating in cycles that close back into `caller`
216        // before we erase those back-edges to seed the traversal from `caller`.
217        let caller_cycle = caller_subgraph
218            .nodes
219            .values()
220            .any(|edges| edges.contains(&caller))
221            .then(|| caller_subgraph.reverse_reachable(caller));
222
223        // Remove all predecessor edges to `caller`
224        graph.nodes.iter_mut().for_each(|(_pred, out_edges)| {
225            out_edges.retain(|n| *n != caller);
226        });
227
228        let mut roots = VecDeque::from_iter([caller]);
229        let mut successors = Vec::with_capacity(4);
230        while let Some(id) = roots.pop_front() {
231            output.push(id);
232            successors.clear();
233            successors.extend(graph.nodes[&id].iter().copied());
234            for mid in successors.drain(..) {
235                graph.remove_edge(id, mid);
236                if graph.num_predecessors(mid) == 0 {
237                    roots.push_back(mid);
238                }
239            }
240        }
241
242        let has_cycle = output.len() != graph.nodes.len() || caller_cycle.is_some();
243        if has_cycle {
244            let mut in_cycle = BTreeSet::default();
245            for (n, edges) in graph.nodes.iter() {
246                if edges.is_empty() {
247                    continue;
248                }
249                in_cycle.insert(*n);
250                in_cycle.extend(edges.as_slice());
251            }
252            if let Some(caller_cycle) = caller_cycle {
253                in_cycle.extend(caller_cycle);
254            }
255            Err(CycleError(in_cycle))
256        } else {
257            Ok(output)
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
266
267    const A: ModuleIndex = ModuleIndex::const_new(1);
268    const B: ModuleIndex = ModuleIndex::const_new(2);
269    const P1: ItemIndex = ItemIndex::const_new(1);
270    const P2: ItemIndex = ItemIndex::const_new(2);
271    const P3: ItemIndex = ItemIndex::const_new(3);
272    const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
273    const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
274    const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
275    const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
276    const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
277    const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
278
279    #[test]
280    fn callgraph_add_edge() {
281        let graph = callgraph_simple();
282
283        // Verify the graph structure
284        assert_eq!(graph.num_predecessors(A1), 0);
285        assert_eq!(graph.num_predecessors(B1), 0);
286        assert_eq!(graph.num_predecessors(A2), 1);
287        assert_eq!(graph.num_predecessors(B2), 2);
288        assert_eq!(graph.num_predecessors(B3), 1);
289        assert_eq!(graph.num_predecessors(A3), 2);
290
291        assert_eq!(graph.out_edges(A1), &[A2]);
292        assert_eq!(graph.out_edges(B1), &[B2]);
293        assert_eq!(graph.out_edges(A2), &[B2, A3]);
294        assert_eq!(graph.out_edges(B2), &[B3]);
295        assert_eq!(graph.out_edges(A3), &[]);
296        assert_eq!(graph.out_edges(B3), &[A3]);
297    }
298
299    #[test]
300    fn callgraph_add_edge_with_cycle() {
301        let graph = callgraph_cycle();
302
303        // Verify the graph structure
304        assert_eq!(graph.num_predecessors(A1), 0);
305        assert_eq!(graph.num_predecessors(B1), 0);
306        assert_eq!(graph.num_predecessors(A2), 2);
307        assert_eq!(graph.num_predecessors(B2), 2);
308        assert_eq!(graph.num_predecessors(B3), 1);
309        assert_eq!(graph.num_predecessors(A3), 1);
310
311        assert_eq!(graph.out_edges(A1), &[A2]);
312        assert_eq!(graph.out_edges(B1), &[B2]);
313        assert_eq!(graph.out_edges(A2), &[B2]);
314        assert_eq!(graph.out_edges(B2), &[B3]);
315        assert_eq!(graph.out_edges(A3), &[A2]);
316        assert_eq!(graph.out_edges(B3), &[A3]);
317    }
318
319    #[test]
320    fn callgraph_subgraph() {
321        let graph = callgraph_simple();
322        let subgraph = graph.subgraph(A2);
323
324        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
325    }
326
327    #[test]
328    fn callgraph_with_cycle_subgraph() {
329        let graph = callgraph_cycle();
330        let subgraph = graph.subgraph(A2);
331
332        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
333    }
334
335    #[test]
336    fn callgraph_toposort() {
337        let graph = callgraph_simple();
338
339        let sorted = graph.toposort().expect("expected valid topological ordering");
340        assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
341    }
342
343    #[test]
344    fn callgraph_toposort_caller() {
345        let graph = callgraph_simple();
346
347        let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
348        assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
349    }
350
351    #[test]
352    fn callgraph_with_cycle_toposort() {
353        let graph = callgraph_cycle();
354
355        let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
356        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
357    }
358
359    #[test]
360    fn callgraph_toposort_caller_with_reachable_cycle() {
361        let graph = callgraph_cycle();
362
363        let err = graph
364            .toposort_caller(A1)
365            .expect_err("expected toposort_caller to fail when a reachable cycle exists");
366        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
367    }
368
369    #[test]
370    fn callgraph_toposort_caller_root_closing_cycle() {
371        let graph = callgraph_cycle();
372
373        let err = graph
374            .toposort_caller(A2)
375            .expect_err("expected toposort_caller to detect cycle closing back into root");
376        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
377    }
378    /// a::a1 -> a::a2 -> a::a3
379    ///            |        ^
380    ///            v        |
381    /// b::b1 -> b::b2 -> b::b3
382    fn callgraph_simple() -> CallGraph {
383        // Construct the graph
384        let mut graph = CallGraph::default();
385        graph.add_edge(A1, A2);
386        graph.add_edge(B1, B2);
387        graph.add_edge(A2, B2);
388        graph.add_edge(A2, A3);
389        graph.add_edge(B2, B3);
390        graph.add_edge(B3, A3);
391
392        graph
393    }
394
395    /// a::a1 -> a::a2 <- a::a3
396    ///            |        ^
397    ///            v        |
398    /// b::b1 -> b::b2 -> b::b3
399    fn callgraph_cycle() -> CallGraph {
400        // Construct the graph
401        let mut graph = CallGraph::default();
402        graph.add_edge(A1, A2);
403        graph.add_edge(B1, B2);
404        graph.add_edge(A2, B2);
405        graph.add_edge(B2, B3);
406        graph.add_edge(B3, A3);
407        graph.add_edge(A3, A2);
408
409        graph
410    }
411}