skydive/
mldbg.rs

1use std::{
2    collections::{HashMap, HashSet},
3    path::PathBuf,
4};
5
6use gbdt::{decision_tree::Data, gradient_boost::GBDT};
7use itertools::Itertools;
8
9use crate::{ldbg::LdBG, record::Record};
10
11/// Represents a multi-color linked de Bruijn graph, all built with same k-mer size.
12#[derive(Debug)]
13pub struct MLdBG {
14    pub kmer_size: usize,
15    pub ldbgs: Vec<LdBG>,
16    pub scores: HashMap<Vec<u8>, f32>,
17}
18
19impl MLdBG {
20    /// Create an empty multi-color `LdBG`.
21    #[must_use]
22    pub fn new(kmer_size: usize) -> Self {
23        MLdBG {
24            kmer_size,
25            ldbgs: Vec::new(),
26            scores: HashMap::new(),
27        }
28    }
29
30    /// Create an `MLdBG` from a vector of `LdBGs`.
31    ///
32    /// # Arguments
33    ///
34    /// * `ldbgs` - A vector of `LdBGs`.
35    ///
36    /// # Returns
37    ///
38    /// A new `MLdBG`.
39    ///
40    /// # Panics
41    ///
42    /// This function will panic if the `kmer_size` of the `LdBG` being added does not match the
43    /// `kmer_size` of the `MLdBG`. Specifically, it will panic at the `assert!` statement if the
44    /// condition `ldbg.kmer_size == self.kmer_size` is not met.
45    #[must_use]
46    pub fn from_ldbgs(ldbgs: Vec<LdBG>) -> Self {
47        let kmer_size = ldbgs[0].kmer_size;
48
49        for ldbg in &ldbgs {
50            assert!(
51                ldbg.kmer_size == kmer_size,
52                "The k-mer size of the LdBG does not match the k-mer size of the MLdBG."
53            );
54        }
55
56        MLdBG {
57            kmer_size,
58            ldbgs,
59            scores: HashMap::new(),
60        }
61    }
62
63    /// Add a `LdBG` to the `MLdBG`.
64    ///
65    /// # Arguments
66    ///
67    /// * `ldbg` - The `LdBG` to add.
68    ///
69    /// # Panics
70    ///
71    /// The `push` function will panic if the `kmer_size` of the `ldbg` being added does not match
72    /// the `kmer_size` of the `MLdBG`. Specifically, it will panic at the `assert!` statement if
73    /// the condition `ldbg.kmer_size == self.kmer_size` is not met.
74    pub fn push(&mut self, ldbg: LdBG) {
75        assert!(
76            ldbg.kmer_size == self.kmer_size,
77            "The k-mer size of the LdBG does not match the k-mer size of the MLdBG."
78        );
79
80        self.ldbgs.push(ldbg);
81    }
82
83    /// Insert a `LdBG` at a specific position in the `MLdBG`.
84    pub fn insert(&mut self, index: usize, ldbg: LdBG) {
85        if index <= self.ldbgs.len() {
86            self.ldbgs.insert(index, ldbg);
87        }
88    }
89
90    /// Append a `LdBG` to the end of the `MLdBG`.
91    pub fn append(&mut self, ldbg: LdBG) {
92        self.ldbgs.push(ldbg);
93    }
94
95    /// Append an `LdBG` to the end of the `MLdBG`, created anew from a fasta file.
96    pub fn append_from_file(&mut self, name: String, seq_path: &PathBuf) {
97        let l = LdBG::from_file(name, self.kmer_size, seq_path);
98        self.ldbgs.push(l);
99    }
100
101    /// Append an `LdBG` to the end of the `MLdBG`, created from a filtered set of
102    /// sequences in a fasta file.
103    ///
104    /// This function reads sequences from a specified fasta file, filters them based on a provided
105    /// condition, and then creates an `LdBG` from the filtered sequences. The new `LdBG` is appended to
106    /// the `MLdBG`.
107    ///
108    /// # Arguments
109    ///
110    /// * `name` - A string representing the name of the new `LdBG`.
111    /// * `seq_path` - A reference to a `PathBuf` representing the path to the fasta file containing the sequences.
112    /// * `filter` - A closure that takes a fasta record and a set of kmers and returns a boolean indicating
113    ///              whether the record should be included.
114    ///
115    /// # Panics
116    ///
117    /// This function will panic if it cannot read the fasta file.
118    pub fn append_from_filtered_file<F>(&mut self, name: String, seq_path: &PathBuf, filter: F)
119    where
120        F: Fn(&bio::io::fasta::Record, &HashSet<Vec<u8>>) -> bool,
121    {
122        let reader = bio::io::fasta::Reader::from_file(seq_path).unwrap();
123        let all_reads: Vec<bio::io::fasta::Record> = reader.records().map(|r| r.unwrap()).collect();
124
125        let kmer_union = self.union_of_kmers();
126
127        let filtered_reads: Vec<Vec<u8>> = all_reads
128            .into_iter()
129            .filter(|r| filter(r, &kmer_union))
130            .map(|r| r.seq().to_vec())
131            .collect();
132
133        let l = LdBG::from_sequences(name, self.kmer_size, &filtered_reads);
134        self.ldbgs.push(l);
135    }
136
137    /// Scores the k-mers in the `MLdBG` using a pre-trained Gradient Boosting
138    /// Decision Tree (`GBDT`) model.
139    ///
140    /// This function loads a `GBDT` model from the specified path and uses it to predict scores
141    /// for each k-mer in the union of k-mers from all `LdBGs` in the `MLdBG`. The scores are stored
142    /// in the `scores` field of the `MLdBG`.
143    ///
144    /// # Arguments
145    ///
146    /// * `model_path` - A path to the file containing the pre-trained `GBDT` model.
147    ///
148    /// # Returns
149    ///
150    /// An updated `MLdBG` with the k-mer scores populated.
151    ///
152    /// # Panics
153    ///
154    /// This function will panic if the model file cannot be loaded or if the prediction fails.
155    #[must_use]
156    pub fn score_kmers(mut self, model_path: &PathBuf) -> Self {
157        let gbdt = GBDT::load_model(model_path.to_str().unwrap()).unwrap();
158
159        self.scores = self
160            .union_of_kmers()
161            .iter()
162            .map(|cn_kmer| {
163                let compressed_len = crate::utils::homopolymer_compressed(cn_kmer).len();
164                let compressed_len_diff = (cn_kmer.len() - compressed_len) as f32;
165                let entropy = crate::utils::shannon_entropy(cn_kmer);
166                let gc_content = crate::utils::gc_content(cn_kmer);
167
168                let lcov = self.ldbgs[0]
169                    .kmers
170                    .get(cn_kmer)
171                    .map_or(0, |record| record.coverage());
172
173                let scov_fw = self.ldbgs[1]
174                    .kmers
175                    .get(cn_kmer)
176                    .map_or(0, |sr| sr.fw_coverage());
177                let scov_rc = self.ldbgs[1]
178                    .kmers
179                    .get(cn_kmer)
180                    .map_or(0, |sr| sr.rc_coverage());
181                let scov_total = scov_fw + scov_rc;
182                let strand_ratio = if scov_total > 0 {
183                    (scov_fw as f32).max(scov_rc as f32) / scov_total as f32
184                } else {
185                    0.5
186                };
187
188                let features = vec![
189                    if lcov > 0 { 1.0 } else { 0.0 }, // present in long reads
190                    scov_total as f32,                // coverage in short reads
191                    strand_ratio as f32, // measure of strand bias (0.5 = balanced, 1.0 = all on one strand)
192                    compressed_len_diff, // homopolymer compression length difference
193                    entropy,             // shannon entropy
194                    gc_content,          // gc content
195                ];
196
197                let data = Data::new_test_data(features, None);
198                let prediction = *gbdt.predict(&vec![data]).first().unwrap_or(&0.0);
199
200                (cn_kmer.clone(), prediction)
201            })
202            .collect();
203
204        self
205    }
206
207    fn distance_to_a_contig_end(
208        contigs: &Vec<Vec<u8>>,
209        kmer_size: usize,
210    ) -> HashMap<Vec<u8>, usize> {
211        let mut distances = HashMap::new();
212
213        for contig in contigs {
214            for (distance_from_start, cn_kmer) in contig
215                .windows(kmer_size)
216                .map(crate::utils::canonicalize_kmer)
217                .enumerate()
218            {
219                let distance_from_end = contig.len() - distance_from_start - kmer_size;
220
221                distances.insert(
222                    cn_kmer,
223                    if distance_from_start < distance_from_end {
224                        distance_from_start
225                    } else {
226                        distance_from_end
227                    },
228                );
229            }
230        }
231
232        distances
233    }
234
235    pub fn collapse(&mut self) -> LdBG {
236        let mut ldbg = LdBG::new(self.ldbgs[0].name.clone(), self.kmer_size);
237
238        for cn_kmer in self.union_of_kmers() {
239            let coverage = self
240                .ldbgs
241                .iter()
242                .map(|ldbg| {
243                    ldbg.kmers
244                        .get(&cn_kmer)
245                        .map_or(0, |record| record.coverage())
246                })
247                .sum::<u16>();
248
249            let sources = self
250                .ldbgs
251                .iter()
252                .enumerate()
253                .filter_map(|(index, ldbg)| {
254                    if ldbg.kmers.contains_key(&cn_kmer) {
255                        Some(index)
256                    } else {
257                        None
258                    }
259                })
260                .collect::<Vec<_>>();
261
262            ldbg.kmers
263                .insert(cn_kmer.clone(), Record::new(coverage, None));
264            ldbg.sources.insert(cn_kmer.clone(), sources);
265        }
266
267        ldbg.scores = self.scores.clone();
268
269        for l in &self.ldbgs {
270            for cn_kmer in l.noise.iter() {
271                ldbg.noise.insert(cn_kmer.clone());
272            }
273        }
274
275        ldbg.infer_edges();
276
277        ldbg
278    }
279
280    /// Filter reads from a fasta file based on a condition.
281    ///
282    /// # Arguments
283    ///
284    /// * `seq_path` - A path to the fasta file containing the reads.
285    /// * `filter` - A closure that takes a fasta record and a set of kmers and returns a boolean.
286    ///
287    /// # Returns
288    ///
289    /// A vector of the kept reads.
290    ///
291    /// # Panics
292    ///
293    /// This function will panic if it cannot read the fasta file.
294    pub fn filter_reads<F>(&mut self, seq_path: &PathBuf, filter: F) -> Vec<Vec<u8>>
295    where
296        F: Fn(&bio::io::fasta::Record, &HashSet<Vec<u8>>) -> bool,
297    {
298        let reader = bio::io::fasta::Reader::from_file(seq_path).unwrap();
299        let all_reads: Vec<bio::io::fasta::Record> = reader.records().map(|r| r.unwrap()).collect();
300
301        let kmer_union = self.union_of_kmers();
302
303        let filtered_reads: Vec<Vec<u8>> = all_reads
304            .into_iter()
305            .filter(|r| filter(r, &kmer_union))
306            .map(|r| r.seq().to_vec())
307            .collect();
308
309        filtered_reads
310    }
311
312    /// Get the union of kmers from all `LdBGs` in the `MLdBG`.
313    #[must_use]
314    pub fn union_of_kmers(&self) -> HashSet<Vec<u8>> {
315        let mut kmer_union = HashSet::new();
316
317        for ldbg in &self.ldbgs {
318            for kmer in ldbg.kmers.keys() {
319                kmer_union.insert(kmer.clone());
320            }
321        }
322
323        kmer_union
324    }
325
326    /// Get a reference to the `LdBG` at a specific index.
327    #[must_use]
328    pub fn get(&self, index: usize) -> Option<&LdBG> {
329        self.ldbgs.get(index)
330    }
331
332    /// Returns an iterator over the `LdBGs` in the `MLdBG`.
333    pub fn iter(&self) -> std::slice::Iter<LdBG> {
334        self.ldbgs.iter()
335    }
336
337    /// Returns a mutable iterator over the `LdBGs` in the `MLdBG`.
338    pub fn iter_mut(&mut self) -> std::slice::IterMut<LdBG> {
339        self.ldbgs.iter_mut()
340    }
341
342    /// Clear all `LdBGs` from the `MLdBG`.
343    pub fn clear(&mut self) {
344        self.ldbgs.clear();
345    }
346
347    /// Remove a `LdBG` from the `MLdBG` by index.
348    pub fn remove(&mut self, index: usize) -> Option<LdBG> {
349        if index < self.ldbgs.len() {
350            Some(self.ldbgs.remove(index))
351        } else {
352            None
353        }
354    }
355
356    /// Returns the number of `LdBGs` in the `MLdBG`.
357    #[must_use]
358    pub fn len(&self) -> usize {
359        self.ldbgs.len()
360    }
361
362    /// Check if the `MLdBG` is empty.
363    #[must_use]
364    pub fn is_empty(&self) -> bool {
365        self.ldbgs.is_empty()
366    }
367
368    /// Remove and return the last `LdBG` from the `MLdBG`.
369    pub fn pop(&mut self) -> Option<LdBG> {
370        self.ldbgs.pop()
371    }
372
373    /// Remove and return the `LdBG` from the `MLdBG` if it matches a certain condition.
374    pub fn pop_if<F>(&mut self, condition: F) -> Option<LdBG>
375    where
376        F: Fn(&LdBG) -> bool,
377    {
378        let index = self.ldbgs.iter().position(condition);
379        if let Some(index) = index {
380            Some(self.ldbgs.remove(index))
381        } else {
382            None
383        }
384    }
385}