1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
//! The prediction uses a RNN model (LSTM) trained on a CSV files created by  ```cargo run --features replay```.
//! We use [TensorflowLite](https://www.tensorflow.org/lite/) with the windows dll.
//!
//! This allows us to make predictions locally, without having to install Tensorflow, which is an heavy framework to
//! install and to start.
//! This comes at a price: some interesting features available with Tensorflow and Keras are [not supported
//! by TfLite](https://www.tensorflow.org/lite/convert/rnn). Our main issue is that owlyshield predict
//! can generate very long predictions sequences that could be handle by stateful lstm with trucated
//! backpropagation through time (tbtt). But stateful lstm is not possible with TfLite and the state
//! has to be manually propagated between epochs. That's why we limit the sequence length, capped to
//! [crate::predictions::prediction::PREDMTRXROWS]. See module [crate::predictions::prediction::input_tensors] for details.

use std::fs::File;
use std::io::{BufReader, Read};

use byteorder::{ByteOrder, LittleEndian};
use moonfire_tflite::*;

use crate::predictions::prediction::input_tensors::VecvecCapped;
use crate::predictions::prediction::PREDMTRXCOLS;

/// The .tflite (converted from Tensorflow/Keras) model is included as a static variable.
static MODEL: &str = "./models/model.tflite";
/// Features means vector, used by Standard Scaling.
static MEANS: &str = "./models/mean.json";
/// Features standard deviations vector used by Standard Scaling.
static STDVS: &str = "./models/std.json";

/// A record to describe a tflite model
pub struct TfLiteMalware {
    model: Model,
    /// Needed by Standard Scaling and set to [MEANS]
    means: Vec<f32>,
    /// Needed by Standard Scaling and set to [STDVS]
    stdvs: Vec<f32>,
}

impl TfLiteMalware {
    pub fn new() -> TfLiteMalware {
        let mut means = Vec::new();
        BufReader::new(File::open(MEANS).unwrap())
            .read_to_end(&mut means)
            .unwrap();

        let mut stdvs = Vec::new();
        BufReader::new(File::open(STDVS).unwrap())
            .read_to_end(&mut stdvs)
            .unwrap();

        TfLiteMalware {
            model: Model::from_file(MODEL).unwrap(),
            means: serde_json::from_slice(means.as_slice()).unwrap(),
            stdvs: serde_json::from_slice(stdvs.as_slice()).unwrap(),
        }
    }

    /// Make a prediction on the sequence *predmtrx*. The prediction can be costly.
    /// The model input tensor dimensions are (None, [PREDMTRXCOLS]) and is dimensioned accordingly
    /// by the *InterpreterBuilder*.
    /// The model returns only the last prediction (it does not returns sequences).
    pub fn make_prediction(&self, predmtrx: &VecvecCapped<f32>) -> f32 {
        let inputmtrx = self.standardize(predmtrx).to_vec();
        let builder = Interpreter::builder();
        let mut interpreter = builder
            .build(&self.model, predmtrx.rows_len(), PREDMTRXCOLS)
            .unwrap();

        let mut inputs = interpreter.inputs();

        let mut dst = inputs[0].bytes_mut();
        LittleEndian::write_f32_into(inputmtrx.as_slice(), &mut dst);
        interpreter.invoke().unwrap();
        let outputs = interpreter.outputs();

        let y_pred = outputs[0].f32s()[0];
        y_pred
    }

    /// Standard Scaling of the input vectors with [MEANS] and [STDVS].
    fn standardize(&self, predmtrx: &VecvecCapped<f32>) -> VecvecCapped<f32> {
        let mut res = predmtrx.clone();
        let epsilon = 0.0001f32;
        for i in 0..predmtrx.rows_len() {
            for j in 0..predmtrx.capacity_cols {
                let stdvs_j = self.stdvs[j];
                let denominator = if stdvs_j < epsilon { epsilon } else { stdvs_j };
                res[i][j] = (predmtrx[i][j] - self.means[j]) / denominator
            }
        }
        res
    }
}
//
// #[cfg(test)]
// #[doc(hidden)]
// mod tests {
//     use crate::predictions::prediction::input_tensors::VecvecCapped;
//     use crate::predictions::prediction::{PREDMTRXCOLS, PREDMTRXROWS};
//     use crate::TfLiteMalware;
//
//     #[test]
//     fn test_standardize() {
//         let mut predmtrx = VecvecCapped::new(PREDMTRXCOLS, PREDMTRXROWS);
//         println!("{}", predmtrx.rows_len());
//         for i in 0..PREDMTRXROWS {
//             row.len() != self.capacity_cols {
//             predmtrx.push_row([0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0].to_vec()).unwrap();
//         }
//         println!("{}", predmtrx.rows_len());
//         let tflite_malware = TfLiteMalware::new();
//         let x = tflite_malware.standardize(&predmtrx);
//         println!("{:?}", x);
//         //assert_eq!(x, 2);
//     }
// }