This is an automated email from the ASF dual-hosted git repository.
mssun pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
The following commit(s) were added to refs/heads/develop by this push:
new 5e547cf [function] Add native logistic regression prediction (#250)
5e547cf is described below
commit 5e547cfee1b2fa6faaca91905c7a2b648e344515
Author: Zhaofeng Chen <[email protected]>
AuthorDate: Thu Mar 26 13:10:04 2020 -0700
[function] Add native logistic regression prediction (#250)
---
function/src/gbdt_training.rs | 1 -
function/src/lib.rs | 3 +
...aining.rs => logistic_regression_prediction.rs} | 110 +++++++++++----------
function/src/logistic_regression_training.rs | 2 +-
.../expected_result.txt | 5 +
.../logistic_regression_prediction/model.txt | 1 +
.../predict_input.txt | 5 +
7 files changed, 71 insertions(+), 56 deletions(-)
diff --git a/function/src/gbdt_training.rs b/function/src/gbdt_training.rs
index 9b63f99..30c7ceb 100644
--- a/function/src/gbdt_training.rs
+++ b/function/src/gbdt_training.rs
@@ -78,7 +78,6 @@ impl TeaclaveFunction for GbdtTraining {
gbdt_train_mod.fit(&mut train_dv);
let model_json = serde_json::to_string(&gbdt_train_mod)?;
- log::debug!("create file...");
// save the model to output
let mut model_file = runtime.create_output(OUT_MODEL)?;
model_file.write_all(model_json.as_bytes())?;
diff --git a/function/src/lib.rs b/function/src/lib.rs
index 4113f90..a329e6e 100644
--- a/function/src/lib.rs
+++ b/function/src/lib.rs
@@ -29,12 +29,14 @@ mod context;
mod echo;
mod gbdt_prediction;
mod gbdt_training;
+mod logistic_regression_prediction;
mod logistic_regression_training;
mod mesapy;
pub use echo::Echo;
pub use gbdt_prediction::GbdtPrediction;
pub use gbdt_training::GbdtTraining;
+pub use logistic_regression_prediction::LogitRegPrediction;
pub use logistic_regression_training::LogitRegTraining;
pub use mesapy::Mesapy;
@@ -51,6 +53,7 @@ pub mod tests {
mesapy::tests::run_tests(),
context::tests::run_tests(),
logistic_regression_training::tests::run_tests(),
+ logistic_regression_prediction::tests::run_tests(),
)
}
}
diff --git a/function/src/logistic_regression_training.rs
b/function/src/logistic_regression_prediction.rs
similarity index 55%
copy from function/src/logistic_regression_training.rs
copy to function/src/logistic_regression_prediction.rs
index a2c2873..7de3b9c 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_prediction.rs
@@ -23,78 +23,81 @@ use rusty_machine::learning::optim::grad_desc::GradientDesc;
use rusty_machine::learning::SupModel;
use rusty_machine::linalg;
use serde_json;
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
use anyhow;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
use teaclave_types::FunctionArguments;
use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
#[derive(Default)]
-pub struct LogitRegTraining;
+pub struct LogitRegPrediction;
-static TRAINING_DATA: &str = "training_data";
-static OUT_MODEL_FILE: &str = "model_file";
+static MODEL_FILE: &str = "model_file";
+static INPUT_DATA: &str = "data_file";
+static RESULT: &str = "result_file";
-impl TeaclaveFunction for LogitRegTraining {
+impl TeaclaveFunction for LogitRegPrediction {
fn execute(
&self,
runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- arguments: FunctionArguments,
+ _arguments: FunctionArguments,
) -> anyhow::Result<String> {
- let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
- let alg_iters = arguments.get("alg_iters")?.as_usize()?;
- let feature_size = arguments.get("feature_size")?.as_usize()?;
-
- let input = runtime.open_input(TRAINING_DATA)?;
- let (flattend_features, targets) = parse_training_data(input,
feature_size)?;
- let data_size = targets.len();
- let data_matrix = linalg::Matrix::new(data_size, feature_size,
flattend_features);
- let targets = linalg::Vector::new(targets);
-
- let gd = GradientDesc::new(alg_alpha, alg_iters);
- let mut lr = LogisticRegressor::new(gd);
- lr.train(&data_matrix, &targets)?;
-
- let model_json = serde_json::to_string(&lr).unwrap();
- let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
- model_file.write_all(model_json.as_bytes())?;
- Ok(format!("Trained {} lines of data.", data_size))
+ let mut model_json = String::new();
+ let mut f = runtime.open_input(MODEL_FILE)?;
+ f.read_to_string(&mut model_json)?;
+
+ let lr: LogisticRegressor<GradientDesc> =
serde_json::from_str(&model_json)?;
+ let feature_size = lr
+ .parameters()
+ .ok_or_else(|| anyhow::anyhow!("Model parameter is None"))?
+ .size()
+ - 1;
+
+ let input = runtime.open_input(INPUT_DATA)?;
+ let data_matrix = parse_input_data(input, feature_size)?;
+
+ let result = lr.predict(&data_matrix)?;
+
+ let mut output = runtime.create_output(RESULT)?;
+ let result_cnt = result.data().len();
+ for c in result.data().iter() {
+ writeln!(&mut output, "{:.4}", c)?;
+ }
+ Ok(format!("Predicted {} lines of data.", result_cnt))
}
}
-fn parse_training_data(
+fn parse_input_data(
input: impl io::Read,
feature_size: usize,
-) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
- let reader = BufReader::new(input);
- let mut targets = Vec::<f64>::new();
- let mut features = Vec::new();
+) -> anyhow::Result<linalg::Matrix<f64>> {
+ let mut flattened_data = Vec::new();
+ let mut count = 0;
+ let reader = BufReader::new(input);
for line_result in reader.lines() {
let line = line_result?;
let trimed_line = line.trim();
anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
- log::debug!(trimed_line);
- let mut v: Vec<f64> = trimed_line
+ let v: Vec<f64> = trimed_line
.split(',')
.map(|x| x.parse::<f64>())
.collect::<std::result::Result<_, _>>()?;
anyhow::ensure!(
- v.len() == feature_size + 1,
+ v.len() == feature_size,
"Data format error: column len = {}, expected = {}",
v.len(),
- feature_size + 1
+ feature_size
);
- let label = v.swap_remove(feature_size);
- targets.push(label);
- features.extend(v);
+ flattened_data.extend(v);
+ count += 1;
}
- Ok((features, targets))
+ Ok(linalg::Matrix::new(count, feature_size, flattened_data))
}
#[cfg(feature = "enclave_unit_test")]
@@ -107,36 +110,35 @@ pub mod tests {
use teaclave_types::*;
pub fn run_tests() -> bool {
- run_tests!(test_logistic_regression_training)
+ run_tests!(test_logistic_regression_prediction)
}
- fn test_logistic_regression_training() {
- let func_args = FunctionArguments::new(hashmap! {
- "alg_alpha" => "0.3",
- "alg_iters" => "100",
- "feature_size" => "30"
- });
+ fn test_logistic_regression_prediction() {
+ let func_args = FunctionArguments::default();
- let base =
Path::new("fixtures/functions/logistic_regression_training");
- let training_data = base.join("train.txt");
- let plain_output = base.join("model.txt.out");
- let expected_output = base.join("expected_model.txt");
+ let base =
Path::new("fixtures/functions/logistic_regression_prediction");
+ let model = base.join("model.txt");
+ let plain_input = base.join("predict_input.txt");
+ let plain_output = base.join("predict_result.txt.out");
+ let expected_output = base.join("expected_result.txt");
let input_files = StagedFiles::new(hashmap!(
- TRAINING_DATA =>
- StagedFileInfo::new(&training_data, TeaclaveFile128Key::random()),
+ MODEL_FILE =>
+ StagedFileInfo::new(&model, TeaclaveFile128Key::random()),
+ INPUT_DATA =>
+ StagedFileInfo::new(&plain_input, TeaclaveFile128Key::random()),
));
let output_files = StagedFiles::new(hashmap!(
- OUT_MODEL_FILE =>
+ RESULT =>
StagedFileInfo::new(&plain_output, TeaclaveFile128Key::random())
));
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
- let function = LogitRegTraining;
+ let function = LogitRegPrediction;
let summary = function.execute(runtime, func_args).unwrap();
- assert_eq!(summary, "Trained 100 lines of data.");
+ assert_eq!(summary, "Predicted 5 lines of data.");
let result = fs::read_to_string(&plain_output).unwrap();
let expected = fs::read_to_string(&expected_output).unwrap();
diff --git a/function/src/logistic_regression_training.rs
b/function/src/logistic_regression_training.rs
index a2c2873..d509319 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_training.rs
@@ -59,6 +59,7 @@ impl TeaclaveFunction for LogitRegTraining {
let model_json = serde_json::to_string(&lr).unwrap();
let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
model_file.write_all(model_json.as_bytes())?;
+
Ok(format!("Trained {} lines of data.", data_size))
}
}
@@ -76,7 +77,6 @@ fn parse_training_data(
let trimed_line = line.trim();
anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
- log::debug!(trimed_line);
let mut v: Vec<f64> = trimed_line
.split(',')
.map(|x| x.parse::<f64>())
diff --git
a/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt
b/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt
new file mode 100644
index 0000000..ae59d4d
--- /dev/null
+++
b/tests/fixtures/functions/logistic_regression_prediction/expected_result.txt
@@ -0,0 +1,5 @@
+0.7530
+0.9163
+0.2041
+0.0094
+0.0426
diff --git a/tests/fixtures/functions/logistic_regression_prediction/model.txt
b/tests/fixtures/functions/logistic_regression_prediction/model.txt
new file mode 100644
index 0000000..45d02b7
--- /dev/null
+++ b/tests/fixtures/functions/logistic_regression_prediction/model.txt
@@ -0,0 +1 @@
+{"base":{"parameters":{"size":31,"data":[-0.7217673215572631,1.7917952971098938,0.5160567210379624,1.9611477804952018,-0.06059944514037786,-0.3922383186171758,-0.016518720611358107,0.5544320061020859,-0.0633262139909826,1.411549874207487,0.10834644397882805,0.5745352048596792,0.7267581684630282,0.3591176819552514,0.05218327931711743,0.2311147165162303,0.7200202688528744,-0.11124176185023933,0.20438197304348082,0.5711206726813367,0.4502424746039776,0.49126862948834127,0.5477485136765818,-
[...]
\ No newline at end of file
diff --git
a/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt
b/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt
new file mode 100644
index 0000000..4440050
--- /dev/null
+++ b/tests/fixtures/functions/logistic_regression_prediction/predict_input.txt
@@ -0,0 +1,5 @@
+1.5619560125603997,-1.3585291014731475,-1.929064964958864,-0.48811178352065915,-1.6298734512909983,0.8556691653018434,0.3798596856717057,0.5206354603638552,1.4795022936289703,-1.475871675153695,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
+0.4317415153486741,-1.2724398149784077,1.2065551475519039,0.30051381061013843,-1.126829867003464,-0.5463861825373719,1.0927733527526797,1.3761579259389451,-1.077581460405606,-0.66117948106943,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
+0.1831170075985541,0.8266275084009919,-1.6263223984375375,0.22082406698679818,1.1026886446611233,-1.079671043815752,0.6608823735448814,-0.5931674081381179,0.7960784158847922,-0.5670352239419173,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
+-1.7258296754127966,0.3190923551324158,-0.07759090516648384,0.47445362292910587,-0.43834833360941045,0.4858568905058413,-1.029447361090786,-0.6760396593910052,-0.9488385062478163,1.7400447870621698,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
+-1.3273559563065362,1.3048414649869655,-0.5504419191862819,0.48415113906417967,0.6153152869330243,0.5076592437240378,0.9637908015683003,1.0405852286895143,-0.48539483338526546,1.4724927513878685,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]