1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
//! The prediction uses a RNN model (LSTM) trained on a CSV files created by ```cargo run --features replay```.
//! We use [TensorflowLite](https://www.tensorflow.org/lite/) with the windows dll.
//! This allows us to make predictions locally, without having to install Tensorflow, which is an heavy framework to
//! install and to start.
//! This comes at a price: some interesting features available with Tensorflow and Keras are [not supported
//! by TfLite](https://www.tensorflow.org/lite/convert/rnn). Our main issue is that owlyshield predict
//! can generate very long predictions sequences that could be handle by stateful lstm with trucated
//! backpropagation through time (tbtt). But stateful lstm is not possible with TfLite and the state
//! has to be manually propagated between epochs. That's why we limit the sequence length, capped to
//! [crate::predictions::prediction::PREDMTRXROWS]. See module [crate::predictions::prediction::input_tensors] for details.
use std::fs::File;
use std::io::{BufReader, Read};
use byteorder::{ByteOrder, LittleEndian};
use moonfire_tflite::*;
use crate::predictions::prediction::input_tensors::VecvecCapped;
use crate::predictions::prediction::PREDMTRXCOLS;
/// The .tflite (converted from Tensorflow/Keras) model is included as a static variable.
static MODEL: &str = "./models/model.tflite";
/// Features means vector, used by Standard Scaling.
static MEANS: &str = "./models/mean.json";
/// Features standard deviations vector used by Standard Scaling.
static STDVS: &str = "./models/std.json";
/// A record to describe a tflite model
pub struct TfLiteMalware {
model: Model,
/// Needed by Standard Scaling and set to [MEANS]
means: Vec<f32>,
/// Needed by Standard Scaling and set to [STDVS]
stdvs: Vec<f32>,
impl TfLiteMalware {
pub fn new() -> TfLiteMalware {
let mut means = Vec::new();
.read_to_end(&mut means)
let mut stdvs = Vec::new();
.read_to_end(&mut stdvs)
TfLiteMalware {
model: Model::from_file(MODEL).unwrap(),
means: serde_json::from_slice(means.as_slice()).unwrap(),
stdvs: serde_json::from_slice(stdvs.as_slice()).unwrap(),
/// Make a prediction on the sequence *predmtrx*. The prediction can be costly.
/// The model input tensor dimensions are (None, [PREDMTRXCOLS]) and is dimensioned accordingly
/// by the *InterpreterBuilder*.
/// The model returns only the last prediction (it does not returns sequences).
pub fn make_prediction(&self, predmtrx: &VecvecCapped<f32>) -> f32 {
let inputmtrx = self.standardize(predmtrx).to_vec();
let builder = Interpreter::builder();
let mut interpreter = builder
.build(&self.model, predmtrx.rows_len(), PREDMTRXCOLS)
let mut inputs = interpreter.inputs();
let mut dst = inputs[0].bytes_mut();
LittleEndian::write_f32_into(inputmtrx.as_slice(), &mut dst);
let outputs = interpreter.outputs();
let y_pred = outputs[0].f32s()[0];
/// Standard Scaling of the input vectors with [MEANS] and [STDVS].
fn standardize(&self, predmtrx: &VecvecCapped<f32>) -> VecvecCapped<f32> {
let mut res = predmtrx.clone();
let epsilon = 0.0001f32;
for i in 0..predmtrx.rows_len() {
for j in 0..predmtrx.capacity_cols {
let stdvs_j = self.stdvs[j];
let denominator = if stdvs_j < epsilon { epsilon } else { stdvs_j };
res[i][j] = (predmtrx[i][j] - self.means[j]) / denominator
// #[cfg(test)]
// #[doc(hidden)]
// mod tests {
// use crate::predictions::prediction::input_tensors::VecvecCapped;
// use crate::predictions::prediction::{PREDMTRXCOLS, PREDMTRXROWS};
// use crate::TfLiteMalware;
// #[test]
// fn test_standardize() {
// let mut predmtrx = VecvecCapped::new(PREDMTRXCOLS, PREDMTRXROWS);
// println!("{}", predmtrx.rows_len());
// for i in 0..PREDMTRXROWS {
// row.len() != self.capacity_cols {
// predmtrx.push_row([0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0].to_vec()).unwrap();
// }
// println!("{}", predmtrx.rows_len());
// let tflite_malware = TfLiteMalware::new();
// let x = tflite_malware.standardize(&predmtrx);
// println!("{:?}", x);
// //assert_eq!(x, 2);
// }
// }