This is an automated email from the ASF dual-hosted git repository.

mssun pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git


The following commit(s) were added to refs/heads/develop by this push:
     new 5e547cf  [function] Add native logistic regression prediction (#250)
5e547cf is described below

commit 5e547cfee1b2fa6faaca91905c7a2b648e344515
Author: Zhaofeng Chen <[email protected]>
AuthorDate: Thu Mar 26 13:10:04 2020 -0700

    [function] Add native logistic regression prediction (#250)
---
 function/src/gbdt_training.rs                      |   1 -
 function/src/lib.rs                                |   3 +
 ...aining.rs => logistic_regression_prediction.rs} | 110 +++++++++++----------
 function/src/logistic_regression_training.rs       |   2 +-
 .../expected_result.txt                            |   5 +
 .../logistic_regression_prediction/model.txt       |   1 +
 .../predict_input.txt                              |   5 +
 7 files changed, 71 insertions(+), 56 deletions(-)

diff --git a/function/src/gbdt_training.rs b/function/src/gbdt_training.rs
index 9b63f99..30c7ceb 100644
--- a/function/src/gbdt_training.rs
+++ b/function/src/gbdt_training.rs
@@ -78,7 +78,6 @@ impl TeaclaveFunction for GbdtTraining {
         gbdt_train_mod.fit(&mut train_dv);
         let model_json = serde_json::to_string(&gbdt_train_mod)?;
 
-        log::debug!("create file...");
         // save the model to output
         let mut model_file = runtime.create_output(OUT_MODEL)?;
         model_file.write_all(model_json.as_bytes())?;
diff --git a/function/src/lib.rs b/function/src/lib.rs
index 4113f90..a329e6e 100644
--- a/function/src/lib.rs
+++ b/function/src/lib.rs
@@ -29,12 +29,14 @@ mod context;
 mod echo;
 mod gbdt_prediction;
 mod gbdt_training;
+mod logistic_regression_prediction;
 mod logistic_regression_training;
 mod mesapy;
 
 pub use echo::Echo;
 pub use gbdt_prediction::GbdtPrediction;
 pub use gbdt_training::GbdtTraining;
+pub use logistic_regression_prediction::LogitRegPrediction;
 pub use logistic_regression_training::LogitRegTraining;
 pub use mesapy::Mesapy;
 
@@ -51,6 +53,7 @@ pub mod tests {
             mesapy::tests::run_tests(),
             context::tests::run_tests(),
             logistic_regression_training::tests::run_tests(),
+            logistic_regression_prediction::tests::run_tests(),
         )
     }
 }
diff --git a/function/src/logistic_regression_training.rs 
b/function/src/logistic_regression_prediction.rs
similarity index 55%
copy from function/src/logistic_regression_training.rs
copy to function/src/logistic_regression_prediction.rs
index a2c2873..7de3b9c 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_prediction.rs
@@ -23,78 +23,81 @@ use rusty_machine::learning::optim::grad_desc::GradientDesc;
 use rusty_machine::learning::SupModel;
 use rusty_machine::linalg;
 use serde_json;
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
 
 use anyhow;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
 use teaclave_types::FunctionArguments;
 use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
 
 #[derive(Default)]
-pub struct LogitRegTraining;
+pub struct LogitRegPrediction;
 
-static TRAINING_DATA: &str = "training_data";
-static OUT_MODEL_FILE: &str = "model_file";
+static MODEL_FILE: &str = "model_file";
+static INPUT_DATA: &str = "data_file";
+static RESULT: &str = "result_file";
 
-impl TeaclaveFunction for LogitRegTraining {
+impl TeaclaveFunction for LogitRegPrediction {
     fn execute(
         &self,
         runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
-        arguments: FunctionArguments,
+        _arguments: FunctionArguments,
     ) -> anyhow::Result<String> {
-        let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
-        let alg_iters = arguments.get("alg_iters")?.as_usize()?;
-        let feature_size = arguments.get("feature_size")?.as_usize()?;
-
-        let input = runtime.open_input(TRAINING_DATA)?;
-        let (flattend_features, targets) = parse_training_data(input, 
feature_size)?;
-        let data_size = targets.len();
-        let data_matrix = linalg::Matrix::new(data_size, feature_size, 
flattend_features);
-        let targets = linalg::Vector::new(targets);
-
-        let gd = GradientDesc::new(alg_alpha, alg_iters);
-        let mut lr = LogisticRegressor::new(gd);
-        lr.train(&data_matrix, &targets)?;
-
-        let model_json = serde_json::to_string(&lr).unwrap();
-        let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
-        model_file.write_all(model_json.as_bytes())?;
-        Ok(format!("Trained {} lines of data.", data_size))
+        let mut model_json = String::new();
+        let mut f = runtime.open_input(MODEL_FILE)?;
+        f.read_to_string(&mut model_json)?;
+
+        let lr: LogisticRegressor<GradientDesc> = 
serde_json::from_str(&model_json)?;
+        let feature_size = lr
+            .parameters()
+            .ok_or_else(|| anyhow::anyhow!("Model parameter is None"))?
+            .size()
+            - 1;
+
+        let input = runtime.open_input(INPUT_DATA)?;
+        let data_matrix = parse_input_data(input, feature_size)?;
+
+        let result = lr.predict(&data_matrix)?;
+
+        let mut output = runtime.create_output(RESULT)?;
+        let result_cnt = result.data().len();
+        for c in result.data().iter() {
+            writeln!(&mut output, "{:.4}", c)?;
+        }
+        Ok(format!("Predicted {} lines of data.", result_cnt))
     }
 }
 
-fn parse_training_data(
+fn parse_input_data(
     input: impl io::Read,
     feature_size: usize,
-) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
-    let reader = BufReader::new(input);
-    let mut targets = Vec::<f64>::new();
-    let mut features = Vec::new();
+) -> anyhow::Result<linalg::Matrix<f64>> {
+    let mut flattened_data = Vec::new();
+    let mut count = 0;
 
+    let reader = BufReader::new(input);
     for line_result in reader.lines() {
         let line = line_result?;
         let trimed_line = line.trim();
         anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
 
-        log::debug!(trimed_line);
-        let mut v: Vec<f64> = trimed_line
+        let v: Vec<f64> = trimed_line
             .split(',')
             .map(|x| x.parse::<f64>())
             .collect::<std::result::Result<_, _>>()?;
 
         anyhow::ensure!(
-            v.len() == feature_size + 1,
+            v.len() == feature_size,
             "Data format error: column len = {}, expected = {}",
             v.len(),
-            feature_size + 1
+            feature_size
         );
 
-        let label = v.swap_remove(feature_size);
-        targets.push(label);
-        features.extend(v);
+        flattened_data.extend(v);
+        count += 1;
     }
 
-    Ok((features, targets))
+    Ok(linalg::Matrix::new(count, feature_size, flattened_data))
 }
 
 #[cfg(feature = "enclave_unit_test")]
@@ -107,36 +110,35 @@ pub mod tests {
     use teaclave_types::*;
 
     pub fn run_tests() -> bool {
-        run_tests!(test_logistic_regression_training)
+        run_tests!(test_logistic_regression_prediction)
     }
 
-    fn test_logistic_regression_training() {
-        let func_args = FunctionArguments::new(hashmap! {
-            "alg_alpha" => "0.3",
-            "alg_iters" => "100",
-            "feature_size" => "30"
-        });
+    fn test_logistic_regression_prediction() {
+        let func_args = FunctionArguments::default();
 
-        let base = 
Path::new("fixtures/functions/logistic_regression_training");
-        let training_data = base.join("train.txt");
-        let plain_output = base.join("model.txt.out");
-        let expected_output = base.join("expected_model.txt");
+        let base = 
Path::new("fixtures/functions/logistic_regression_prediction");
+        let model = base.join("model.txt");
+        let plain_input = base.join("predict_input.txt");
+        let plain_output = base.join("predict_result.txt.out");
+        let expected_output = base.join("expected_result.txt");
 
         let input_files = StagedFiles::new(hashmap!(
-            TRAINING_DATA =>
-            StagedFileInfo::new(&training_data, TeaclaveFile128Key::random()),
+            MODEL_FILE =>
+            StagedFileInfo::new(&model, TeaclaveFile128Key::random()),
+            INPUT_DATA =>
+            StagedFileInfo::new(&plain_input, TeaclaveFile128Key::random()),
         ));
 
         let output_files = StagedFiles::new(hashmap!(
-            OUT_MODEL_FILE =>
+            RESULT =>
             StagedFileInfo::new(&plain_output, TeaclaveFile128Key::random())
         ));
 
         let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
 
-        let function = LogitRegTraining;
+        let function = LogitRegPrediction;
         let summary = function.execute(runtime, func_args).unwrap();
-        assert_eq!(summary, "Trained 100 lines of data.");
+        assert_eq!(summary, "Predicted 5 lines of data.");
 
         let result = fs::read_to_string(&plain_output).unwrap();
         let expected = fs::read_to_string(&expected_output).unwrap();
diff --git a/function/src/logistic_regression_training.rs 
b/function/src/logistic_regression_training.rs
index a2c2873..d509319 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_training.rs
@@ -59,6 +59,7 @@ impl TeaclaveFunction for LogitRegTraining {
         let model_json = serde_json::to_string(&lr).unwrap();
         let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
         model_file.write_all(model_json.as_bytes())?;
+
         Ok(format!("Trained {} lines of data.", data_size))
     }
 }
@@ -76,7 +77,6 @@ fn parse_training_data(
         let trimed_line = line.trim();
         anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
 
-        log::debug!(trimed_line);
         let mut v: Vec<f64> = trimed_line
             .split(',')
             .map(|x| x.parse::<f64>())
diff --git 
a/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt 
b/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt
new file mode 100644
index 0000000..ae59d4d
--- /dev/null
+++ 
b/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt
@@ -0,0 +1,5 @@
+0.7530
+0.9163
+0.2041
+0.0094
+0.0426
diff --git a/tests/fixtures/functions/logistic_regression_prediction/model.txt 
b/tests/fixtures/functions/logistic_regression_prediction/model.txt
new file mode 100644
index 0000000..45d02b7
--- /dev/null
+++ b/tests/fixtures/functions/logistic_regression_prediction/model.txt
@@ -0,0 +1 @@
+{"base":{"parameters":{"size":31,"data":[-0.7217673215572631,1.7917952971098938,0.5160567210379624,1.9611477804952018,-0.06059944514037786,-0.3922383186171758,-0.016518720611358107,0.5544320061020859,-0.0633262139909826,1.411549874207487,0.10834644397882805,0.5745352048596792,0.7267581684630282,0.3591176819552514,0.05218327931711743,0.2311147165162303,0.7200202688528744,-0.11124176185023933,0.20438197304348082,0.5711206726813367,0.4502424746039776,0.49126862948834127,0.5477485136765818,-
 [...]
\ No newline at end of file
diff --git 
a/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt 
b/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt
new file mode 100644
index 0000000..4440050
--- /dev/null
+++ b/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt
@@ -0,0 +1,5 @@
+1.5619560125603997,-1.3585291014731475,-1.929064964958864,-0.48811178352065915,-1.6298734512909983,0.8556691653018434,0.3798596856717057,0.5206354603638552,1.4795022936289703,-1.475871675153695,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
+0.4317415153486741,-1.2724398149784077,1.2065551475519039,0.30051381061013843,-1.126829867003464,-0.5463861825373719,1.0927733527526797,1.3761579259389451,-1.077581460405606,-0.66117948106943,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
+0.1831170075985541,0.8266275084009919,-1.6263223984375375,0.22082406698679818,1.1026886446611233,-1.079671043815752,0.6608823735448814,-0.5931674081381179,0.7960784158847922,-0.5670352239419173,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
+-1.7258296754127966,0.3190923551324158,-0.07759090516648384,0.47445362292910587,-0.43834833360941045,0.4858568905058413,-1.029447361090786,-0.6760396593910052,-0.9488385062478163,1.7400447870621698,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
+-1.3273559563065362,1.3048414649869655,-0.5504419191862819,0.48415113906417967,0.6153152869330243,0.5076592437240378,0.9637908015683003,1.0405852286895143,-0.48539483338526546,1.4724927513878685,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to