hidive/
eval_model.rs

1use num_format::{Locale, ToFormattedString};
2use std::path::PathBuf;
3
4use crate::train::{
5    create_dataset_for_model, distance_to_a_contig_end, plot_roc_curve, process_reads,
6};
7use gbdt::decision_tree::DataVec;
8use gbdt::gradient_boost::GBDT;
9use skydive::ldbg::LdBG;
10use std::io::Write;
11
12pub fn start(
13    output: &PathBuf,
14    kmer_size: usize,
15    long_read_seq_paths: &Vec<PathBuf>,
16    short_read_seq_paths: &Vec<PathBuf>,
17    truth_seq_paths: &Vec<PathBuf>,
18    model_path: &PathBuf,
19) {
20    let long_read_seq_urls = skydive::parse::parse_file_names(long_read_seq_paths);
21    let short_read_seq_urls = skydive::parse::parse_file_names(short_read_seq_paths);
22    let truth_seq_urls = skydive::parse::parse_file_names(truth_seq_paths);
23
24    // Read all long reads.
25    let all_lr_seqs: Vec<Vec<u8>> = process_reads(&long_read_seq_urls, "long");
26    let l1 = LdBG::from_sequences(String::from("l1"), kmer_size, &all_lr_seqs);
27
28    // Read all short reads.
29    let all_sr_seqs: Vec<Vec<u8>> = process_reads(&short_read_seq_urls, "short");
30    let s1 = LdBG::from_sequences(String::from("s1"), kmer_size, &all_sr_seqs);
31
32    // Read all truth sequences.
33    let all_truth_seqs: Vec<Vec<u8>> = process_reads(&truth_seq_urls, "truth");
34    let t1 = LdBG::from_sequences(String::from("s1"), kmer_size, &all_truth_seqs);
35
36    let lr_contigs = l1.assemble_all();
37    let lr_distances = distance_to_a_contig_end(&lr_contigs, kmer_size);
38
39    let sr_contigs = s1.assemble_all();
40    let sr_distances = distance_to_a_contig_end(&sr_contigs, kmer_size);
41
42    // load model
43    skydive::elog!(
44        "Loading GBDT model from {}...",
45        model_path.to_str().unwrap()
46    );
47    let gbdt = GBDT::load_model(model_path.to_str().unwrap()).expect("Unable to load model");
48
49    // Prepare test data.
50    let test_kmers = l1
51        .kmers
52        .keys()
53        .chain(s1.kmers.keys())
54        .chain(t1.kmers.keys());
55    let test_data: DataVec =
56        create_dataset_for_model(test_kmers, &lr_distances, &sr_distances, &l1, &s1, &t1);
57
58    // Predict the test data.
59    skydive::elog!("Computing accuracy on test data...");
60    let prediction = gbdt.predict(&test_data);
61    let pred_threshold = 0.5;
62
63    // Evaluate accuracy of the model on the test data.
64    let mut num_correct = 0u32;
65    let mut num_total = 0u32;
66    for (data, pred) in test_data.iter().zip(prediction.iter()) {
67        let truth = data.label;
68        let call = if *pred > pred_threshold { 1.0 } else { 0.0 };
69        if (call - truth).abs() < f32::EPSILON {
70            num_correct += 1;
71        }
72        num_total += 1;
73    }
74
75    // Precision, Recall, and F1 score calculations.
76    let (precision, recall, f1_score) =
77        crate::train::compute_precision_recall_f1(&test_data, &prediction, pred_threshold);
78    skydive::elog!("Prediction threshold: {:.2}", pred_threshold);
79    skydive::elog!("Precision: {:.2}%", 100.0 * precision);
80    skydive::elog!("Recall: {:.2}%", 100.0 * recall);
81    skydive::elog!("F1 score: {:.2}%", 100.0 * f1_score);
82
83    // TPR and FPR calculations at various thresholds.
84    let fpr_tpr = crate::train::compute_fpr_tpr(&test_data, &prediction);
85
86    // Save TPR and FPR at various thresholds to a file.
87    let csv_output = output.with_extension("csv");
88    let mut writer = std::fs::File::create(&csv_output).expect("Unable to create file");
89    for (tpr, fpr) in &fpr_tpr {
90        writeln!(writer, "{},{}", tpr, fpr).expect("Unable to write data");
91    }
92    skydive::elog!(
93        "TPR and FPR at various thresholds saved to {}",
94        csv_output.to_str().unwrap()
95    );
96
97    // Create a ROC curve.
98    let png_output = output.with_extension("png");
99    plot_roc_curve(&png_output, &fpr_tpr).expect("Unable to plot ROC curve");
100    skydive::elog!("ROC curve saved to {}", png_output.to_str().unwrap());
101
102    skydive::elog!(
103        "Predicted accuracy: {}/{} ({:.2}%)",
104        num_correct.to_formatted_string(&Locale::en),
105        num_total.to_formatted_string(&Locale::en),
106        100.0 * num_correct as f32 / num_total as f32
107    );
108
109    gbdt.save_model(output.to_str().unwrap())
110        .expect("Unable to save model");
111
112    skydive::elog!("Model saved to {}", output.to_str().unwrap());
113}