1use std::collections::{BTreeMap, BTreeSet, HashMap};
9use std::fs::File;
10use std::io::Write;
11
12use itertools::Itertools;
13
14use indicatif::ParallelProgressIterator;
15use rayon::prelude::*;
16
17#[derive(Debug)]
18pub struct WMECData {
19 pub reads: Vec<Vec<Option<u8>>>, pub confidences: Vec<Vec<Option<u32>>>, pub num_snps: usize, }
23
24impl WMECData {
25 #[must_use]
27 pub fn new(reads: Vec<Vec<Option<u8>>>, confidences: Vec<Vec<Option<u32>>>) -> Self {
28 let num_snps = reads[0].len();
29 WMECData {
30 reads,
31 confidences,
32 num_snps,
33 }
34 }
35
36 #[must_use]
39 pub fn compute_costs(&self, snp: usize, set_r: &BTreeSet<usize>) -> (u32, u32) {
40 let mut w0 = 0; let mut w1 = 0; for &read_index in set_r {
44 if let Some(allele) = self.reads[read_index][snp] {
45 if let Some(confidence) = self.confidences[read_index][snp] {
46 if allele == 0 {
47 w1 += confidence; } else {
49 w0 += confidence; }
51 }
52 }
53 }
54
55 (w0, w1)
56 }
57
58 #[must_use]
60 pub fn delta_c(&self, snp: usize, r: &BTreeSet<usize>, s: &BTreeSet<usize>) -> u32 {
61 let (w0_r, w1_r) = self.compute_costs(snp, r);
62 let (w0_s, w1_s) = self.compute_costs(snp, s);
63
64 std::cmp::min(w0_r, w1_r) + std::cmp::min(w0_s, w1_s)
66
67 }
71
72 pub fn write_reads_matrix(&self, path: &str) -> std::io::Result<()> {
74 let mut file = File::create(path)?;
75
76 for j in 0..self.num_snps {
78 write!(file, "\tSNP_{}", j)?;
79 }
80 writeln!(file)?;
81
82 for (i, (read, conf)) in self.reads.iter().zip(self.confidences.iter()).enumerate() {
84 write!(file, "Read_{}", i)?;
85 for (allele, qual) in read.iter().zip(conf.iter()) {
86 if let Some(allele) = allele {
87 if *allele == 0 {
88 write!(file, "\t0,{}", qual.unwrap())?
89 };
90 if *allele == 1 {
91 write!(file, "\t1,{}", qual.unwrap())?
92 };
93 } else {
94 write!(file, "\t-,1")?;
95 }
96 }
97 writeln!(file)?;
98 }
99
100 Ok(())
101 }
102}
103
104fn generate_bipartitions(set: &BTreeSet<usize>) -> Vec<(BTreeSet<usize>, BTreeSet<usize>)> {
106 let mut partitions = vec![];
107 let set_vec: Vec<_> = set.iter().collect();
108
109 let num_partitions = 1 << set_vec.len(); for i in 0..num_partitions {
111 let mut r = BTreeSet::new();
112 let mut s = BTreeSet::new();
113
114 for (j, &elem) in set_vec.iter().enumerate() {
115 if i & (1 << j) == 0 {
116 r.insert(*elem);
117 } else {
118 s.insert(*elem);
119 }
120 }
121
122 partitions.push((r, s));
123 }
124
125 partitions
126}
127
128fn initialize_dp(
130 data: &WMECData,
131) -> (
132 HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
133 HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>,
134) {
135 let mut dp: HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32> = HashMap::new();
136 let mut backtrack: HashMap<
137 (usize, BTreeSet<usize>, BTreeSet<usize>),
138 Option<(BTreeSet<usize>, BTreeSet<usize>)>,
139 > = HashMap::new();
140
141 let active_fragments: BTreeSet<usize> = data
142 .reads
143 .iter()
144 .enumerate()
145 .filter(|(_, read)| read[0].is_some()) .map(|(index, _)| index)
147 .collect();
148
149 let partitions = generate_bipartitions(&active_fragments);
150 for (r, s) in partitions {
151 let cost = data.delta_c(0, &r, &s);
152 dp.insert((0, r.clone(), s.clone()), cost);
153 backtrack.insert((0, r.clone(), s.clone()), None);
154 }
155
156 (dp, backtrack)
157}
158
159fn update_dp_old(
161 data: &WMECData,
162 dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
163 backtrack: &mut HashMap<
164 (usize, BTreeSet<usize>, BTreeSet<usize>),
165 Option<(BTreeSet<usize>, BTreeSet<usize>)>,
166 >,
167 snp: usize,
168) {
169 let active_fragments: BTreeSet<usize> = data
170 .reads
171 .iter()
172 .enumerate()
173 .filter(|(_, read)| read[snp].is_some()) .map(|(index, _)| index)
175 .collect();
176 let partitions = generate_bipartitions(&active_fragments);
177
178 for (r, s) in &partitions {
179 let delta_cost = data.delta_c(snp, r, s);
180 let mut min_cost = u32::MAX;
181 let mut best_bipartition = None;
182
183 let prev_active_fragments: BTreeSet<usize> = data
184 .reads
185 .iter()
186 .enumerate()
187 .filter(|(_, read)| read[snp - 1].is_some()) .map(|(index, _)| index)
189 .collect();
190
191 for (prev_r, prev_s) in generate_bipartitions(&prev_active_fragments) {
192 let r_compatible = r
193 .intersection(&prev_active_fragments)
194 .all(|&x| prev_r.contains(&x));
195 let s_compatible = s
196 .intersection(&prev_active_fragments)
197 .all(|&x| prev_s.contains(&x));
198
199 if r_compatible && s_compatible {
200 if let Some(&prev_cost) = dp.get(&(snp - 1, prev_r.clone(), prev_s.clone())) {
201 let current_cost = delta_cost + prev_cost;
202
203 if current_cost < min_cost {
204 min_cost = current_cost;
205 best_bipartition = Some((prev_r.clone(), prev_s.clone()));
206 }
207 }
208 }
209 }
210
211 dp.insert((snp, r.clone(), s.clone()), min_cost);
212 backtrack.insert((snp, r.clone(), s.clone()), best_bipartition);
213 }
214}
215
216fn update_dp(
217 data: &WMECData,
218 dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
219 backtrack: &mut HashMap<
220 (usize, BTreeSet<usize>, BTreeSet<usize>),
221 Option<(BTreeSet<usize>, BTreeSet<usize>)>,
222 >,
223 snp: usize,
224) {
225 let active_fragments: BTreeSet<usize> = data
226 .reads
227 .iter()
228 .enumerate()
229 .filter(|(_, read)| read[snp].is_some())
230 .map(|(index, _)| index)
231 .collect();
232 let partitions = generate_bipartitions(&active_fragments);
233
234 let prev_active_fragments: BTreeSet<usize> = data
236 .reads
237 .iter()
238 .enumerate()
239 .filter(|(_, read)| read[snp - 1].is_some())
240 .map(|(index, _)| index)
241 .collect();
242
243 let results: Vec<_> = partitions
245 .par_iter()
246 .map(|(r, s)| {
247 let delta_cost = data.delta_c(snp, r, s);
248 let mut min_cost = u32::MAX;
249 let mut best_bipartition = None;
250
251 for (prev_r, prev_s) in generate_bipartitions(&prev_active_fragments) {
252 let r_compatible = r
253 .intersection(&prev_active_fragments)
254 .all(|&x| prev_r.contains(&x));
255 let s_compatible = s
256 .intersection(&prev_active_fragments)
257 .all(|&x| prev_s.contains(&x));
258
259 if r_compatible && s_compatible {
260 if let Some(&prev_cost) = dp.get(&(snp - 1, prev_r.clone(), prev_s.clone())) {
261 let current_cost = delta_cost + prev_cost;
262
263 if current_cost < min_cost {
264 min_cost = current_cost;
265 best_bipartition = Some((prev_r.clone(), prev_s.clone()));
266 }
267 }
268 }
269 }
270
271 ((r.clone(), s.clone()), (min_cost, best_bipartition))
272 })
273 .collect();
274
275 for ((r, s), (min_cost, best_bipartition)) in results {
277 dp.insert((snp, r.clone(), s.clone()), min_cost);
278 backtrack.insert((snp, r, s), best_bipartition);
279 }
280}
281
282fn backtrack_haplotypes(
283 data: &WMECData,
284 dp: &HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>,
285 backtrack: &HashMap<
286 (usize, BTreeSet<usize>, BTreeSet<usize>),
287 Option<(BTreeSet<usize>, BTreeSet<usize>)>,
288 >,
289) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
290 let mut best_cost = u32::MAX;
291 let mut best_bipartition = None;
292
293 let final_active_fragments: BTreeSet<usize> = data
295 .reads
296 .iter()
297 .enumerate()
298 .filter(|(_, read)| read[data.num_snps - 1].is_some())
299 .map(|(index, _)| index)
300 .collect();
301
302 for (r, s) in generate_bipartitions(&final_active_fragments) {
303 if let Some(&cost) = dp.get(&(data.num_snps - 1, r.clone(), s.clone())) {
304 if cost < best_cost {
305 best_cost = cost;
306 best_bipartition = Some((r, s));
307 }
308 }
309 }
310
311 let mut haplotype1 = vec![0; data.num_snps];
312 let mut haplotype2 = vec![0; data.num_snps];
313 let mut current_bipartition = best_bipartition.unwrap();
314
315 for snp in (0..data.num_snps).rev() {
316 let (r, s) = ¤t_bipartition;
317 let (w0_r, w1_r) = data.compute_costs(snp, r);
318 let (w0_s, w1_s) = data.compute_costs(snp, s);
319
320 if w0_r + w1_s <= w1_r + w0_s {
321 haplotype1[snp] = 0;
322 haplotype2[snp] = 1;
323 } else {
324 haplotype1[snp] = 1;
325 haplotype2[snp] = 0;
326 }
327
328 if snp > 0 {
329 if let Some(prev_bipartition) = backtrack
330 .get(&(snp, r.clone(), s.clone()))
331 .and_then(|x| x.as_ref())
332 {
333 current_bipartition = prev_bipartition.clone();
334 } else {
335 panic!("No valid previous bipartition found for SNP {snp}");
337 }
338 }
339 }
340
341 crate::elog!("wmec score: {}", best_cost);
342
343 (
344 haplotype1,
345 haplotype2,
346 current_bipartition.0,
347 current_bipartition.1,
348 )
349}
350
351#[must_use]
353pub fn phase(data: &WMECData) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
354 let (mut dp, mut backtrack) = initialize_dp(data);
355
356 for snp in 1..data.num_snps {
357 update_dp(data, &mut dp, &mut backtrack, snp);
358 }
359
360 backtrack_haplotypes(data, &dp, &backtrack)
361}
362
363#[must_use]
364pub fn phase_all(
365 data: &WMECData,
366 window: usize,
367 stride: usize,
368) -> (Vec<u8>, Vec<u8>, BTreeSet<usize>, BTreeSet<usize>) {
369 data.write_reads_matrix("mat.tsv");
370
371 let mut windows: Vec<_> = (0..data.num_snps)
373 .step_by(stride)
374 .map(|start| (start, (start + window).min(data.num_snps)))
375 .collect();
376
377 if windows.len() > 1 {
379 if let Some(&(_, prev_end)) = windows.get(windows.len() - 2) {
380 if let Some(&(_, last_end)) = windows.last() {
381 if last_end == prev_end {
382 windows.pop();
383 }
384 }
385 }
386 }
387
388 let pb = crate::utils::default_bounded_progress_bar("Processing windows", windows.len() as u64);
389
390 let haplotype_pairs: Vec<_> = windows
392 .par_iter()
394 .progress_with(pb)
395 .map(|&(start, end)| {
396 let (window_indices, window_reads): (Vec<usize>, Vec<Vec<Option<u8>>>) = data
397 .reads
398 .iter()
399 .zip(data.confidences.iter())
400 .enumerate()
401 .map(|(read_idx, (read, confidence))| {
402 let window_read = read[start..end].to_vec();
403 let window_confidence = confidence[start..end].to_vec();
404
405 let none_count = window_read.iter().filter(|x| x.is_none()).count();
406 let lowqual_count = window_confidence
407 .iter()
408 .filter(|&&x| x.is_some() && x.unwrap() == 0)
409 .count();
410
411 (
412 none_count + lowqual_count,
413 read_idx,
414 window_read,
415 window_confidence,
416 )
417 })
418 .collect::<Vec<_>>()
419 .into_iter()
420 .sorted_by_key(|(bad_count, _, _, _)| *bad_count)
421 .take(12)
422 .map(|(_, read_idx, window_read, _)| (read_idx, window_read))
423 .unzip();
424
425 let (_, window_confidences): (Vec<usize>, Vec<Vec<Option<u32>>>) = data
426 .reads
427 .iter()
428 .zip(data.confidences.iter())
429 .enumerate()
430 .map(|(read_idx, (read, confidence))| {
431 let window_read = read[start..end].to_vec();
432 let window_confidence = confidence[start..end].to_vec();
433
434 let none_count = window_read.iter().filter(|x| x.is_none()).count();
435 let lowqual_count = window_confidence
436 .iter()
437 .filter(|&&x| x.is_some() && x.unwrap() == 0)
438 .count();
439
440 (
441 none_count + lowqual_count,
442 read_idx,
443 window_read,
444 window_confidence,
445 )
446 })
447 .collect::<Vec<_>>()
448 .into_iter()
449 .sorted_by_key(|(bad_count, _, _, _)| *bad_count)
450 .take(12)
451 .map(|(_, read_idx, _, window_confidences)| (read_idx, window_confidences))
452 .unzip();
453
454 crate::elog!("window_reads: {} {} {}", start, end, window_reads.len());
455 crate::elog!("window_indices: {} {} {:?}", start, end, window_indices);
456
457 let window_data = WMECData::new(window_reads, window_confidences);
458
459 let (mut dp, mut backtrack) = initialize_dp(&window_data);
460 for snp in 1..window_data.num_snps {
461 update_dp(&window_data, &mut dp, &mut backtrack, snp);
462 }
463
464 let (hap1, hap2, part1, part2) = backtrack_haplotypes(&window_data, &dp, &backtrack);
465
466 let re1: BTreeSet<_> = part1.iter().map(|&i| window_indices[i as usize]).collect();
467 let re2: BTreeSet<_> = part2.iter().map(|&i| window_indices[i as usize]).collect();
468
469 crate::elog!("window_re1: {} {} {:?} {:?}", start, end, part1, re1);
470 crate::elog!("window_re2: {} {} {:?} {:?}", start, end, part2, re2);
471
472 (hap1, hap2, re1, re2)
473 })
474 .collect();
475
476 let mut haplotype1 = Vec::new();
477 let mut haplotype2 = Vec::new();
478 let mut part1 = BTreeSet::new();
479 let mut part2 = BTreeSet::new();
480
481 let overlap = window - stride;
482
483 let mut all_read_assignments = BTreeMap::new();
484
485 let mut i = 0;
486 for (hap1, hap2, reads1, reads2) in haplotype_pairs {
487 i += 1;
492
493 if haplotype1.len() == 0 {
494 for read_num in reads1.clone() {
495 all_read_assignments
496 .entry(read_num)
497 .or_insert(vec![])
498 .push(1);
499 }
500
501 for read_num in reads2.clone() {
502 all_read_assignments
503 .entry(read_num)
504 .or_insert(vec![])
505 .push(2);
506 }
507
508 haplotype1 = hap1;
509 haplotype2 = hap2;
510 part1 = reads1;
511 part2 = reads2;
512 } else {
513 crate::elog!(
515 "haplotypes: {} {} {}",
516 haplotype1.len(),
517 haplotype2.len(),
518 overlap
519 );
520
521 let new_overlap_len = std::cmp::min(overlap, hap1.len());
522 let h1_overlap = &haplotype1[haplotype1.len() - new_overlap_len..];
523 let h2_overlap = &haplotype2[haplotype2.len() - new_overlap_len..];
524 let new_overlap = &hap1[..new_overlap_len];
525
526 let h1_matches = h1_overlap
528 .iter()
529 .zip(new_overlap.iter())
530 .filter(|(a, b)| a == b)
531 .count();
532 let h2_matches = h2_overlap
533 .iter()
534 .zip(new_overlap.iter())
535 .filter(|(a, b)| a == b)
536 .count();
537
538 if h1_matches >= h2_matches {
540 haplotype1.extend_from_slice(&hap1[std::cmp::min(overlap, hap1.len())..]);
541 haplotype2.extend_from_slice(&hap2[std::cmp::min(overlap, hap2.len())..]);
542
543 for read_num in reads1.clone() {
544 all_read_assignments
545 .entry(read_num)
546 .or_insert(vec![])
547 .push(1);
548 }
549
550 for read_num in reads2.clone() {
551 all_read_assignments
552 .entry(read_num)
553 .or_insert(vec![])
554 .push(2);
555 }
556 } else {
557 haplotype1.extend_from_slice(&hap2[std::cmp::min(overlap, hap2.len())..]);
558 haplotype2.extend_from_slice(&hap1[std::cmp::min(overlap, hap1.len())..]);
559
560 for read_num in reads1.clone() {
561 all_read_assignments
562 .entry(read_num)
563 .or_insert(vec![])
564 .push(2);
565 }
566
567 for read_num in reads2.clone() {
568 all_read_assignments
569 .entry(read_num)
570 .or_insert(vec![])
571 .push(1);
572 }
573 }
574 }
575 }
576
577 let mut final_part1 = BTreeSet::new();
578 let mut final_part2 = BTreeSet::new();
579
580 crate::elog!("all_read_assignments: {:?}", &all_read_assignments);
581
582 for (read_num, assignments) in all_read_assignments {
583 let most_common = assignments
585 .iter()
586 .fold(HashMap::new(), |mut counts, &value| {
587 *counts.entry(value).or_insert(0) += 1;
588 counts
589 })
590 .into_iter()
591 .max_by_key(|&(_, count)| count)
592 .map(|(value, _)| value)
593 .unwrap_or(0); if most_common == 1 {
596 final_part1.insert(read_num);
597 } else if most_common == 2 {
598 final_part2.insert(read_num);
599 }
600 }
601
602 (haplotype1, haplotype2, final_part1, final_part2)
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_initialization_manuscript_example_1() {
611 let reads = vec![
613 vec![Some(0), None], vec![Some(1), Some(0)], vec![Some(1), Some(1)], vec![None, Some(0)], ];
618
619 let confidences = vec![
621 vec![Some(5), None], vec![Some(3), Some(2)], vec![Some(6), Some(1)], vec![None, Some(2)], ];
626
627 let data = WMECData::new(reads, confidences);
629
630 let expected_values = vec![
632 (vec![0, 1, 2], vec![], 5), (vec![0, 1], vec![2], 3), (vec![0, 2], vec![1], 5), (vec![1, 2], vec![0], 0), ];
637
638 for (r, s, expected_cost) in expected_values {
640 let r_set: BTreeSet<usize> = r.into_iter().collect();
641 let s_set: BTreeSet<usize> = s.into_iter().collect();
642 let cost = data.delta_c(0, &r_set, &s_set);
643 assert_eq!(
644 cost, expected_cost,
645 "Cost for partition ({:?}, {:?}) does not match expected",
646 r_set, s_set
647 );
648 }
649 }
650
651 #[test]
652 fn test_recurrence_manuscript_example_2() {
653 let reads = vec![
655 vec![Some(0), None], vec![Some(1), Some(0)], vec![Some(1), Some(1)], vec![None, Some(0)], ];
660
661 let confidences = vec![
663 vec![Some(5), None], vec![Some(3), Some(2)], vec![Some(6), Some(1)], vec![None, Some(2)], ];
668
669 let data = WMECData::new(reads, confidences);
671
672 let expected_values = vec![
674 (vec![1, 2, 3], vec![], 1), (vec![1, 2], vec![3], 1), (vec![1, 3], vec![2], 3), (vec![2, 3], vec![1], 4), ];
679
680 let (mut dp, mut backtrack) = initialize_dp(&data);
681
682 for snp in 1..data.num_snps {
683 update_dp(&data, &mut dp, &mut backtrack, snp);
684 }
685
686 for (r, s, expected_cost) in expected_values {
688 let r_set: BTreeSet<usize> = r.into_iter().collect();
689 let s_set: BTreeSet<usize> = s.into_iter().collect();
690 let actual_cost = dp
691 .get(&(1, r_set.clone(), s_set.clone()))
692 .unwrap_or(&u32::MAX);
693 assert_eq!(
694 *actual_cost, expected_cost,
695 "Cost for partition ({:?}, {:?}) does not match expected",
696 r_set, s_set
697 );
698 }
699 }
700
701 #[test]
709 fn test_whatshap_manuscript_figure_1() {
710 let reads = vec![
712 vec![Some(0), None, None, Some(0), Some(1)], vec![Some(1), Some(0), Some(0), None, None], vec![Some(1), Some(1), Some(0), None, None], vec![None, Some(0), Some(0), Some(1), None], vec![None, None, Some(1), Some(0), Some(1)], vec![None, None, None, Some(0), Some(1)], ];
719
720 let confidences = vec![
722 vec![Some(32), None, None, Some(34), Some(17)], vec![Some(15), Some(25), Some(13), None, None], vec![Some(7), Some(3), Some(15), None, None], vec![None, Some(12), Some(23), Some(29), Some(31)], vec![None, None, Some(25), Some(17), Some(19)], vec![None, None, None, Some(20), Some(10)], ];
729
730 let data = WMECData::new(reads, confidences);
732
733 let (haplotype1, haplotype2, part1, part2) = phase(&data);
735
736 let expected_haplotype1 = vec![0, 1, 1, 0, 1];
737 let expected_haplotype2 = vec![1, 0, 0, 1, 0];
738
739 assert_eq!(
740 haplotype1, expected_haplotype1,
741 "Haplotype 1 does not match expected"
742 );
743 assert_eq!(
744 haplotype2, expected_haplotype2,
745 "Haplotype 2 does not match expected"
746 );
747 }
748}