1use num_format::{Locale, ToFormattedString};
2use std::path::PathBuf;
3
4use crate::train::{
5 create_dataset_for_model, distance_to_a_contig_end, plot_roc_curve, process_reads,
6};
7use gbdt::decision_tree::DataVec;
8use gbdt::gradient_boost::GBDT;
9use skydive::ldbg::LdBG;
10use std::io::Write;
11
12pub fn start(
13 output: &PathBuf,
14 kmer_size: usize,
15 long_read_seq_paths: &Vec<PathBuf>,
16 short_read_seq_paths: &Vec<PathBuf>,
17 truth_seq_paths: &Vec<PathBuf>,
18 model_path: &PathBuf,
19) {
20 let long_read_seq_urls = skydive::parse::parse_file_names(long_read_seq_paths);
21 let short_read_seq_urls = skydive::parse::parse_file_names(short_read_seq_paths);
22 let truth_seq_urls = skydive::parse::parse_file_names(truth_seq_paths);
23
24 let all_lr_seqs: Vec<Vec<u8>> = process_reads(&long_read_seq_urls, "long");
26 let l1 = LdBG::from_sequences(String::from("l1"), kmer_size, &all_lr_seqs);
27
28 let all_sr_seqs: Vec<Vec<u8>> = process_reads(&short_read_seq_urls, "short");
30 let s1 = LdBG::from_sequences(String::from("s1"), kmer_size, &all_sr_seqs);
31
32 let all_truth_seqs: Vec<Vec<u8>> = process_reads(&truth_seq_urls, "truth");
34 let t1 = LdBG::from_sequences(String::from("s1"), kmer_size, &all_truth_seqs);
35
36 let lr_contigs = l1.assemble_all();
37 let lr_distances = distance_to_a_contig_end(&lr_contigs, kmer_size);
38
39 let sr_contigs = s1.assemble_all();
40 let sr_distances = distance_to_a_contig_end(&sr_contigs, kmer_size);
41
42 skydive::elog!(
44 "Loading GBDT model from {}...",
45 model_path.to_str().unwrap()
46 );
47 let gbdt = GBDT::load_model(model_path.to_str().unwrap()).expect("Unable to load model");
48
49 let test_kmers = l1
51 .kmers
52 .keys()
53 .chain(s1.kmers.keys())
54 .chain(t1.kmers.keys());
55 let test_data: DataVec =
56 create_dataset_for_model(test_kmers, &lr_distances, &sr_distances, &l1, &s1, &t1);
57
58 skydive::elog!("Computing accuracy on test data...");
60 let prediction = gbdt.predict(&test_data);
61 let pred_threshold = 0.5;
62
63 let mut num_correct = 0u32;
65 let mut num_total = 0u32;
66 for (data, pred) in test_data.iter().zip(prediction.iter()) {
67 let truth = data.label;
68 let call = if *pred > pred_threshold { 1.0 } else { 0.0 };
69 if (call - truth).abs() < f32::EPSILON {
70 num_correct += 1;
71 }
72 num_total += 1;
73 }
74
75 let (precision, recall, f1_score) =
77 crate::train::compute_precision_recall_f1(&test_data, &prediction, pred_threshold);
78 skydive::elog!("Prediction threshold: {:.2}", pred_threshold);
79 skydive::elog!("Precision: {:.2}%", 100.0 * precision);
80 skydive::elog!("Recall: {:.2}%", 100.0 * recall);
81 skydive::elog!("F1 score: {:.2}%", 100.0 * f1_score);
82
83 let fpr_tpr = crate::train::compute_fpr_tpr(&test_data, &prediction);
85
86 let csv_output = output.with_extension("csv");
88 let mut writer = std::fs::File::create(&csv_output).expect("Unable to create file");
89 for (tpr, fpr) in &fpr_tpr {
90 writeln!(writer, "{},{}", tpr, fpr).expect("Unable to write data");
91 }
92 skydive::elog!(
93 "TPR and FPR at various thresholds saved to {}",
94 csv_output.to_str().unwrap()
95 );
96
97 let png_output = output.with_extension("png");
99 plot_roc_curve(&png_output, &fpr_tpr).expect("Unable to plot ROC curve");
100 skydive::elog!("ROC curve saved to {}", png_output.to_str().unwrap());
101
102 skydive::elog!(
103 "Predicted accuracy: {}/{} ({:.2}%)",
104 num_correct.to_formatted_string(&Locale::en),
105 num_total.to_formatted_string(&Locale::en),
106 100.0 * num_correct as f32 / num_total as f32
107 );
108
109 gbdt.save_model(output.to_str().unwrap())
110 .expect("Unable to save model");
111
112 skydive::elog!("Model saved to {}", output.to_str().unwrap());
113}