1use alloc::{collections::BTreeMap, format, string::String, vec::Vec};
26
27#[cfg(feature = "serde")]
28use serde::{Deserialize, Serialize};
29
30use crate::{
31 mast::{AsmOpId, MastNodeId},
32 serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
33 utils::{CsrMatrix, CsrValidationError},
34};
35
36#[derive(Debug, Clone, PartialEq, Eq)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct OpToAsmOpId {
49 inner: CsrMatrix<MastNodeId, (usize, AsmOpId)>,
52}
53
54impl Default for OpToAsmOpId {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl OpToAsmOpId {
61 pub fn new() -> Self {
63 Self { inner: CsrMatrix::new() }
64 }
65
66 pub fn with_capacity(nodes_capacity: usize, operations_capacity: usize) -> Self {
68 Self {
69 inner: CsrMatrix::with_capacity(nodes_capacity, operations_capacity),
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.inner.is_empty()
76 }
77
78 pub fn num_nodes(&self) -> usize {
80 self.inner.num_rows()
81 }
82
83 pub fn num_operations(&self) -> usize {
88 self.inner.num_elements()
89 }
90
91 pub fn add_asm_ops_for_node(
112 &mut self,
113 node_id: MastNodeId,
114 num_operations: usize,
115 asm_ops: Vec<(usize, AsmOpId)>,
116 ) -> Result<(), AsmOpIndexError> {
117 let expected_node = self.num_nodes() as u32;
118 let node_idx = u32::from(node_id);
119
120 if node_idx < expected_node {
122 return Err(AsmOpIndexError::NodeIndex(node_id));
123 }
124
125 for _ in expected_node..node_idx {
127 self.inner.push_empty_row().map_err(|_| AsmOpIndexError::InternalStructure)?;
128 }
129
130 for window in asm_ops.windows(2) {
132 if window[0].0 >= window[1].0 {
133 return Err(AsmOpIndexError::NonIncreasingOpIndices);
134 }
135 }
136
137 if let Some((max_idx, _)) = asm_ops.last()
139 && *max_idx >= num_operations
140 {
141 return Err(AsmOpIndexError::OpIndexOutOfBounds(*max_idx, num_operations));
142 }
143
144 self.inner.push_row(asm_ops).map_err(|_| AsmOpIndexError::InternalStructure)?;
145
146 Ok(())
147 }
148
149 pub fn asm_op_id_for_operation(&self, node_id: MastNodeId, op_idx: usize) -> Option<AsmOpId> {
166 let entries = self.inner.row(node_id)?;
167
168 match entries.binary_search_by_key(&op_idx, |(idx, _)| *idx) {
171 Ok(i) => Some(entries[i].1),
172 Err(i) if i > 0 => Some(entries[i - 1].1),
173 Err(_) => None,
174 }
175 }
176
177 pub fn first_asm_op_for_node(&self, node_id: MastNodeId) -> Option<AsmOpId> {
182 let entries = self.inner.row(node_id)?;
183 entries.first().map(|(_, id)| *id)
184 }
185
186 pub(super) fn validate_csr(&self, asm_op_count: usize) -> Result<(), String> {
196 self.inner
197 .validate_with(|(_op_idx, asm_op_id)| (u32::from(*asm_op_id) as usize) < asm_op_count)
198 .map_err(|e| format_validation_error(e, asm_op_count))
199 }
200
201 pub fn remap_nodes(&self, remapping: &BTreeMap<MastNodeId, MastNodeId>) -> Self {
209 if self.is_empty() {
210 return Self::new();
211 }
212 if remapping.is_empty() {
213 return self.clone();
215 }
216
217 let max_new_id = remapping.values().map(|id| u32::from(*id)).max().unwrap_or(0) as usize;
219 let num_new_nodes = max_new_id + 1;
220
221 let mut new_node_data: BTreeMap<usize, Vec<(usize, AsmOpId)>> = BTreeMap::new();
223
224 for (old_id, new_id) in remapping {
225 let new_idx = u32::from(*new_id) as usize;
226
227 if let Some(entries) = self.inner.row(*old_id)
228 && !entries.is_empty()
229 {
230 new_node_data.insert(new_idx, entries.to_vec());
231 }
232 }
233
234 let mut new_inner = CsrMatrix::with_capacity(num_new_nodes, self.inner.num_elements());
236
237 for new_idx in 0..num_new_nodes {
238 if let Some(data) = new_node_data.get(&new_idx) {
239 new_inner.push_row(data.iter().copied()).expect("node count should fit in u32");
240 } else {
241 new_inner.push_empty_row().expect("node count should fit in u32");
242 }
243 }
244
245 Self { inner: new_inner }
246 }
247
248 pub(super) fn write_into<W: ByteWriter>(&self, target: &mut W) {
250 self.inner.write_into(target);
251 }
252
253 pub(super) fn read_from<R: ByteReader>(
255 source: &mut R,
256 asm_op_count: usize,
257 ) -> Result<Self, DeserializationError> {
258 let inner: CsrMatrix<MastNodeId, (usize, AsmOpId)> = Deserializable::read_from(source)?;
259
260 let result = Self { inner };
261
262 result.validate_csr(asm_op_count).map_err(|e| {
263 DeserializationError::InvalidValue(format!("OpToAsmOpId validation failed: {}", e))
264 })?;
265
266 Ok(result)
267 }
268}
269
270#[derive(Debug, PartialEq, Eq, thiserror::Error)]
275pub enum AsmOpIndexError {
276 #[error("Invalid node index {0:?}")]
278 NodeIndex(MastNodeId),
279 #[error("Operation indices must be strictly increasing")]
281 NonIncreasingOpIndices,
282 #[error("Operation index {0} exceeds node's operation count {1}")]
284 OpIndexOutOfBounds(usize, usize),
285 #[error("Internal CSR structure error")]
287 InternalStructure,
288}
289
290fn format_validation_error(error: CsrValidationError, asm_op_count: usize) -> String {
292 match error {
293 CsrValidationError::IndptrStartNotZero(val) => format!("indptr must start at 0, got {val}"),
294 CsrValidationError::IndptrNotMonotonic { index, prev, curr } => {
295 format!("indptr not monotonic at index {index}: {prev} > {curr}")
296 },
297 CsrValidationError::IndptrDataMismatch { indptr_end, data_len } => {
298 format!("indptr ends at {indptr_end}, but data.len() is {data_len}")
299 },
300 CsrValidationError::InvalidData { row, position } => format!(
301 "Invalid AsmOpId at row {row}, position {position}: exceeds asm_op count {asm_op_count}"
302 ),
303 }
304}
305
306#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::serde::SliceReader;
313
314 fn test_asm_op_id(value: u32) -> AsmOpId {
316 AsmOpId::new(value)
317 }
318
319 fn test_node_id(value: u32) -> MastNodeId {
321 MastNodeId::new_unchecked(value)
322 }
323
324 #[test]
329 fn test_op_to_asm_op_id_empty() {
330 let storage = OpToAsmOpId::new();
331 assert!(storage.is_empty());
332 assert_eq!(storage.num_nodes(), 0);
333 assert_eq!(storage.num_operations(), 0);
334 }
335
336 #[test]
337 fn test_op_to_asm_op_id_default() {
338 let storage = OpToAsmOpId::default();
339 assert!(storage.is_empty());
340 }
341
342 #[test]
343 fn test_op_to_asm_op_id_with_capacity() {
344 let storage = OpToAsmOpId::with_capacity(10, 100);
345 assert!(storage.is_empty());
346 assert_eq!(storage.num_nodes(), 0);
347 }
348
349 #[test]
354 fn test_op_to_asm_op_id_single_node() {
355 let mut storage = OpToAsmOpId::new();
356 let node_id = test_node_id(0);
357 let asm_op_id = test_asm_op_id(0);
358
359 storage.add_asm_ops_for_node(node_id, 3, vec![(2, asm_op_id)]).unwrap();
361
362 assert!(!storage.is_empty());
363 assert_eq!(storage.num_nodes(), 1);
364 assert_eq!(storage.num_operations(), 1); assert_eq!(storage.asm_op_id_for_operation(node_id, 0), None);
368 assert_eq!(storage.asm_op_id_for_operation(node_id, 1), None);
369 assert_eq!(storage.asm_op_id_for_operation(node_id, 2), Some(asm_op_id));
370 assert_eq!(storage.asm_op_id_for_operation(node_id, 3), Some(asm_op_id)); }
372
373 #[test]
374 fn test_op_to_asm_op_id_single_node_multiple_ops() {
375 let mut storage = OpToAsmOpId::new();
376 let node_id = test_node_id(0);
377
378 storage
380 .add_asm_ops_for_node(
381 node_id,
382 6,
383 vec![(0, test_asm_op_id(10)), (2, test_asm_op_id(20)), (5, test_asm_op_id(30))],
384 )
385 .unwrap();
386
387 assert_eq!(storage.num_operations(), 3); assert_eq!(storage.asm_op_id_for_operation(node_id, 0), Some(test_asm_op_id(10)));
391 assert_eq!(storage.asm_op_id_for_operation(node_id, 1), Some(test_asm_op_id(10)));
392 assert_eq!(storage.asm_op_id_for_operation(node_id, 2), Some(test_asm_op_id(20)));
393 assert_eq!(storage.asm_op_id_for_operation(node_id, 3), Some(test_asm_op_id(20)));
394 assert_eq!(storage.asm_op_id_for_operation(node_id, 4), Some(test_asm_op_id(20)));
395 assert_eq!(storage.asm_op_id_for_operation(node_id, 5), Some(test_asm_op_id(30)));
396 }
397
398 #[test]
399 fn test_op_to_asm_op_id_empty_node() {
400 let mut storage = OpToAsmOpId::new();
401 let node_id = test_node_id(0);
402
403 storage.add_asm_ops_for_node(node_id, 0, vec![]).unwrap();
405
406 assert!(!storage.is_empty());
407 assert_eq!(storage.num_nodes(), 1);
408 assert_eq!(storage.num_operations(), 0);
409
410 assert_eq!(storage.asm_op_id_for_operation(node_id, 0), None);
412 }
413
414 #[test]
419 fn test_op_to_asm_op_id_multiple_nodes() {
420 let mut storage = OpToAsmOpId::new();
421
422 storage
424 .add_asm_ops_for_node(test_node_id(0), 2, vec![(1, test_asm_op_id(0))])
425 .unwrap();
426
427 storage
429 .add_asm_ops_for_node(
430 test_node_id(1),
431 3,
432 vec![(0, test_asm_op_id(1)), (2, test_asm_op_id(2))],
433 )
434 .unwrap();
435
436 assert_eq!(storage.num_nodes(), 2);
437
438 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), None);
440 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 1), Some(test_asm_op_id(0)));
441
442 assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), Some(test_asm_op_id(1)));
445 assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 1), Some(test_asm_op_id(1)));
446 assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 2), Some(test_asm_op_id(2)));
447 }
448
449 #[test]
450 fn test_op_to_asm_op_id_mixed_empty_and_populated_nodes() {
451 let mut storage = OpToAsmOpId::new();
452
453 storage
455 .add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(0))])
456 .unwrap();
457
458 storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
460
461 storage
463 .add_asm_ops_for_node(test_node_id(2), 2, vec![(1, test_asm_op_id(1))])
464 .unwrap();
465
466 assert_eq!(storage.num_nodes(), 3);
467
468 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), Some(test_asm_op_id(0)));
469 assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), None);
470 assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 0), None);
471 assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 1), Some(test_asm_op_id(1)));
472 }
473
474 #[test]
475 fn test_op_to_asm_op_id_gap_in_nodes() {
476 let mut storage = OpToAsmOpId::new();
477
478 storage
480 .add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(0))])
481 .unwrap();
482
483 storage
485 .add_asm_ops_for_node(test_node_id(2), 1, vec![(0, test_asm_op_id(1))])
486 .unwrap();
487
488 assert_eq!(storage.num_nodes(), 3);
489
490 assert_eq!(storage.asm_op_id_for_operation(test_node_id(1), 0), None);
492
493 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), Some(test_asm_op_id(0)));
495 assert_eq!(storage.asm_op_id_for_operation(test_node_id(2), 0), Some(test_asm_op_id(1)));
496 }
497
498 #[test]
503 fn test_first_asm_op_for_node() {
504 let mut storage = OpToAsmOpId::new();
505
506 storage
508 .add_asm_ops_for_node(test_node_id(0), 3, vec![(2, test_asm_op_id(42))])
509 .unwrap();
510
511 assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), Some(test_asm_op_id(42)));
512 }
513
514 #[test]
515 fn test_first_asm_op_for_node_empty() {
516 let mut storage = OpToAsmOpId::new();
517
518 storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]).unwrap();
519
520 assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), None);
521 }
522
523 #[test]
524 fn test_first_asm_op_for_node_nonexistent() {
525 let storage = OpToAsmOpId::new();
526
527 assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), None);
528 }
529
530 #[test]
531 fn test_first_asm_op_for_node_multiple_ops() {
532 let mut storage = OpToAsmOpId::new();
533
534 storage
536 .add_asm_ops_for_node(
537 test_node_id(0),
538 4,
539 vec![(1, test_asm_op_id(10)), (3, test_asm_op_id(30))],
540 )
541 .unwrap();
542
543 assert_eq!(storage.first_asm_op_for_node(test_node_id(0)), Some(test_asm_op_id(10)));
545 }
546
547 #[test]
552 fn test_op_to_asm_op_id_non_increasing_ops() {
553 let mut storage = OpToAsmOpId::new();
554
555 let result = storage.add_asm_ops_for_node(
557 test_node_id(0),
558 3,
559 vec![(2, test_asm_op_id(0)), (1, test_asm_op_id(1))],
560 );
561
562 assert_eq!(result, Err(AsmOpIndexError::NonIncreasingOpIndices));
563 }
564
565 #[test]
566 fn test_op_to_asm_op_id_duplicate_ops() {
567 let mut storage = OpToAsmOpId::new();
568
569 let result = storage.add_asm_ops_for_node(
571 test_node_id(0),
572 2,
573 vec![(1, test_asm_op_id(0)), (1, test_asm_op_id(1))],
574 );
575
576 assert_eq!(result, Err(AsmOpIndexError::NonIncreasingOpIndices));
577 }
578
579 #[test]
580 fn test_op_to_asm_op_id_node_already_added() {
581 let mut storage = OpToAsmOpId::new();
582
583 storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]).unwrap();
584 storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
585
586 let result = storage.add_asm_ops_for_node(test_node_id(0), 0, vec![]);
588
589 assert_eq!(result, Err(AsmOpIndexError::NodeIndex(test_node_id(0))));
590 }
591
592 #[test]
597 fn test_op_to_asm_op_id_query_nonexistent_node() {
598 let storage = OpToAsmOpId::new();
599
600 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 0), None);
601 assert_eq!(storage.asm_op_id_for_operation(test_node_id(999), 0), None);
602 }
603
604 #[test]
605 fn test_op_to_asm_op_id_query_out_of_bounds_op() {
606 let mut storage = OpToAsmOpId::new();
607
608 storage
610 .add_asm_ops_for_node(test_node_id(0), 2, vec![(1, test_asm_op_id(0))])
611 .unwrap();
612
613 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 2), Some(test_asm_op_id(0)));
615 assert_eq!(storage.asm_op_id_for_operation(test_node_id(0), 100), Some(test_asm_op_id(0)));
616 }
617
618 #[test]
623 fn test_validate_csr_empty() {
624 let storage = OpToAsmOpId::new();
625 assert!(storage.validate_csr(0).is_ok());
626 }
627
628 #[test]
629 fn test_validate_csr_valid() {
630 let mut storage = OpToAsmOpId::new();
631 storage
633 .add_asm_ops_for_node(
634 test_node_id(0),
635 2,
636 vec![(0, test_asm_op_id(0)), (1, test_asm_op_id(1))],
637 )
638 .unwrap();
639
640 assert!(storage.validate_csr(2).is_ok());
641 }
642
643 #[test]
644 fn test_validate_csr_invalid_asm_op_id() {
645 let mut storage = OpToAsmOpId::new();
646 storage
648 .add_asm_ops_for_node(
649 test_node_id(0),
650 2,
651 vec![(0, test_asm_op_id(0)), (1, test_asm_op_id(5))],
652 )
653 .unwrap();
654
655 let result = storage.validate_csr(2);
657 assert!(result.is_err());
658 assert!(result.unwrap_err().contains("Invalid AsmOpId"));
659 }
660
661 #[test]
666 fn test_serialization_roundtrip_empty() {
667 let storage = OpToAsmOpId::new();
668
669 let mut bytes = alloc::vec::Vec::new();
670 storage.write_into(&mut bytes);
671
672 let mut reader = SliceReader::new(&bytes);
673 let deserialized = OpToAsmOpId::read_from(&mut reader, 0).unwrap();
674
675 assert_eq!(storage, deserialized);
676 }
677
678 #[test]
679 fn test_serialization_roundtrip_with_data() {
680 let mut storage = OpToAsmOpId::new();
681 storage
683 .add_asm_ops_for_node(
684 test_node_id(0),
685 3,
686 vec![(0, test_asm_op_id(0)), (2, test_asm_op_id(1))],
687 )
688 .unwrap();
689 storage.add_asm_ops_for_node(test_node_id(1), 0, vec![]).unwrap();
691 storage
693 .add_asm_ops_for_node(test_node_id(2), 2, vec![(1, test_asm_op_id(2))])
694 .unwrap();
695
696 let mut bytes = alloc::vec::Vec::new();
697 storage.write_into(&mut bytes);
698
699 let mut reader = SliceReader::new(&bytes);
700 let deserialized = OpToAsmOpId::read_from(&mut reader, 3).unwrap();
701
702 assert_eq!(storage, deserialized);
703 }
704
705 #[test]
710 fn test_clone_and_equality() {
711 let mut storage1 = OpToAsmOpId::new();
712 storage1
713 .add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(42))])
714 .unwrap();
715
716 let storage2 = storage1.clone();
717 assert_eq!(storage1, storage2);
718
719 let mut storage3 = OpToAsmOpId::new();
720 storage3
721 .add_asm_ops_for_node(test_node_id(0), 1, vec![(0, test_asm_op_id(99))])
722 .unwrap();
723
724 assert_ne!(storage1, storage3);
725 }
726
727 #[test]
728 fn test_debug_impl() {
729 let storage = OpToAsmOpId::new();
730 let debug_str = alloc::format!("{:?}", storage);
731 assert!(debug_str.contains("OpToAsmOpId"));
732 }
733}