1use std::{
2 collections::{HashMap, HashSet},
3 path::PathBuf,
4};
5
6use gbdt::{decision_tree::Data, gradient_boost::GBDT};
7use itertools::Itertools;
8
9use crate::{ldbg::LdBG, record::Record};
10
11#[derive(Debug)]
13pub struct MLdBG {
14 pub kmer_size: usize,
15 pub ldbgs: Vec<LdBG>,
16 pub scores: HashMap<Vec<u8>, f32>,
17}
18
19impl MLdBG {
20 #[must_use]
22 pub fn new(kmer_size: usize) -> Self {
23 MLdBG {
24 kmer_size,
25 ldbgs: Vec::new(),
26 scores: HashMap::new(),
27 }
28 }
29
30 #[must_use]
46 pub fn from_ldbgs(ldbgs: Vec<LdBG>) -> Self {
47 let kmer_size = ldbgs[0].kmer_size;
48
49 for ldbg in &ldbgs {
50 assert!(
51 ldbg.kmer_size == kmer_size,
52 "The k-mer size of the LdBG does not match the k-mer size of the MLdBG."
53 );
54 }
55
56 MLdBG {
57 kmer_size,
58 ldbgs,
59 scores: HashMap::new(),
60 }
61 }
62
63 pub fn push(&mut self, ldbg: LdBG) {
75 assert!(
76 ldbg.kmer_size == self.kmer_size,
77 "The k-mer size of the LdBG does not match the k-mer size of the MLdBG."
78 );
79
80 self.ldbgs.push(ldbg);
81 }
82
83 pub fn insert(&mut self, index: usize, ldbg: LdBG) {
85 if index <= self.ldbgs.len() {
86 self.ldbgs.insert(index, ldbg);
87 }
88 }
89
90 pub fn append(&mut self, ldbg: LdBG) {
92 self.ldbgs.push(ldbg);
93 }
94
95 pub fn append_from_file(&mut self, name: String, seq_path: &PathBuf) {
97 let l = LdBG::from_file(name, self.kmer_size, seq_path);
98 self.ldbgs.push(l);
99 }
100
101 pub fn append_from_filtered_file<F>(&mut self, name: String, seq_path: &PathBuf, filter: F)
119 where
120 F: Fn(&bio::io::fasta::Record, &HashSet<Vec<u8>>) -> bool,
121 {
122 let reader = bio::io::fasta::Reader::from_file(seq_path).unwrap();
123 let all_reads: Vec<bio::io::fasta::Record> = reader.records().map(|r| r.unwrap()).collect();
124
125 let kmer_union = self.union_of_kmers();
126
127 let filtered_reads: Vec<Vec<u8>> = all_reads
128 .into_iter()
129 .filter(|r| filter(r, &kmer_union))
130 .map(|r| r.seq().to_vec())
131 .collect();
132
133 let l = LdBG::from_sequences(name, self.kmer_size, &filtered_reads);
134 self.ldbgs.push(l);
135 }
136
137 #[must_use]
156 pub fn score_kmers(mut self, model_path: &PathBuf) -> Self {
157 let gbdt = GBDT::load_model(model_path.to_str().unwrap()).unwrap();
158
159 self.scores = self
160 .union_of_kmers()
161 .iter()
162 .map(|cn_kmer| {
163 let compressed_len = crate::utils::homopolymer_compressed(cn_kmer).len();
164 let compressed_len_diff = (cn_kmer.len() - compressed_len) as f32;
165 let entropy = crate::utils::shannon_entropy(cn_kmer);
166 let gc_content = crate::utils::gc_content(cn_kmer);
167
168 let lcov = self.ldbgs[0]
169 .kmers
170 .get(cn_kmer)
171 .map_or(0, |record| record.coverage());
172
173 let scov_fw = self.ldbgs[1]
174 .kmers
175 .get(cn_kmer)
176 .map_or(0, |sr| sr.fw_coverage());
177 let scov_rc = self.ldbgs[1]
178 .kmers
179 .get(cn_kmer)
180 .map_or(0, |sr| sr.rc_coverage());
181 let scov_total = scov_fw + scov_rc;
182 let strand_ratio = if scov_total > 0 {
183 (scov_fw as f32).max(scov_rc as f32) / scov_total as f32
184 } else {
185 0.5
186 };
187
188 let features = vec![
189 if lcov > 0 { 1.0 } else { 0.0 }, scov_total as f32, strand_ratio as f32, compressed_len_diff, entropy, gc_content, ];
196
197 let data = Data::new_test_data(features, None);
198 let prediction = *gbdt.predict(&vec![data]).first().unwrap_or(&0.0);
199
200 (cn_kmer.clone(), prediction)
201 })
202 .collect();
203
204 self
205 }
206
207 fn distance_to_a_contig_end(
208 contigs: &Vec<Vec<u8>>,
209 kmer_size: usize,
210 ) -> HashMap<Vec<u8>, usize> {
211 let mut distances = HashMap::new();
212
213 for contig in contigs {
214 for (distance_from_start, cn_kmer) in contig
215 .windows(kmer_size)
216 .map(crate::utils::canonicalize_kmer)
217 .enumerate()
218 {
219 let distance_from_end = contig.len() - distance_from_start - kmer_size;
220
221 distances.insert(
222 cn_kmer,
223 if distance_from_start < distance_from_end {
224 distance_from_start
225 } else {
226 distance_from_end
227 },
228 );
229 }
230 }
231
232 distances
233 }
234
235 pub fn collapse(&mut self) -> LdBG {
236 let mut ldbg = LdBG::new(self.ldbgs[0].name.clone(), self.kmer_size);
237
238 for cn_kmer in self.union_of_kmers() {
239 let coverage = self
240 .ldbgs
241 .iter()
242 .map(|ldbg| {
243 ldbg.kmers
244 .get(&cn_kmer)
245 .map_or(0, |record| record.coverage())
246 })
247 .sum::<u16>();
248
249 let sources = self
250 .ldbgs
251 .iter()
252 .enumerate()
253 .filter_map(|(index, ldbg)| {
254 if ldbg.kmers.contains_key(&cn_kmer) {
255 Some(index)
256 } else {
257 None
258 }
259 })
260 .collect::<Vec<_>>();
261
262 ldbg.kmers
263 .insert(cn_kmer.clone(), Record::new(coverage, None));
264 ldbg.sources.insert(cn_kmer.clone(), sources);
265 }
266
267 ldbg.scores = self.scores.clone();
268
269 for l in &self.ldbgs {
270 for cn_kmer in l.noise.iter() {
271 ldbg.noise.insert(cn_kmer.clone());
272 }
273 }
274
275 ldbg.infer_edges();
276
277 ldbg
278 }
279
280 pub fn filter_reads<F>(&mut self, seq_path: &PathBuf, filter: F) -> Vec<Vec<u8>>
295 where
296 F: Fn(&bio::io::fasta::Record, &HashSet<Vec<u8>>) -> bool,
297 {
298 let reader = bio::io::fasta::Reader::from_file(seq_path).unwrap();
299 let all_reads: Vec<bio::io::fasta::Record> = reader.records().map(|r| r.unwrap()).collect();
300
301 let kmer_union = self.union_of_kmers();
302
303 let filtered_reads: Vec<Vec<u8>> = all_reads
304 .into_iter()
305 .filter(|r| filter(r, &kmer_union))
306 .map(|r| r.seq().to_vec())
307 .collect();
308
309 filtered_reads
310 }
311
312 #[must_use]
314 pub fn union_of_kmers(&self) -> HashSet<Vec<u8>> {
315 let mut kmer_union = HashSet::new();
316
317 for ldbg in &self.ldbgs {
318 for kmer in ldbg.kmers.keys() {
319 kmer_union.insert(kmer.clone());
320 }
321 }
322
323 kmer_union
324 }
325
326 #[must_use]
328 pub fn get(&self, index: usize) -> Option<&LdBG> {
329 self.ldbgs.get(index)
330 }
331
332 pub fn iter(&self) -> std::slice::Iter<LdBG> {
334 self.ldbgs.iter()
335 }
336
337 pub fn iter_mut(&mut self) -> std::slice::IterMut<LdBG> {
339 self.ldbgs.iter_mut()
340 }
341
342 pub fn clear(&mut self) {
344 self.ldbgs.clear();
345 }
346
347 pub fn remove(&mut self, index: usize) -> Option<LdBG> {
349 if index < self.ldbgs.len() {
350 Some(self.ldbgs.remove(index))
351 } else {
352 None
353 }
354 }
355
356 #[must_use]
358 pub fn len(&self) -> usize {
359 self.ldbgs.len()
360 }
361
362 #[must_use]
364 pub fn is_empty(&self) -> bool {
365 self.ldbgs.is_empty()
366 }
367
368 pub fn pop(&mut self) -> Option<LdBG> {
370 self.ldbgs.pop()
371 }
372
373 pub fn pop_if<F>(&mut self, condition: F) -> Option<LdBG>
375 where
376 F: Fn(&LdBG) -> bool,
377 {
378 let index = self.ldbgs.iter().position(condition);
379 if let Some(index) = index {
380 Some(self.ldbgs.remove(index))
381 } else {
382 None
383 }
384 }
385}