This is an automated email from the ASF dual-hosted git repository.
mssun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
The following commit(s) were added to refs/heads/master by this push:
new fc9cebd Use json type for function arguments instead of string/string
hashmap (#307)
fc9cebd is described below
commit fc9cebda32ab26f1eb2272887c51658470c4bb65
Author: Mingshen Sun <[email protected]>
AuthorDate: Wed May 20 19:55:24 2020 -0700
Use json type for function arguments instead of string/string hashmap (#307)
---
function/src/echo.rs | 13 ++-
function/src/gbdt_train.rs | 48 +++-----
function/src/logistic_regression_train.rs | 24 ++--
sdk/python/teaclave.py | 1 +
services/execution/enclave/src/service.rs | 28 +++--
services/management/enclave/src/service.rs | 3 +-
.../src/proto/teaclave_frontend_service.proto | 4 +-
services/proto/src/teaclave_frontend_service.rs | 8 +-
.../enclave/src/end_to_end/builtin_gbdt_train.rs | 24 ++--
tests/integration/enclave/src/teaclave_worker.rs | 24 ++--
types/src/staged_function.rs | 129 +++++++--------------
11 files changed, 124 insertions(+), 182 deletions(-)
diff --git a/function/src/echo.rs b/function/src/echo.rs
index a6de293..2d41057 100644
--- a/function/src/echo.rs
+++ b/function/src/echo.rs
@@ -24,6 +24,7 @@ use teaclave_types::{FunctionArguments, FunctionRuntime};
#[derive(Default)]
pub struct Echo;
+#[derive(serde::Deserialize)]
struct EchoArguments {
message: String,
}
@@ -32,8 +33,8 @@ impl TryFrom<FunctionArguments> for EchoArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
- let message = arguments.get("message")?.to_string();
- Ok(Self { message })
+ use anyhow::Context;
+ serde_json::from_str(&arguments.into_string()).context("Cannot
deserialize arguments")
}
}
@@ -57,6 +58,7 @@ impl Echo {
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
+ use serde_json::json;
use teaclave_runtime::*;
use teaclave_test_utils::*;
use teaclave_types::*;
@@ -66,9 +68,10 @@ pub mod tests {
}
fn test_echo() {
- let args = FunctionArguments::new(hashmap!(
- "message" => "Hello Teaclave!"
- ));
+ let args = FunctionArguments::from_json(json!({
+ "message": "Hello Teaclave!"
+ }))
+ .unwrap();
let input_files = StagedFiles::default();
let output_files = StagedFiles::default();
diff --git a/function/src/gbdt_train.rs b/function/src/gbdt_train.rs
index b887427..7f63330 100644
--- a/function/src/gbdt_train.rs
+++ b/function/src/gbdt_train.rs
@@ -34,6 +34,7 @@ const OUT_MODEL: &str = "trained_model";
#[derive(Default)]
pub struct GbdtTrain;
+#[derive(serde::Deserialize)]
struct GbdtTrainArguments {
feature_size: usize,
max_depth: u32,
@@ -50,27 +51,8 @@ impl TryFrom<FunctionArguments> for GbdtTrainArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
- let feature_size = arguments.get("feature_size")?.as_usize()?;
- let max_depth = arguments.get("max_depth")?.as_u32()?;
- let iterations = arguments.get("iterations")?.as_usize()?;
- let shrinkage = arguments.get("shrinkage")?.as_f32()?;
- let feature_sample_ratio =
arguments.get("feature_sample_ratio")?.as_f64()?;
- let data_sample_ratio = arguments.get("data_sample_ratio")?.as_f64()?;
- let min_leaf_size = arguments.get("min_leaf_size")?.as_usize()?;
- let loss = arguments.get("loss")?.as_str().to_owned();
- let training_optimization_level =
arguments.get("training_optimization_level")?.as_u8()?;
-
- Ok(Self {
- feature_size,
- max_depth,
- iterations,
- shrinkage,
- feature_sample_ratio,
- data_sample_ratio,
- min_leaf_size,
- loss,
- training_optimization_level,
- })
+ use anyhow::Context;
+ serde_json::from_str(&arguments.into_string()).context("Cannot
deserialize arguments")
}
}
@@ -164,6 +146,7 @@ fn parse_training_data(input: impl io::Read, feature_size:
usize) -> anyhow::Res
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
+ use serde_json::json;
use std::untrusted::fs;
use teaclave_crypto::*;
use teaclave_runtime::*;
@@ -175,17 +158,18 @@ pub mod tests {
}
fn test_gbdt_train() {
- let arguments = FunctionArguments::new(hashmap!(
- "feature_size" => "4",
- "max_depth" => "4",
- "iterations" => "100",
- "shrinkage" => "0.1",
- "feature_sample_ratio" => "1.0",
- "data_sample_ratio" => "1.0",
- "min_leaf_size" => "1",
- "loss" => "LAD",
- "training_optimization_level" => "2"
- ));
+ let arguments = FunctionArguments::from_json(json!({
+ "feature_size": 4,
+ "max_depth": 4,
+ "iterations": 100,
+ "shrinkage": 0.1,
+ "feature_sample_ratio": 1.0,
+ "data_sample_ratio": 1.0,
+ "min_leaf_size": 1,
+ "loss": "LAD",
+ "training_optimization_level": 2
+ }))
+ .unwrap();
let plain_input = "fixtures/functions/gbdt_training/train.txt";
let plain_output =
"fixtures/functions/gbdt_training/training_model.txt.out";
diff --git a/function/src/logistic_regression_train.rs
b/function/src/logistic_regression_train.rs
index 5df3bbe..098d063 100644
--- a/function/src/logistic_regression_train.rs
+++ b/function/src/logistic_regression_train.rs
@@ -35,6 +35,7 @@ const OUT_MODEL_FILE: &str = "model_file";
#[derive(Default)]
pub struct LogisticRegressionTrain;
+#[derive(serde::Deserialize)]
struct LogisticRegressionTrainArguments {
alg_alpha: f64,
alg_iters: usize,
@@ -45,15 +46,8 @@ impl TryFrom<FunctionArguments> for
LogisticRegressionTrainArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
- let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
- let alg_iters = arguments.get("alg_iters")?.as_usize()?;
- let feature_size = arguments.get("feature_size")?.as_usize()?;
-
- Ok(Self {
- alg_alpha,
- alg_iters,
- feature_size,
- })
+ use anyhow::Context;
+ serde_json::from_str(&arguments.into_string()).context("Cannot
deserialize arguments")
}
}
@@ -125,6 +119,7 @@ fn parse_training_data(
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
+ use serde_json::json;
use std::path::Path;
use std::untrusted::fs;
use teaclave_crypto::*;
@@ -137,11 +132,12 @@ pub mod tests {
}
fn test_logistic_regression_train() {
- let arguments = FunctionArguments::new(hashmap! {
- "alg_alpha" => "0.3",
- "alg_iters" => "100",
- "feature_size" => "30"
- });
+ let arguments = FunctionArguments::from_json(json!({
+ "alg_alpha": 0.3,
+ "alg_iters": 100,
+ "feature_size": 30
+ }))
+ .unwrap();
let base =
Path::new("fixtures/functions/logistic_regression_training");
let training_data = base.join("train.txt");
diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py
index 7e1e4c5..1819ca0 100644
--- a/sdk/python/teaclave.py
+++ b/sdk/python/teaclave.py
@@ -158,6 +158,7 @@ class FrontendClient:
executor,
inputs_ownership=[],
outputs_ownership=[]):
+ function_arguments = json.dumps(function_arguments)
request = CreateTaskRequest(self.metadata, function_id,
function_arguments, executor,
inputs_ownership, outputs_ownership)
diff --git a/services/execution/enclave/src/service.rs
b/services/execution/enclave/src/service.rs
index 18ce902..1928679 100644
--- a/services/execution/enclave/src/service.rs
+++ b/services/execution/enclave/src/service.rs
@@ -177,6 +177,7 @@ fn finalize_task(file_mgr: &TaskFileManager) ->
Result<HashMap<String, FileAuthT
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
+ use serde_json::json;
use std::format;
use teaclave_crypto::*;
use url::Url;
@@ -184,11 +185,13 @@ pub mod tests {
pub fn test_invoke_echo() {
let task_id = Uuid::new_v4();
+ let function_arguments =
+ FunctionArguments::from_json(json!({"message": "Hello,
Teaclave!"})).unwrap();
let staged_task = StagedTask::new()
.task_id(task_id)
.executor(Executor::Builtin)
.function_name("builtin-echo")
- .function_arguments(hashmap!("message" => "Hello, Teaclave!"));
+ .function_arguments(function_arguments);
let file_mgr = TaskFileManager::new(
WORKER_BASE_DIR,
@@ -210,17 +213,18 @@ pub mod tests {
pub fn test_invoke_gbdt_train() {
let task_id = Uuid::new_v4();
- let function_arguments = FunctionArguments::new(hashmap!(
- "feature_size" => "4",
- "max_depth" => "4",
- "iterations" => "100",
- "shrinkage" => "0.1",
- "feature_sample_ratio" => "1.0",
- "data_sample_ratio" => "1.0",
- "min_leaf_size" => "1",
- "loss" => "LAD",
- "training_optimization_level" => "2",
- ));
+ let function_arguments = FunctionArguments::from_json(json!({
+ "feature_size": 4,
+ "max_depth": 4,
+ "iterations": 100,
+ "shrinkage": 0.1,
+ "feature_sample_ratio": 1.0,
+ "data_sample_ratio": 1.0,
+ "min_leaf_size": 1,
+ "loss": "LAD",
+ "training_optimization_level": 2,
+ }))
+ .unwrap();
let fixture_dir = format!(
"file:///{}/fixtures/functions/gbdt_training",
env!("TEACLAVE_TEST_INSTALL_DIR")
diff --git a/services/management/enclave/src/service.rs
b/services/management/enclave/src/service.rs
index 35311a6..49f436f 100644
--- a/services/management/enclave/src/service.rs
+++ b/services/management/enclave/src/service.rs
@@ -576,6 +576,7 @@ impl TeaclaveManagementService {
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
+ use serde_json::json;
use std::collections::HashMap;
use teaclave_types::{
hashmap, Executor, FileAuthTag, FileCrypto, FunctionArguments,
FunctionInput,
@@ -631,7 +632,7 @@ pub mod tests {
.arguments(vec!["arg".to_string()])
.public(true)
.owner("mock_user");
- let function_arguments = FunctionArguments::new(hashmap!("arg" =>
"data"));
+ let function_arguments = FunctionArguments::from_json(json!({"arg":
"data"})).unwrap();
let task = Task::<Create>::new(
UserID::from("mock_user"),
diff --git a/services/proto/src/proto/teaclave_frontend_service.proto
b/services/proto/src/proto/teaclave_frontend_service.proto
index 6a3dcb3..cb2ce90 100644
--- a/services/proto/src/proto/teaclave_frontend_service.proto
+++ b/services/proto/src/proto/teaclave_frontend_service.proto
@@ -110,7 +110,7 @@ message DataMap {
message CreateTaskRequest {
string function_id = 1;
- map<string, string> function_arguments = 2;
+ string function_arguments = 2;
string executor = 3;
repeated OwnerList inputs_ownership = 10;
repeated OwnerList outputs_ownership= 11;
@@ -129,7 +129,7 @@ message GetTaskResponse {
string creator = 2;
string function_id = 3;
string function_owner = 4;
- map<string, string> function_arguments = 5;
+ string function_arguments = 5;
repeated OwnerList inputs_ownership = 6;
repeated OwnerList outputs_ownership = 7;
repeated string participants = 8;
diff --git a/services/proto/src/teaclave_frontend_service.rs
b/services/proto/src/teaclave_frontend_service.rs
index cf36589..0b8db6e 100644
--- a/services/proto/src/teaclave_frontend_service.rs
+++ b/services/proto/src/teaclave_frontend_service.rs
@@ -958,7 +958,7 @@ impl std::convert::TryFrom<proto::CreateTaskRequest> for
CreateTaskRequest {
type Error = Error;
fn try_from(proto: proto::CreateTaskRequest) -> Result<Self> {
- let function_arguments = proto.function_arguments.into();
+ let function_arguments: FunctionArguments =
proto.function_arguments.try_into()?;
let inputs_ownership = from_proto_ownership(proto.inputs_ownership);
let outputs_ownership = from_proto_ownership(proto.outputs_ownership);
let function_id = proto.function_id.try_into()?;
@@ -977,7 +977,7 @@ impl std::convert::TryFrom<proto::CreateTaskRequest> for
CreateTaskRequest {
impl From<CreateTaskRequest> for proto::CreateTaskRequest {
fn from(request: CreateTaskRequest) -> Self {
- let function_arguments = request.function_arguments.into();
+ let function_arguments = request.function_arguments.into_string();
let inputs_ownership = to_proto_ownership(request.inputs_ownership);
let outputs_ownership = to_proto_ownership(request.outputs_ownership);
@@ -1054,7 +1054,7 @@ impl std::convert::TryFrom<proto::GetTaskResponse> for
GetTaskResponse {
type Error = Error;
fn try_from(proto: proto::GetTaskResponse) -> Result<Self> {
- let function_arguments = proto.function_arguments.into();
+ let function_arguments: FunctionArguments =
proto.function_arguments.try_into()?;
let inputs_ownership = from_proto_ownership(proto.inputs_ownership);
let outputs_ownership = from_proto_ownership(proto.outputs_ownership);
let assigned_inputs = from_proto_file_ids(proto.assigned_inputs)?;
@@ -1086,7 +1086,7 @@ impl std::convert::TryFrom<proto::GetTaskResponse> for
GetTaskResponse {
impl From<GetTaskResponse> for proto::GetTaskResponse {
fn from(response: GetTaskResponse) -> Self {
- let function_arguments = response.function_arguments.into();
+ let function_arguments = response.function_arguments.into_string();
let inputs_ownership = to_proto_ownership(response.inputs_ownership);
let outputs_ownership = to_proto_ownership(response.outputs_ownership);
let assigned_inputs = to_proto_file_ids(response.assigned_inputs);
diff --git a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
index 2eae2bf..0c35a68 100644
--- a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
+++ b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
@@ -103,20 +103,22 @@ fn create_gbdt_training_task(
client: &mut TeaclaveFrontendClient,
function_id: &ExternalID,
) -> ExternalID {
+ let arguments = FunctionArguments::from_json(serde_json::json!({
+ "feature_size": 4,
+ "max_depth": 4,
+ "iterations": 100,
+ "shrinkage": 0.1,
+ "feature_sample_ratio": 1.0,
+ "data_sample_ratio": 1.0,
+ "min_leaf_size": 1,
+ "loss": "LAD",
+ "training_optimization_level": 2
+ }))
+ .unwrap();
let request = CreateTaskRequest::new()
.executor(Executor::Builtin)
.function_id(function_id.clone())
- .function_arguments(hashmap!(
- "feature_size" => "4",
- "max_depth" => "4",
- "iterations" => "100",
- "shrinkage" => "0.1",
- "feature_sample_ratio" => "1.0",
- "data_sample_ratio" => "1.0",
- "min_leaf_size" => "1",
- "loss" => "LAD",
- "training_optimization_level" => "2"
- ))
+ .function_arguments(arguments)
.inputs_ownership(hashmap!("training_data" => vec![USERNAME]))
.outputs_ownership(hashmap!("trained_model" => vec![USERNAME]));
diff --git a/tests/integration/enclave/src/teaclave_worker.rs
b/tests/integration/enclave/src/teaclave_worker.rs
index 09b9f5d..35b0547 100644
--- a/tests/integration/enclave/src/teaclave_worker.rs
+++ b/tests/integration/enclave/src/teaclave_worker.rs
@@ -17,6 +17,7 @@
use std::prelude::v1::*;
+use serde_json::json;
use teaclave_crypto::TeaclaveFile128Key;
use teaclave_types::{
hashmap, read_all_bytes, Executor, ExecutorType, FileAuthTag,
FunctionArguments,
@@ -25,17 +26,18 @@ use teaclave_types::{
use teaclave_worker::Worker;
fn test_start_worker() {
- let arguments = FunctionArguments::new(hashmap!(
- "feature_size" => "4",
- "max_depth" => "4",
- "iterations" => "100",
- "shrinkage" => "0.1",
- "feature_sample_ratio" => "1.0",
- "data_sample_ratio" => "1.0",
- "min_leaf_size" => "1",
- "loss" => "LAD",
- "training_optimization_level" => "2"
- ));
+ let arguments = FunctionArguments::from_json(json!({
+ "feature_size": 4,
+ "max_depth": 4,
+ "iterations": 100,
+ "shrinkage": 0.1,
+ "feature_sample_ratio": 1.0,
+ "data_sample_ratio": 1.0,
+ "min_leaf_size": 1,
+ "loss": "LAD",
+ "training_optimization_level": 2
+ }))
+ .unwrap();
let plain_input = "fixtures/functions/gbdt_training/train.txt";
let enc_output = "fixtures/functions/gbdt_training/model.enc.out";
diff --git a/types/src/staged_function.rs b/types/src/staged_function.rs
index 47f3bd9..284191a 100644
--- a/types/src/staged_function.rs
+++ b/types/src/staged_function.rs
@@ -20,103 +20,22 @@ use crate::{Executor, ExecutorType, StagedFiles,
TeaclaveRuntime};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::prelude::v1::*;
-use std::str::FromStr;
use anyhow::{Context, Result};
pub type FunctionRuntime = Box<dyn TeaclaveRuntime + Send + Sync>;
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ArgumentValue {
- inner: String,
-}
-
-impl From<String> for ArgumentValue {
- fn from(value: String) -> Self {
- ArgumentValue::new(value)
- }
-}
-
-impl From<&str> for ArgumentValue {
- fn from(value: &str) -> Self {
- ArgumentValue::new(value.into())
- }
-}
-
-impl From<&String> for ArgumentValue {
- fn from(value: &String) -> Self {
- ArgumentValue::new(value.into())
- }
-}
-
-impl From<ArgumentValue> for String {
- fn from(value: ArgumentValue) -> Self {
- value.as_str().to_owned()
- }
-}
-
-impl ArgumentValue {
- pub fn new(value: String) -> Self {
- Self { inner: value }
- }
-
- pub fn inner(&self) -> &String {
- &self.inner
- }
-
- pub fn as_str(&self) -> &str {
- &self.inner
- }
-
- pub fn as_usize(&self) -> Result<usize> {
- usize::from_str(&self.inner).with_context(|| format!("cannot parse
{}", self.inner))
- }
-
- pub fn as_u32(&self) -> Result<u32> {
- u32::from_str(&self.inner).with_context(|| format!("cannot parse {}",
self.inner))
- }
-
- pub fn as_f32(&self) -> Result<f32> {
- f32::from_str(&self.inner).with_context(|| format!("cannot parse {}",
self.inner))
- }
-
- pub fn as_f64(&self) -> Result<f64> {
- f64::from_str(&self.inner).with_context(|| format!("cannot parse {}",
self.inner))
- }
-
- pub fn as_u8(&self) -> Result<u8> {
- u8::from_str(&self.inner).with_context(|| format!("cannot parse {}",
self.inner))
- }
-}
-
-impl std::fmt::Display for ArgumentValue {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.inner)
- }
-}
+type ArgumentValue = serde_json::Value;
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
pub struct FunctionArguments {
#[serde(flatten)]
- inner: HashMap<String, ArgumentValue>,
-}
-
-impl<S: core::default::Default + std::hash::BuildHasher>
From<FunctionArguments>
- for HashMap<String, String, S>
-{
- fn from(arguments: FunctionArguments) -> Self {
- arguments
- .inner()
- .iter()
- .map(|(k, v)| (k.to_owned(), v.as_str().to_owned()))
- .collect()
- }
+ inner: serde_json::Map<String, ArgumentValue>,
}
impl From<HashMap<String, String>> for FunctionArguments {
fn from(map: HashMap<String, String>) -> Self {
- let inner = map.iter().fold(HashMap::new(), |mut acc, (k, v)| {
- acc.insert(k.into(), v.into());
+ let inner = map.iter().fold(serde_json::Map::new(), |mut acc, (k, v)| {
+ acc.insert(k.to_owned(), v.to_owned().into());
acc
});
@@ -124,16 +43,39 @@ impl From<HashMap<String, String>> for FunctionArguments {
}
}
+impl std::convert::TryFrom<String> for FunctionArguments {
+ type Error = anyhow::Error;
+
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ let v: ArgumentValue = serde_json::from_str(&s)?;
+ let inner = match v {
+ ArgumentValue::Object(o) => o,
+ _ => anyhow::bail!("Cannot convert to function arguments"),
+ };
+
+ Ok(Self { inner })
+ }
+}
+
impl FunctionArguments {
- pub fn new(map: HashMap<String, ArgumentValue>) -> Self {
- Self { inner: map }
+ pub fn from_json(json: ArgumentValue) -> Result<Self> {
+ let inner = match json {
+ ArgumentValue::Object(o) => o,
+ _ => anyhow::bail!("Not an json object"),
+ };
+
+ Ok(Self { inner })
+ }
+
+ pub fn from_map(map: HashMap<String, String>) -> Self {
+ map.into()
}
- pub fn inner(&self) -> &HashMap<String, ArgumentValue> {
+ pub fn inner(&self) -> &serde_json::Map<String, ArgumentValue> {
&self.inner
}
- pub fn inner_mut(&mut self) -> &mut HashMap<String, ArgumentValue> {
+ pub fn inner_mut(&mut self) -> &mut serde_json::Map<String, ArgumentValue>
{
&mut self.inner
}
@@ -148,11 +90,18 @@ impl FunctionArguments {
self.inner.into_iter().for_each(|(k, v)| {
vector.push(k);
- vector.push(v.to_string());
+ match v {
+ ArgumentValue::String(s) => vector.push(s),
+ _ => vector.push(v.to_string()),
+ }
});
vector
}
+
+ pub fn into_string(self) -> String {
+ ArgumentValue::Object(self.inner).to_string()
+ }
}
#[derive(Debug, Default)]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]