skydive/
stage.rs

1// Import the Result type from the anyhow crate for error handling.
2use anyhow::Result;
3use linked_hash_set::LinkedHashSet;
4use parquet::data_type::AsBytes;
5use rust_htslib::bam::record::Aux;
6
7// Import various standard library collections.
8use std::collections::{HashMap, HashSet};
9use std::env;
10use std::fs::File;
11use std::io::BufWriter;
12use std::path::PathBuf;
13
14// Import the Url type to work with URLs.
15use url::Url;
16
17// Import ExponentialBackoff for retrying operations.
18use backoff::ExponentialBackoff;
19
20// Import rayon's parallel iterator traits.
21// use rayon::prelude::*;
22use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
23
24// Import types from rust_htslib for working with BAM files.
25use bio::io::fasta;
26use rust_htslib::bam::{self, FetchDefinition, IndexedReader, Read};
27use rust_htslib::faidx::Reader;
28
29// Import functions for authorizing access to Google Cloud Storage.
30use crate::env::{gcs_authorize_data_access, local_guess_curl_ca_bundle};
31
32/// Function to open a BAM/CRAM file from a URL and cache its contents locally.
33///
34/// # Arguments
35///
36/// * `seqs_url` - A reference to a URL object representing the sequence file URL.
37///
38/// # Returns
39///
40/// An `IndexedReader` object representing the opened BAM/CRAM file.
41///
42/// # Errors
43///
44/// This function returns an error if the BAM/CRAM file cannot be opened.
45///
46/// # Panics
47///
48/// This function panics if the URL scheme is not recognized.
49pub fn open_bam(seqs_url: &Url) -> Result<IndexedReader> {
50    if seqs_url.to_string().starts_with("gs://") && env::var("GCS_OAUTH_TOKEN").is_err() {
51        gcs_authorize_data_access();
52    }
53
54    // Try to open the BAM file from the URL, with retries for authorization.
55    let bam = match IndexedReader::from_url(seqs_url) {
56        Ok(bam) => bam,
57        Err(_) => {
58            crate::elog!("Read '{}', attempt 2 (reauthorizing to GCS)", seqs_url);
59
60            // If opening fails, try authorizing access to Google Cloud Storage.
61            gcs_authorize_data_access();
62
63            // Try opening the BAM file again.
64            match IndexedReader::from_url(seqs_url) {
65                Ok(bam) => bam,
66                Err(_) => {
67                    crate::elog!("Read '{}', attempt 3 (overriding cURL CA bundle)", seqs_url);
68
69                    // If it still fails, guess the cURL CA bundle path.
70                    local_guess_curl_ca_bundle();
71
72                    // Try one last time to open the BAM file.
73                    IndexedReader::from_url(seqs_url)?
74                }
75            }
76        }
77    };
78
79    Ok(bam)
80}
81
82/// Function to open a FASTA file from a URL and cache its contents locally.
83///
84/// # Arguments
85///
86/// * `seqs_url` - A reference to a URL object representing the sequence file URL.
87///
88/// # Returns
89///
90/// A `Reader` object representing the opened FASTA file.
91///
92/// # Errors
93///
94/// This function returns an error if the FASTA file cannot be opened.
95///
96/// # Panics
97///
98/// This function panics if the URL scheme is not recognized.
99pub fn open_fasta(seqs_url: &Url) -> Result<Reader> {
100    if seqs_url.to_string().starts_with("gs://") && env::var("GCS_OAUTH_TOKEN").is_err() {
101        gcs_authorize_data_access();
102    }
103
104    // Try to open the FASTA file from the URL, with retries for authorization.
105    let fasta = match Reader::from_url(seqs_url) {
106        Ok(fasta) => fasta,
107        Err(_) => {
108            crate::elog!("Read '{}', attempt 2 (reauthorizing to GCS)", seqs_url);
109
110            // If opening fails, try authorizing access to Google Cloud Storage.
111            gcs_authorize_data_access();
112
113            // Try opening the BAM file again.
114            match Reader::from_url(seqs_url) {
115                Ok(bam) => bam,
116                Err(_) => {
117                    crate::elog!("Read '{}', attempt 3 (overriding cURL CA bundle)", seqs_url);
118
119                    // If it still fails, guess the cURL CA bundle path.
120                    local_guess_curl_ca_bundle();
121
122                    // Try one last time to open the FASTA file.
123                    Reader::from_url(seqs_url)?
124                }
125            }
126        }
127    };
128
129    Ok(fasta)
130}
131
132// Function to get a mapping between read group and sample name from a BAM header.
133fn get_rg_to_sm_mapping(bam: &IndexedReader) -> HashMap<String, String> {
134    let header = bam::Header::from_template(bam.header());
135
136    let rg_sm_map: HashMap<String, String> = header
137        .to_hashmap()
138        .into_iter()
139        .flat_map(|(_, records)| records)
140        .filter(|record| record.contains_key("ID") && record.contains_key("SM"))
141        .map(|record| (record["ID"].clone(), record["SM"].clone()))
142        .collect();
143
144    rg_sm_map
145}
146
147fn get_sm_name_from_rg(read: &bam::Record, rg_sm_map: &HashMap<String, String>) -> Result<String> {
148    let rg = read.aux(b"RG")?;
149
150    if let Aux::String(v) = rg {
151        if let Some(sm) = rg_sm_map.get(v) {
152            Ok(sm.to_owned())
153        } else {
154            Err(anyhow::anyhow!(
155                "Sample name not found for read group: {}",
156                v
157            ))
158        }
159    } else {
160        Err(anyhow::anyhow!("Read group is not a string"))
161    }
162}
163
164// Function to extract seqs from a BAM file within a specified genomic region.
165pub fn extract_aligned_bam_reads(
166    _basename: &str,
167    bam: &mut IndexedReader,
168    chr: &str,
169    start: &u64,
170    stop: &u64,
171    name: &str,
172    haplotype: Option<u8>,
173) -> Result<Vec<fasta::Record>> {
174    let rg_sm_map = get_rg_to_sm_mapping(bam);
175
176    let mut bmap = HashMap::new();
177
178    let _ = bam.fetch(((*chr).as_bytes(), *start, *stop));
179    for p in bam.pileup() {
180        let pileup = p.unwrap();
181
182        if *start <= (pileup.pos() as u64) && (pileup.pos() as u64) < *stop {
183            // for alignment in pileup.alignments().filter(|a| !a.record().is_secondary()) {
184            for (i, alignment) in pileup.alignments().enumerate().filter(|(_, a)| {
185                haplotype.is_none()
186                    || a.record()
187                        .aux(b"HP")
188                        .ok()
189                        .map(|aux| match aux {
190                            Aux::U8(v) => v == haplotype.unwrap(),
191                            _ => false,
192                        })
193                        .unwrap_or(false)
194            }) {
195                let qname = String::from_utf8_lossy(alignment.record().qname()).into_owned();
196                let sm = match get_sm_name_from_rg(&alignment.record(), &rg_sm_map) {
197                    Ok(a) => a,
198                    Err(_) => String::from("unknown"),
199                };
200
201                let is_secondary = alignment.record().is_secondary();
202                let is_supplementary = alignment.record().is_supplementary();
203                let seq_name = format!("{qname}|{name}|{sm}|{i}|{is_secondary}|{is_supplementary}");
204
205                // crate::elog!("{}", seq_name);
206
207                if !bmap.contains_key(&seq_name) {
208                    bmap.insert(seq_name.clone(), String::new());
209                }
210
211                if !alignment.is_del() && !alignment.is_refskip() {
212                    let a = alignment.record().seq()[alignment.qpos().unwrap()];
213
214                    bmap.get_mut(&seq_name).unwrap().push(a as char);
215                }
216
217                if let bam::pileup::Indel::Ins(len) = alignment.indel() {
218                    if let Some(pos1) = alignment.qpos() {
219                        let pos2 = pos1 + (len as usize);
220                        for pos in pos1..pos2 {
221                            let a = alignment.record().seq()[pos];
222
223                            bmap.get_mut(&seq_name).unwrap().push(a as char);
224                        }
225                    }
226                }
227            }
228        }
229    }
230
231    let records = bmap
232        .iter()
233        .map(|kv| fasta::Record::with_attrs(kv.0.as_str(), None, kv.1.as_bytes()))
234        .collect();
235
236    Ok(records)
237}
238
239// Function to extract unaligned seqs from a BAM file
240fn extract_unaligned_bam_reads(
241    _basename: &str,
242    bam: &mut IndexedReader,
243) -> Result<Vec<fasta::Record>> {
244    let rg_sm_map = get_rg_to_sm_mapping(bam);
245
246    let _ = bam.fetch(FetchDefinition::Unmapped);
247    let records = bam
248        .records()
249        .map(|r| {
250            let read = r.unwrap();
251            let qname = String::from_utf8_lossy(read.qname()).into_owned();
252            let sm = get_sm_name_from_rg(&read, &rg_sm_map).unwrap();
253
254            let seq_name = format!("{qname}|{sm}");
255
256            let vseq = read.seq().as_bytes();
257            let bseq = vseq.as_bytes();
258
259            let seq = fasta::Record::with_attrs(seq_name.as_str(), Some(""), bseq);
260
261            seq
262        })
263        .collect();
264
265    Ok(records)
266}
267
268// Function to extract seqs from a FASTA file within a specified genomic region.
269fn extract_fasta_seqs(
270    basename: &String,
271    fasta: &mut Reader,
272    chr: &String,
273    start: &u64,
274    stop: &u64,
275    name: &String,
276) -> Result<Vec<fasta::Record>> {
277    let id = format!("{chr}:{start}-{stop}|{name}|{basename}");
278    let seq = fasta.fetch_seq_string(chr, usize::try_from(*start)?, usize::try_from(*stop - 1)?)?;
279
280    if seq.len() > 0 {
281        let records = vec![fasta::Record::with_attrs(id.as_str(), None, seq.as_bytes())];
282
283        return Ok(records);
284    }
285
286    Err(anyhow::anyhow!("No sequence found for locus: {}", id))
287}
288
289// Function to stage data from a single file.
290fn stage_data_from_one_file(
291    seqs_url: &Url,
292    loci: &LinkedHashSet<(String, u64, u64, String)>,
293    unmapped: bool,
294    haplotype: Option<u8>,
295) -> Result<Vec<fasta::Record>> {
296    let mut all_seqs = Vec::new();
297
298    let basename = seqs_url
299        .path_segments()
300        .map(|c| c.collect::<Vec<_>>())
301        .unwrap()
302        .last()
303        .unwrap()
304        .to_string();
305
306    let seqs_str = seqs_url.as_str();
307    if seqs_str.ends_with(".bam") || seqs_str.ends_with(".cram") {
308        // Handle BAM/CRAM file processing
309        let basename = basename
310            .trim_end_matches(".bam")
311            .trim_end_matches(".cram")
312            .to_string();
313        let mut bam = open_bam(seqs_url)?;
314
315        // Extract seqs for the current locus.
316        for (chr, start, stop, name) in loci.iter() {
317            let aligned_seqs =
318                extract_aligned_bam_reads(&basename, &mut bam, chr, start, stop, name, haplotype)
319                    .unwrap();
320            all_seqs.extend(aligned_seqs);
321        }
322
323        // Optionally extract unaligned reads.
324        if unmapped {
325            let unaligned_seqs = extract_unaligned_bam_reads(&basename, &mut bam).unwrap();
326            all_seqs.extend(unaligned_seqs);
327        }
328    } else if seqs_str.ends_with(".fa")
329        || seqs_str.ends_with(".fasta")
330        || seqs_str.ends_with(".fa.gz")
331        || seqs_str.ends_with(".fasta.gz")
332    {
333        // Handle FASTA file processing
334        let basename = basename
335            .trim_end_matches(".fasta.gz")
336            .trim_end_matches(".fasta")
337            .trim_end_matches(".fa.gz")
338            .trim_end_matches(".fa")
339            .to_string();
340        let mut fasta = open_fasta(seqs_url)?;
341
342        for (chr, start, stop, name) in loci.iter() {
343            // Extract seqs for the current locus.
344            let seqs = extract_fasta_seqs(&basename, &mut fasta, chr, start, stop, name)
345                .map_or_else(|_| Vec::new(), |s| s);
346
347            // Extend the all_seqs vector with the seqs from the current locus.
348            all_seqs.extend(seqs);
349        }
350    } else {
351        // Handle unknown file extension
352        return Err(anyhow::anyhow!("Unsupported file type: {}", seqs_url));
353    }
354
355    Ok(all_seqs)
356}
357
358// Function to stage data from multiple BAM files.
359fn stage_data_from_all_files(
360    seq_urls: &HashSet<Url>,
361    loci: &LinkedHashSet<(String, u64, u64, String)>,
362    unmapped: bool,
363    haplotype: Option<u8>,
364) -> Result<Vec<fasta::Record>> {
365    // Use a parallel iterator to process multiple BAM files concurrently.
366    let all_data: Vec<_> = seq_urls
367        .par_iter()
368        .map(|seqs_url| {
369            // Define an operation to stage data from one file.
370            let op = || {
371                let seqs = stage_data_from_one_file(seqs_url, loci, unmapped, haplotype)?;
372                Ok(seqs)
373            };
374
375            // Retry the operation with exponential backoff in case of failure.
376            match backoff::retry(ExponentialBackoff::default(), op) {
377                Ok(seqs) => seqs,
378                Err(e) => {
379                    // If all retries fail, panic with an error message.
380                    panic!("Error: {e}");
381                }
382            }
383        })
384        .collect();
385
386    let flattened_data = all_data.into_iter().flatten().collect::<Vec<_>>();
387
388    // Return a flattened vector of sequences
389    Ok(flattened_data)
390}
391
392/// Checks if a given genomic range spans any of the loci in the provided set.
393///
394/// # Arguments
395///
396/// * `start` - The start position of the genomic range.
397/// * `end` - The end position of the genomic range.
398/// * `loci` - A reference to a `HashSet` of tuples representing the loci.
399///
400/// # Returns
401///
402/// A `Result` containing `true` if the range spans any loci, `false` otherwise. Returns an error if the positions are negative.
403///
404/// # Errors
405///
406/// Returns an error if the start or end positions are negative.
407///
408/// # Panics
409///
410/// This function does not panic.
411pub fn read_spans_locus(
412    start: i64,
413    end: i64,
414    loci: &HashSet<(String, i64, i64)>,
415) -> Result<bool, String> {
416    if start < 0 || end < 0 {
417        return Err("Error: Negative genomic positions are not allowed.".to_string());
418    }
419
420    Ok(loci.iter().any(|e| start <= e.1 && end >= e.2))
421}
422
423/// Public function to stage data from multiple BAM files and write to an output file.
424///
425/// # Arguments
426///
427/// * `output_path` - A reference to a `PathBuf` representing the output file path.
428/// * `loci` - A reference to a `HashSet` of tuples representing the loci to extract.
429/// * `seq_urls` - A reference to a `HashSet` of URLs representing the sequence files.
430/// * `unmapped` - A boolean indicating whether to extract unmapped reads.
431/// * `cache_path` - A reference to a `PathBuf` representing the cache directory path.
432///
433/// # Returns
434///
435/// The number of records written to the output file.
436///
437/// # Errors
438///
439/// This function returns an error if the output file cannot be created.
440///
441/// # Panics
442///
443/// If an error occurs while staging data from the files.
444///
445pub fn stage_data(
446    output_path: &PathBuf,
447    loci: &LinkedHashSet<(String, u64, u64, String)>,
448    seq_urls: &HashSet<Url>,
449    unmapped: bool,
450    haplotype: Option<u8>,
451    cache_path: &PathBuf,
452) -> Result<usize> {
453    let current_dir = env::current_dir()?;
454    env::set_current_dir(cache_path).unwrap();
455
456    // Stage data from all BAM files.
457    let all_data = match stage_data_from_all_files(seq_urls, loci, unmapped, haplotype) {
458        Ok(all_data) => all_data,
459        Err(e) => {
460            panic!("Error: {e}");
461        }
462    };
463
464    env::set_current_dir(current_dir).unwrap();
465
466    // Write to a FASTA file.
467    let mut buf_writer = BufWriter::new(File::create(output_path)?);
468    let mut fasta_writer = fasta::Writer::new(&mut buf_writer);
469
470    for record in all_data.iter() {
471        if record.seq().len() > 0 {
472            fasta_writer.write_record(record)?;
473        }
474    }
475
476    let _ = fasta_writer.flush();
477
478    Ok(all_data.len())
479}
480
481pub fn stage_data_in_memory(
482    loci: &LinkedHashSet<(String, u64, u64, String)>,
483    seq_urls: &HashSet<Url>,
484    unmapped: bool,
485    cache_path: &PathBuf,
486) -> Result<Vec<fasta::Record>> {
487    let current_dir = env::current_dir()?;
488    env::set_current_dir(cache_path).unwrap();
489
490    // Stage data from all BAM files.
491    let all_data = match stage_data_from_all_files(seq_urls, loci, unmapped, None) {
492        Ok(all_data) => all_data,
493        Err(e) => {
494            panic!("Error: {e}");
495        }
496    };
497
498    env::set_current_dir(current_dir).unwrap();
499
500    Ok(all_data)
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use std::collections::HashSet;
507    use url::Url;
508
509    // This test may pass, but still print a message to stderr regarding its failure to access data. This is because
510    // open_bam() tries a couple of authorization methods before accessing data, and the initial failures print a
511    // message to stderr. Elsewhere in the code, we suppress such messages (i.e. in stage_data()), but here we don't.
512    #[test]
513    fn test_open_bam() {
514        let seqs_url = Url::parse(
515            "gs://fc-8c3900db-633f-477f-96b3-fb31ae265c44/results/PBFlowcell/m84060_230907_210011_s2/reads/ccs/aligned/m84060_230907_210011_s2.bam"
516        ).unwrap();
517        let bam = open_bam(&seqs_url);
518
519        assert!(bam.is_ok(), "Failed to open bam file");
520    }
521
522    #[test]
523    fn test_stage_data_from_one_file() {
524        let seqs_url = Url::parse(
525            "gs://fc-8c3900db-633f-477f-96b3-fb31ae265c44/results/PBFlowcell/m84060_230907_210011_s2/reads/ccs/aligned/m84060_230907_210011_s2.bam"
526        ).unwrap();
527        let mut loci = LinkedHashSet::new();
528        // let loci = LinkedHashSet::from([("chr15".to_string(), 23960193, 23963918, "test".to_string())]);
529        loci.insert(("chr15".to_string(), 23960193, 23963918, "test".to_string()));
530
531        let result = stage_data_from_one_file(&seqs_url, &loci, false, None);
532
533        assert!(result.is_ok(), "Failed to stage data from one file");
534    }
535
536    #[test]
537    fn test_stage_data() {
538        let cache_path = std::env::temp_dir();
539        let output_path = cache_path.join("test.bam");
540
541        let mut loci = LinkedHashSet::new();
542        loci.insert(("chr15".to_string(), 23960193, 23963918, "test".to_string()));
543
544        let seqs_url = Url::parse(
545            "gs://fc-8c3900db-633f-477f-96b3-fb31ae265c44/results/PBFlowcell/m84060_230907_210011_s2/reads/ccs/aligned/m84060_230907_210011_s2.bam"
546        ).unwrap();
547        let seq_urls = HashSet::from([seqs_url]);
548
549        let result = stage_data(&output_path, &loci, &seq_urls, false, None, &cache_path);
550
551        assert!(result.is_ok(), "Failed to stage data from file");
552
553        println!("{:?}", result.unwrap());
554    }
555
556    #[test]
557    fn test_stage_multiple_data() {
558        let cache_path = std::env::temp_dir();
559        let output_path = cache_path.join("test.bam");
560
561        let seqs_url_1 = Url::parse(
562            "gs://fc-8c3900db-633f-477f-96b3-fb31ae265c44/results/PBFlowcell/m84060_230907_210011_s2/reads/ccs/aligned/m84060_230907_210011_s2.bam"
563        ).unwrap();
564        let seqs_url_2 = Url::parse(
565            "gs://fc-8c3900db-633f-477f-96b3-fb31ae265c44/results/PBFlowcell/m84043_230901_211947_s1/reads/ccs/aligned/m84043_230901_211947_s1.hifi_reads.bc2080.bam"
566        ).unwrap();
567        let mut loci = LinkedHashSet::new();
568        loci.insert(("chr15".to_string(), 23960193, 23963918, "test".to_string()));
569
570        let seq_urls = HashSet::from([seqs_url_1, seqs_url_2]);
571
572        let result = stage_data(&output_path, &loci, &seq_urls, false, None, &cache_path);
573
574        println!("{:?}", result);
575
576        assert!(result.is_ok(), "Failed to stage data from all files");
577    }
578}