1use alloc::{
2 collections::{BTreeMap, BTreeSet, VecDeque},
3 vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14 pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
15 self.0.into_iter()
16 }
17}
18
19#[derive(Default, Clone)]
37pub struct CallGraph {
38 nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
40}
41
42impl CallGraph {
43 pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
45 self.nodes.get(&gid).map(|out_edges| out_edges.as_slice()).unwrap_or(&[])
46 }
47
48 pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
53 self.nodes.entry(id).or_default()
54 }
55
56 pub fn add_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
67 assert_ne!(caller, callee, "a procedure cannot call itself");
68
69 self.get_or_insert_node(callee);
71 let callees = self.get_or_insert_node(caller);
73 if callees.contains(&callee) {
75 return;
76 }
77
78 callees.push(callee);
79 }
80
81 pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
83 if let Some(out_edges) = self.nodes.get_mut(&caller) {
84 out_edges.retain(|n| *n != callee);
85 }
86 }
87
88 pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
91 self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
92 }
93
94 pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
98 if self.nodes.is_empty() {
99 return Ok(vec![]);
100 }
101
102 let mut output = Vec::with_capacity(self.nodes.len());
103 let mut graph = self.clone();
104
105 let mut has_preds = BTreeSet::default();
108 for (_node, out_edges) in graph.nodes.iter() {
109 for succ in out_edges.iter() {
110 has_preds.insert(*succ);
111 }
112 }
113 let mut roots =
114 VecDeque::from_iter(graph.nodes.keys().copied().filter(|n| !has_preds.contains(n)));
115
116 let mut expect_cycle = false;
120 if roots.is_empty() {
121 expect_cycle = true;
122 roots.extend(graph.nodes.keys().next());
123 }
124
125 let mut successors = Vec::with_capacity(4);
126 while let Some(id) = roots.pop_front() {
127 output.push(id);
128 successors.clear();
129 successors.extend(graph.nodes[&id].iter().copied());
130 for mid in successors.drain(..) {
131 graph.remove_edge(id, mid);
132 if graph.num_predecessors(mid) == 0 {
133 roots.push_back(mid);
134 }
135 }
136 }
137
138 let has_cycle = graph
139 .nodes
140 .iter()
141 .any(|(n, out_edges)| !output.contains(n) || !out_edges.is_empty());
142 if has_cycle {
143 let mut in_cycle = BTreeSet::default();
144 for (n, edges) in graph.nodes.iter() {
145 if edges.is_empty() {
146 continue;
147 }
148 in_cycle.insert(*n);
149 in_cycle.extend(edges.as_slice());
150 }
151 Err(CycleError(in_cycle))
152 } else {
153 assert!(!expect_cycle, "we expected a cycle to be found, but one was not identified");
154 Ok(output)
155 }
156 }
157
158 pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
161 let mut worklist = VecDeque::from_iter([root]);
162 let mut graph = Self::default();
163 let mut visited = BTreeSet::default();
164
165 while let Some(gid) = worklist.pop_front() {
166 if !visited.insert(gid) {
167 continue;
168 }
169
170 let new_successors = graph.get_or_insert_node(gid);
171 let prev_successors = self.out_edges(gid);
172 worklist.extend(prev_successors.iter().cloned());
173 new_successors.extend_from_slice(prev_successors);
174 }
175
176 graph
177 }
178
179 fn reverse_reachable(&self, root: GlobalItemIndex) -> BTreeSet<GlobalItemIndex> {
181 let mut worklist = VecDeque::from_iter([root]);
182 let mut visited = BTreeSet::default();
183
184 while let Some(gid) = worklist.pop_front() {
185 if !visited.insert(gid) {
186 continue;
187 }
188
189 worklist.extend(
190 self.nodes
191 .iter()
192 .filter(|(_, out_edges)| out_edges.contains(&gid))
193 .map(|(pred, _)| *pred),
194 );
195 }
196
197 visited
198 }
199
200 pub fn toposort_caller(
206 &self,
207 caller: GlobalItemIndex,
208 ) -> Result<Vec<GlobalItemIndex>, CycleError> {
209 let mut output = Vec::with_capacity(self.nodes.len());
210
211 let caller_subgraph = self.subgraph(caller);
213 let mut graph = caller_subgraph.clone();
214
215 let caller_cycle = caller_subgraph
218 .nodes
219 .values()
220 .any(|edges| edges.contains(&caller))
221 .then(|| caller_subgraph.reverse_reachable(caller));
222
223 graph.nodes.iter_mut().for_each(|(_pred, out_edges)| {
225 out_edges.retain(|n| *n != caller);
226 });
227
228 let mut roots = VecDeque::from_iter([caller]);
229 let mut successors = Vec::with_capacity(4);
230 while let Some(id) = roots.pop_front() {
231 output.push(id);
232 successors.clear();
233 successors.extend(graph.nodes[&id].iter().copied());
234 for mid in successors.drain(..) {
235 graph.remove_edge(id, mid);
236 if graph.num_predecessors(mid) == 0 {
237 roots.push_back(mid);
238 }
239 }
240 }
241
242 let has_cycle = output.len() != graph.nodes.len() || caller_cycle.is_some();
243 if has_cycle {
244 let mut in_cycle = BTreeSet::default();
245 for (n, edges) in graph.nodes.iter() {
246 if edges.is_empty() {
247 continue;
248 }
249 in_cycle.insert(*n);
250 in_cycle.extend(edges.as_slice());
251 }
252 if let Some(caller_cycle) = caller_cycle {
253 in_cycle.extend(caller_cycle);
254 }
255 Err(CycleError(in_cycle))
256 } else {
257 Ok(output)
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
266
267 const A: ModuleIndex = ModuleIndex::const_new(1);
268 const B: ModuleIndex = ModuleIndex::const_new(2);
269 const P1: ItemIndex = ItemIndex::const_new(1);
270 const P2: ItemIndex = ItemIndex::const_new(2);
271 const P3: ItemIndex = ItemIndex::const_new(3);
272 const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
273 const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
274 const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
275 const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
276 const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
277 const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
278
279 #[test]
280 fn callgraph_add_edge() {
281 let graph = callgraph_simple();
282
283 assert_eq!(graph.num_predecessors(A1), 0);
285 assert_eq!(graph.num_predecessors(B1), 0);
286 assert_eq!(graph.num_predecessors(A2), 1);
287 assert_eq!(graph.num_predecessors(B2), 2);
288 assert_eq!(graph.num_predecessors(B3), 1);
289 assert_eq!(graph.num_predecessors(A3), 2);
290
291 assert_eq!(graph.out_edges(A1), &[A2]);
292 assert_eq!(graph.out_edges(B1), &[B2]);
293 assert_eq!(graph.out_edges(A2), &[B2, A3]);
294 assert_eq!(graph.out_edges(B2), &[B3]);
295 assert_eq!(graph.out_edges(A3), &[]);
296 assert_eq!(graph.out_edges(B3), &[A3]);
297 }
298
299 #[test]
300 fn callgraph_add_edge_with_cycle() {
301 let graph = callgraph_cycle();
302
303 assert_eq!(graph.num_predecessors(A1), 0);
305 assert_eq!(graph.num_predecessors(B1), 0);
306 assert_eq!(graph.num_predecessors(A2), 2);
307 assert_eq!(graph.num_predecessors(B2), 2);
308 assert_eq!(graph.num_predecessors(B3), 1);
309 assert_eq!(graph.num_predecessors(A3), 1);
310
311 assert_eq!(graph.out_edges(A1), &[A2]);
312 assert_eq!(graph.out_edges(B1), &[B2]);
313 assert_eq!(graph.out_edges(A2), &[B2]);
314 assert_eq!(graph.out_edges(B2), &[B3]);
315 assert_eq!(graph.out_edges(A3), &[A2]);
316 assert_eq!(graph.out_edges(B3), &[A3]);
317 }
318
319 #[test]
320 fn callgraph_subgraph() {
321 let graph = callgraph_simple();
322 let subgraph = graph.subgraph(A2);
323
324 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
325 }
326
327 #[test]
328 fn callgraph_with_cycle_subgraph() {
329 let graph = callgraph_cycle();
330 let subgraph = graph.subgraph(A2);
331
332 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
333 }
334
335 #[test]
336 fn callgraph_toposort() {
337 let graph = callgraph_simple();
338
339 let sorted = graph.toposort().expect("expected valid topological ordering");
340 assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
341 }
342
343 #[test]
344 fn callgraph_toposort_caller() {
345 let graph = callgraph_simple();
346
347 let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
348 assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
349 }
350
351 #[test]
352 fn callgraph_with_cycle_toposort() {
353 let graph = callgraph_cycle();
354
355 let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
356 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
357 }
358
359 #[test]
360 fn callgraph_toposort_caller_with_reachable_cycle() {
361 let graph = callgraph_cycle();
362
363 let err = graph
364 .toposort_caller(A1)
365 .expect_err("expected toposort_caller to fail when a reachable cycle exists");
366 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
367 }
368
369 #[test]
370 fn callgraph_toposort_caller_root_closing_cycle() {
371 let graph = callgraph_cycle();
372
373 let err = graph
374 .toposort_caller(A2)
375 .expect_err("expected toposort_caller to detect cycle closing back into root");
376 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
377 }
378 fn callgraph_simple() -> CallGraph {
383 let mut graph = CallGraph::default();
385 graph.add_edge(A1, A2);
386 graph.add_edge(B1, B2);
387 graph.add_edge(A2, B2);
388 graph.add_edge(A2, A3);
389 graph.add_edge(B2, B3);
390 graph.add_edge(B3, A3);
391
392 graph
393 }
394
395 fn callgraph_cycle() -> CallGraph {
400 let mut graph = CallGraph::default();
402 graph.add_edge(A1, A2);
403 graph.add_edge(B1, B2);
404 graph.add_edge(A2, B2);
405 graph.add_edge(B2, B3);
406 graph.add_edge(B3, A3);
407 graph.add_edge(A3, A2);
408
409 graph
410 }
411}