skydive/
wmec.rs

1//! This module implements the `WhatsHap` phasing algorithm as described in:
2//! Murray Patterson, Tobias Marschall, Nadia Pisanti, Leo van Iersel, Leen Stougie, et al.
3//! `WhatsHap`: Weighted Haplotype Assembly for Future-Generation Sequencing Reads.
4//! Journal of Computational Biology, 2015, 22 (6), pp.498-509.
5//! DOI: 10.1089/cmb.2014.0157
6//! HAL ID: hal-01225988
7
8use std::collections::{BTreeMap, BTreeSet, HashMap};
9use std::fs::File;
10use std::io::Write;
11
12use itertools::Itertools;
13
14use indicatif::ParallelProgressIterator;
15use rayon::prelude::*;
16
17#[derive(Debug)]
18pub struct WMECData {
19    pub reads: Vec<Vec<Option<u8>>>, // Reads matrix where None represents missing data
20    pub confidences: Vec<Vec<Option<u32>>>, // Confidence degrees matrix
21    pub num_snps: usize,             // Number of SNP positions
22}
23
24impl WMECData {
25    // Initialize the data structure with given reads and confidences
26    #[must_use]
27    pub fn new(reads: Vec<Vec<Option<u8>>>, confidences: Vec<Vec<Option<u32>>>) -> Self {
28        let num_snps = reads[0].len();
29        WMECData {
30            reads,
31            confidences,
32            num_snps,
33        }
34    }
35
36    // Function to compute W^0(j, R) and W^1(j, R)
37    // Cost to set all fragments in set R to 0 or 1 at SNP j
38    #[must_use]
39    pub fn compute_costs(&self, snp: usize, set_r: &BTreeSet<usize>) -> (u32, u32) {
40        let mut w0 = 0; // Cost for setting to 0
41        let mut w1 = 0; // Cost for setting to 1
42
43        for &read_index in set_r {
44            if let Some(allele) = self.reads[read_index][snp] {
45                if let Some(confidence) = self.confidences[read_index][snp] {
46                    if allele == 0 {
47                        w1 += confidence; // Cost of flipping 0 -> 1
48                    } else {
49                        w0 += confidence; // Cost of flipping 1 -> 0
50                    }
51                }
52            }
53        }
54
55        (w0, w1)
56    }
57
58    // Calculate minimum correction cost Delta C(j, (R, S))
59    #[must_use]
60    pub fn delta_c(&self, snp: usize, r: &BTreeSet<usize>, s: &BTreeSet<usize>) -> u32 {
61        let (w0_r, w1_r) = self.compute_costs(snp, r);
62        let (w0_s, w1_s) = self.compute_costs(snp, s);
63
64        // Given a bipartition (R, S) of F(j), the minimum cost to make position j conflict-free is:
65        std::cmp::min(w0_r, w1_r) + std::cmp::min(w0_s, w1_s)
66
67        // Alternatively, under the all heterozygous assumption, where one wants to enforce all SNPs
68        // to be heterozygous, the equation becomes:
69        // std::cmp::min(w0_r + w1_r, w0_s + w1_s)
70    }
71
72    // Write reads matrix to a file in tab-separated format
73    pub fn write_reads_matrix(&self, path: &str) -> std::io::Result<()> {
74        let mut file = File::create(path)?;
75
76        // Write header row with SNP positions
77        for j in 0..self.num_snps {
78            write!(file, "\tSNP_{}", j)?;
79        }
80        writeln!(file)?;
81
82        // Write each read's data
83        for (i, (read, conf)) in self.reads.iter().zip(self.confidences.iter()).enumerate() {
84            write!(file, "Read_{}", i)?;
85            for (allele, qual) in read.iter().zip(conf.iter()) {
86                if let Some(allele) = allele {
87                    if *allele == 0 {
88                        write!(file, "\t0,{}", qual.unwrap())?
89                    };
90                    if *allele == 1 {
91                        write!(file, "\t1,{}", qual.unwrap())?
92                    };
93                } else {
94                    write!(file, "\t-,1")?;
95                }
96            }
97            writeln!(file)?;
98        }
99
100        Ok(())
101    }
102}
103
104// Function to generate all bipartitions of a set
105fn generate_bipartitions(set: &BTreeSet<usize>) -> Vec<(BTreeSet<usize>, BTreeSet<usize>)> {
106    let mut partitions = vec![];
107    let set_vec: Vec<_> = set.iter().collect();
108
109    let num_partitions = 1 << set_vec.len(); // 2^|set|
110    for i in 0..num_partitions {
111        let mut r = BTreeSet::new();
112        let mut s = BTreeSet::new();
113
114        for (j, &elem) in set_vec.iter().enumerate() {
115            if i & (1 << j) == 0 {
116                r.insert(*elem);
117            } else {
118                s.insert(*elem);
119            }
120        }
121
122        partitions.push((r, s));
123    }
124
125    partitions
126}
127
128// Function to initialize the DP table for SNP 0
129fn initialize_dp(
130    data: &WMECData,
131) -> (
132    HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
133    HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>,
134) {
135    let mut dp: HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32> = HashMap::new();
136    let mut backtrack: HashMap<
137        (usize, BTreeSet<usize>, BTreeSet<usize>),
138        Option<(BTreeSet<usize>, BTreeSet<usize>)>,
139    > = HashMap::new();
140
141    let active_fragments: BTreeSet<usize> = data
142        .reads
143        .iter()
144        .enumerate()
145        .filter(|(_, read)| read[0].is_some()) // Only consider fragments covering SNP 0
146        .map(|(index, _)| index)
147        .collect();
148
149    let partitions = generate_bipartitions(&active_fragments);
150    for (r, s) in partitions {
151        let cost = data.delta_c(0, &r, &s);
152        dp.insert((0, r.clone(), s.clone()), cost);
153        backtrack.insert((0, r.clone(), s.clone()), None);
154    }
155
156    (dp, backtrack)
157}
158
159// Function to update the DP table for each SNP position
160fn update_dp_old(
161    data: &WMECData,
162    dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
163    backtrack: &mut HashMap<
164        (usize, BTreeSet<usize>, BTreeSet<usize>),
165        Option<(BTreeSet<usize>, BTreeSet<usize>)>,
166    >,
167    snp: usize,
168) {
169    let active_fragments: BTreeSet<usize> = data
170        .reads
171        .iter()
172        .enumerate()
173        .filter(|(_, read)| read[snp].is_some()) // Only consider fragments covering SNP
174        .map(|(index, _)| index)
175        .collect();
176    let partitions = generate_bipartitions(&active_fragments);
177
178    for (r, s) in &partitions {
179        let delta_cost = data.delta_c(snp, r, s);
180        let mut min_cost = u32::MAX;
181        let mut best_bipartition = None;
182
183        let prev_active_fragments: BTreeSet<usize> = data
184            .reads
185            .iter()
186            .enumerate()
187            .filter(|(_, read)| read[snp - 1].is_some()) // Fragments covering the previous SNP
188            .map(|(index, _)| index)
189            .collect();
190
191        for (prev_r, prev_s) in generate_bipartitions(&prev_active_fragments) {
192            let r_compatible = r
193                .intersection(&prev_active_fragments)
194                .all(|&x| prev_r.contains(&x));
195            let s_compatible = s
196                .intersection(&prev_active_fragments)
197                .all(|&x| prev_s.contains(&x));
198
199            if r_compatible && s_compatible {
200                if let Some(&prev_cost) = dp.get(&(snp - 1, prev_r.clone(), prev_s.clone())) {
201                    let current_cost = delta_cost + prev_cost;
202
203                    if current_cost < min_cost {
204                        min_cost = current_cost;
205                        best_bipartition = Some((prev_r.clone(), prev_s.clone()));
206                    }
207                }
208            }
209        }
210
211        dp.insert((snp, r.clone(), s.clone()), min_cost);
212        backtrack.insert((snp, r.clone(), s.clone()), best_bipartition);
213    }
214}
215
216fn update_dp(
217    data: &WMECData,
218    dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
219    backtrack: &mut HashMap<
220        (usize, BTreeSet<usize>, BTreeSet<usize>),
221        Option<(BTreeSet<usize>, BTreeSet<usize>)>,
222    >,
223    snp: usize,
224) {
225    let active_fragments: BTreeSet<usize> = data
226        .reads
227        .iter()
228        .enumerate()
229        .filter(|(_, read)| read[snp].is_some())
230        .map(|(index, _)| index)
231        .collect();
232    let partitions = generate_bipartitions(&active_fragments);
233
234    // Pre-compute prev_active_fragments since it's used by all iterations
235    let prev_active_fragments: BTreeSet<usize> = data
236        .reads
237        .iter()
238        .enumerate()
239        .filter(|(_, read)| read[snp - 1].is_some())
240        .map(|(index, _)| index)
241        .collect();
242
243    // Collect results in parallel
244    let results: Vec<_> = partitions
245        .par_iter()
246        .map(|(r, s)| {
247            let delta_cost = data.delta_c(snp, r, s);
248            let mut min_cost = u32::MAX;
249            let mut best_bipartition = None;
250
251            for (prev_r, prev_s) in generate_bipartitions(&prev_active_fragments) {
252                let r_compatible = r
253                    .intersection(&prev_active_fragments)
254                    .all(|&x| prev_r.contains(&x));
255                let s_compatible = s
256                    .intersection(&prev_active_fragments)
257                    .all(|&x| prev_s.contains(&x));
258
259                if r_compatible && s_compatible {
260                    if let Some(&prev_cost) = dp.get(&(snp - 1, prev_r.clone(), prev_s.clone())) {
261                        let current_cost = delta_cost + prev_cost;
262
263                        if current_cost < min_cost {
264                            min_cost = current_cost;
265                            best_bipartition = Some((prev_r.clone(), prev_s.clone()));
266                        }
267                    }
268                }
269            }
270
271            ((r.clone(), s.clone()), (min_cost, best_bipartition))
272        })
273        .collect();
274
275    // Update the hashmaps with the results
276    for ((r, s), (min_cost, best_bipartition)) in results {
277        dp.insert((snp, r.clone(), s.clone()), min_cost);
278        backtrack.insert((snp, r, s), best_bipartition);
279    }
280}
281
282fn backtrack_haplotypes(
283    data: &WMECData,
284    dp: &HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
285    backtrack: &HashMap<
286        (usize, BTreeSet<usize>, BTreeSet<usize>),
287        Option<(BTreeSet<usize>, BTreeSet<usize>)>,
288    >,
289) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
290    let mut best_cost = u32::MAX;
291    let mut best_bipartition = None;
292
293    // Restrict processing to reads that span variants within the window.
294    let final_active_fragments: BTreeSet<usize> = data
295        .reads
296        .iter()
297        .enumerate()
298        .filter(|(_, read)| read[data.num_snps - 1].is_some())
299        .map(|(index, _)| index)
300        .collect();
301
302    for (r, s) in generate_bipartitions(&final_active_fragments) {
303        if let Some(&cost) = dp.get(&(data.num_snps - 1, r.clone(), s.clone())) {
304            if cost < best_cost {
305                best_cost = cost;
306                best_bipartition = Some((r, s));
307            }
308        }
309    }
310
311    let mut haplotype1 = vec![0; data.num_snps];
312    let mut haplotype2 = vec![0; data.num_snps];
313    let mut current_bipartition = best_bipartition.unwrap();
314
315    for snp in (0..data.num_snps).rev() {
316        let (r, s) = &current_bipartition;
317        let (w0_r, w1_r) = data.compute_costs(snp, r);
318        let (w0_s, w1_s) = data.compute_costs(snp, s);
319
320        if w0_r + w1_s <= w1_r + w0_s {
321            haplotype1[snp] = 0;
322            haplotype2[snp] = 1;
323        } else {
324            haplotype1[snp] = 1;
325            haplotype2[snp] = 0;
326        }
327
328        if snp > 0 {
329            if let Some(prev_bipartition) = backtrack
330                .get(&(snp, r.clone(), s.clone()))
331                .and_then(|x| x.as_ref())
332            {
333                current_bipartition = prev_bipartition.clone();
334            } else {
335                // This should not happen if the DP table is correctly filled
336                panic!("No valid previous bipartition found for SNP {snp}");
337            }
338        }
339    }
340
341    crate::elog!("wmec score: {}", best_cost);
342
343    (
344        haplotype1,
345        haplotype2,
346        current_bipartition.0,
347        current_bipartition.1,
348    )
349}
350
351// Main function to perform WMEC using dynamic programming
352#[must_use]
353pub fn phase(data: &WMECData) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
354    let (mut dp, mut backtrack) = initialize_dp(data);
355
356    for snp in 1..data.num_snps {
357        update_dp(data, &mut dp, &mut backtrack, snp);
358    }
359
360    backtrack_haplotypes(data, &dp, &backtrack)
361}
362
363#[must_use]
364pub fn phase_all(
365    data: &WMECData,
366    window: usize,
367    stride: usize,
368) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
369    data.write_reads_matrix("mat.tsv");
370
371    // First, collect all window ranges
372    let mut windows: Vec<_> = (0..data.num_snps)
373        .step_by(stride)
374        .map(|start| (start, (start + window).min(data.num_snps)))
375        .collect();
376
377    // Filter out last window if it completely overlaps with previous window
378    if windows.len() > 1 {
379        if let Some(&(_, prev_end)) = windows.get(windows.len() - 2) {
380            if let Some(&(_, last_end)) = windows.last() {
381                if last_end == prev_end {
382                    windows.pop();
383                }
384            }
385        }
386    }
387
388    let pb = crate::utils::default_bounded_progress_bar("Processing windows", windows.len() as u64);
389
390    // Process windows in parallel
391    let haplotype_pairs: Vec<_> = windows
392        // .iter()
393        .par_iter()
394        .progress_with(pb)
395        .map(|&(start, end)| {
396            let (window_indices, window_reads): (Vec<usize>, Vec<Vec<Option<u8>>>) = data
397                .reads
398                .iter()
399                .zip(data.confidences.iter())
400                .enumerate()
401                .map(|(read_idx, (read, confidence))| {
402                    let window_read = read[start..end].to_vec();
403                    let window_confidence = confidence[start..end].to_vec();
404
405                    let none_count = window_read.iter().filter(|x| x.is_none()).count();
406                    let lowqual_count = window_confidence
407                        .iter()
408                        .filter(|&&x| x.is_some() && x.unwrap() == 0)
409                        .count();
410
411                    (
412                        none_count + lowqual_count,
413                        read_idx,
414                        window_read,
415                        window_confidence,
416                    )
417                })
418                .collect::<Vec<_>>()
419                .into_iter()
420                .sorted_by_key(|(bad_count, _, _, _)| *bad_count)
421                .take(12)
422                .map(|(_, read_idx, window_read, _)| (read_idx, window_read))
423                .unzip();
424
425            let (_, window_confidences): (Vec<usize>, Vec<Vec<Option<u32>>>) = data
426                .reads
427                .iter()
428                .zip(data.confidences.iter())
429                .enumerate()
430                .map(|(read_idx, (read, confidence))| {
431                    let window_read = read[start..end].to_vec();
432                    let window_confidence = confidence[start..end].to_vec();
433
434                    let none_count = window_read.iter().filter(|x| x.is_none()).count();
435                    let lowqual_count = window_confidence
436                        .iter()
437                        .filter(|&&x| x.is_some() && x.unwrap() == 0)
438                        .count();
439
440                    (
441                        none_count + lowqual_count,
442                        read_idx,
443                        window_read,
444                        window_confidence,
445                    )
446                })
447                .collect::<Vec<_>>()
448                .into_iter()
449                .sorted_by_key(|(bad_count, _, _, _)| *bad_count)
450                .take(12)
451                .map(|(_, read_idx, _, window_confidences)| (read_idx, window_confidences))
452                .unzip();
453
454            crate::elog!("window_reads: {} {} {}", start, end, window_reads.len());
455            crate::elog!("window_indices: {} {} {:?}", start, end, window_indices);
456
457            let window_data = WMECData::new(window_reads, window_confidences);
458
459            let (mut dp, mut backtrack) = initialize_dp(&window_data);
460            for snp in 1..window_data.num_snps {
461                update_dp(&window_data, &mut dp, &mut backtrack, snp);
462            }
463
464            let (hap1, hap2, part1, part2) = backtrack_haplotypes(&window_data, &dp, &backtrack);
465
466            let re1: BTreeSet<_> = part1.iter().map(|&i| window_indices[i as usize]).collect();
467            let re2: BTreeSet<_> = part2.iter().map(|&i| window_indices[i as usize]).collect();
468
469            crate::elog!("window_re1: {} {} {:?} {:?}", start, end, part1, re1);
470            crate::elog!("window_re2: {} {} {:?} {:?}", start, end, part2, re2);
471
472            (hap1, hap2, re1, re2)
473        })
474        .collect();
475
476    let mut haplotype1 = Vec::new();
477    let mut haplotype2 = Vec::new();
478    let mut part1 = BTreeSet::new();
479    let mut part2 = BTreeSet::new();
480
481    let overlap = window - stride;
482
483    let mut all_read_assignments = BTreeMap::new();
484
485    let mut i = 0;
486    for (hap1, hap2, reads1, reads2) in haplotype_pairs {
487        // crate::elog!("{}", i);
488        // crate::elog!("{:?}", reads1);
489        // crate::elog!("{:?}", reads2);
490
491        i += 1;
492
493        if haplotype1.len() == 0 {
494            for read_num in reads1.clone() {
495                all_read_assignments
496                    .entry(read_num)
497                    .or_insert(vec![])
498                    .push(1);
499            }
500
501            for read_num in reads2.clone() {
502                all_read_assignments
503                    .entry(read_num)
504                    .or_insert(vec![])
505                    .push(2);
506            }
507
508            haplotype1 = hap1;
509            haplotype2 = hap2;
510            part1 = reads1;
511            part2 = reads2;
512        } else {
513            // Compare overlap regions to determine orientation
514            crate::elog!(
515                "haplotypes: {} {} {}",
516                haplotype1.len(),
517                haplotype2.len(),
518                overlap
519            );
520
521            let new_overlap_len = std::cmp::min(overlap, hap1.len());
522            let h1_overlap = &haplotype1[haplotype1.len() - new_overlap_len..];
523            let h2_overlap = &haplotype2[haplotype2.len() - new_overlap_len..];
524            let new_overlap = &hap1[..new_overlap_len];
525
526            // Count matches between overlapping regions
527            let h1_matches = h1_overlap
528                .iter()
529                .zip(new_overlap.iter())
530                .filter(|(a, b)| a == b)
531                .count();
532            let h2_matches = h2_overlap
533                .iter()
534                .zip(new_overlap.iter())
535                .filter(|(a, b)| a == b)
536                .count();
537
538            // Append new haplotypes in correct orientation based on best overlap match
539            if h1_matches >= h2_matches {
540                haplotype1.extend_from_slice(&hap1[std::cmp::min(overlap, hap1.len())..]);
541                haplotype2.extend_from_slice(&hap2[std::cmp::min(overlap, hap2.len())..]);
542
543                for read_num in reads1.clone() {
544                    all_read_assignments
545                        .entry(read_num)
546                        .or_insert(vec![])
547                        .push(1);
548                }
549
550                for read_num in reads2.clone() {
551                    all_read_assignments
552                        .entry(read_num)
553                        .or_insert(vec![])
554                        .push(2);
555                }
556            } else {
557                haplotype1.extend_from_slice(&hap2[std::cmp::min(overlap, hap2.len())..]);
558                haplotype2.extend_from_slice(&hap1[std::cmp::min(overlap, hap1.len())..]);
559
560                for read_num in reads1.clone() {
561                    all_read_assignments
562                        .entry(read_num)
563                        .or_insert(vec![])
564                        .push(2);
565                }
566
567                for read_num in reads2.clone() {
568                    all_read_assignments
569                        .entry(read_num)
570                        .or_insert(vec![])
571                        .push(1);
572                }
573            }
574        }
575    }
576
577    let mut final_part1 = BTreeSet::new();
578    let mut final_part2 = BTreeSet::new();
579
580    crate::elog!("all_read_assignments: {:?}", &all_read_assignments);
581
582    for (read_num, assignments) in all_read_assignments {
583        // Count frequency of assignments (1 or 2) for this read
584        let most_common = assignments
585            .iter()
586            .fold(HashMap::new(), |mut counts, &value| {
587                *counts.entry(value).or_insert(0) += 1;
588                counts
589            })
590            .into_iter()
591            .max_by_key(|&(_, count)| count)
592            .map(|(value, _)| value)
593            .unwrap_or(0); // Default to 0 if no assignments
594
595        if most_common == 1 {
596            final_part1.insert(read_num);
597        } else if most_common == 2 {
598            final_part2.insert(read_num);
599        }
600    }
601
602    (haplotype1, haplotype2, final_part1, final_part2)
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_initialization_manuscript_example_1() {
611        // Define the dataset as a matrix of reads
612        let reads = vec![
613            vec![Some(0), None],    // f0
614            vec![Some(1), Some(0)], // f1
615            vec![Some(1), Some(1)], // f2
616            vec![None, Some(0)],    // f3
617        ];
618
619        // Define the confidence degrees
620        let confidences = vec![
621            vec![Some(5), None],    // f0
622            vec![Some(3), Some(2)], // f1
623            vec![Some(6), Some(1)], // f2
624            vec![None, Some(2)],    // f3
625        ];
626
627        // Initialize the dataset
628        let data = WMECData::new(reads, confidences);
629
630        // Define the expected values for C(1, ·)
631        let expected_values = vec![
632            (vec![0, 1, 2], vec![], 5), // C(1, ({f0, f1, f2}, ∅)) = 5
633            (vec![0, 1], vec![2], 3),   // C(1, ({f0, f1}, {f2})) = 3
634            (vec![0, 2], vec![1], 5),   // C(1, ({f0, f2}, {f1})) = 5
635            (vec![1, 2], vec![0], 0),   // C(1, ({f1, f2}, {f0})) = 0
636        ];
637
638        // Check that the values in the dynamic programming table match the expected values
639        for (r, s, expected_cost) in expected_values {
640            let r_set: BTreeSet<usize> = r.into_iter().collect();
641            let s_set: BTreeSet<usize> = s.into_iter().collect();
642            let cost = data.delta_c(0, &r_set, &s_set);
643            assert_eq!(
644                cost, expected_cost,
645                "Cost for partition ({:?}, {:?}) does not match expected",
646                r_set, s_set
647            );
648        }
649    }
650
651    #[test]
652    fn test_recurrence_manuscript_example_2() {
653        // Define the dataset as a matrix of reads
654        let reads = vec![
655            vec![Some(0), None],    // f0
656            vec![Some(1), Some(0)], // f1
657            vec![Some(1), Some(1)], // f2
658            vec![None, Some(0)],    // f3
659        ];
660
661        // Define the confidence degrees
662        let confidences = vec![
663            vec![Some(5), None],    // f0
664            vec![Some(3), Some(2)], // f1
665            vec![Some(6), Some(1)], // f2
666            vec![None, Some(2)],    // f3
667        ];
668
669        // Initialize the dataset
670        let data = WMECData::new(reads, confidences);
671
672        // Define the expected values for C(2, ·)
673        let expected_values = vec![
674            (vec![1, 2, 3], vec![], 1), // C(2, ({f1, f2, f3}, ∅)) = 1
675            (vec![1, 2], vec![3], 1),   // C(2, ({f1, f2}, {f3})) = 1
676            (vec![1, 3], vec![2], 3),   // C(2, ({f1, f3}, {f2})) = 3
677            (vec![2, 3], vec![1], 4),   // C(2, ({f2, f3}, {f1})) = 4
678        ];
679
680        let (mut dp, mut backtrack) = initialize_dp(&data);
681
682        for snp in 1..data.num_snps {
683            update_dp(&data, &mut dp, &mut backtrack, snp);
684        }
685
686        // Verify that the results comport with expected_values
687        for (r, s, expected_cost) in expected_values {
688            let r_set: BTreeSet<usize> = r.into_iter().collect();
689            let s_set: BTreeSet<usize> = s.into_iter().collect();
690            let actual_cost = dp
691                .get(&(1, r_set.clone(), s_set.clone()))
692                .unwrap_or(&u32::MAX);
693            assert_eq!(
694                *actual_cost, expected_cost,
695                "Cost for partition ({:?}, {:?}) does not match expected",
696                r_set, s_set
697            );
698        }
699    }
700
701    /// This test case is based on Figure 1 from the paper:
702    /// "WhatsHap: fast and accurate read-based phasing"
703    /// by Marcel Martin et al.
704    /// DOI: https://doi.org/10.1101/085050
705    ///
706    /// The figure illustrates a small example of the weighted minimum error correction problem,
707    /// which is the core algorithmic component of WhatsHap.
708    #[test]
709    fn test_whatshap_manuscript_figure_1() {
710        // Define the dataset as a matrix of reads
711        let reads = vec![
712            vec![Some(0), None, None, Some(0), Some(1)], // f0
713            vec![Some(1), Some(0), Some(0), None, None], // f1
714            vec![Some(1), Some(1), Some(0), None, None], // f2
715            vec![None, Some(0), Some(0), Some(1), None], // f3
716            vec![None, None, Some(1), Some(0), Some(1)], // f4
717            vec![None, None, None, Some(0), Some(1)],    // f5
718        ];
719
720        // Define the confidence degrees
721        let confidences = vec![
722            vec![Some(32), None, None, Some(34), Some(17)], // f0
723            vec![Some(15), Some(25), Some(13), None, None], // f1
724            vec![Some(7), Some(3), Some(15), None, None],   // f2
725            vec![None, Some(12), Some(23), Some(29), Some(31)], // f3
726            vec![None, None, Some(25), Some(17), Some(19)], // f4
727            vec![None, None, None, Some(20), Some(10)],     // f5
728        ];
729
730        // Initialize the dataset
731        let data = WMECData::new(reads, confidences);
732
733        // Perform the WMEC algorithm using dynamic programming
734        let (haplotype1, haplotype2, part1, part2) = phase(&data);
735
736        let expected_haplotype1 = vec![0, 1, 1, 0, 1];
737        let expected_haplotype2 = vec![1, 0, 0, 1, 0];
738
739        assert_eq!(
740            haplotype1, expected_haplotype1,
741            "Haplotype 1 does not match expected"
742        );
743        assert_eq!(
744            haplotype2, expected_haplotype2,
745            "Haplotype 2 does not match expected"
746        );
747    }
748}