This is an automated email from the ASF dual-hosted git repository.
mssun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
The following commit(s) were added to refs/heads/master by this push:
new 7c51791 Support optional input/output files (#603)
7c51791 is described below
commit 7c517915745425de53730641172f770838605589
Author: Qinkun Bao <[email protected]>
AuthorDate: Thu Feb 3 22:35:32 2022 -0800
Support optional input/output files (#603)
---
cmake/scripts/test.sh | 1 +
examples/c/builtin_ordered_set_intersect.c | 8 +-
examples/python/mesapy_optional_files.py | 216 +++++++++++++++++++++
examples/python/mesapy_optional_files_payload.py | 27 +++
.../rust/builtin_ordered_set_intersect/src/main.rs | 8 +-
sdk/python/teaclave.py | 8 +-
services/management/enclave/src/service.rs | 23 +--
.../src/proto/teaclave_frontend_service.proto | 2 +
services/proto/src/teaclave_frontend_service.rs | 4 +
.../enclave/src/end_to_end/builtin_gbdt_train.rs | 4 +-
.../enclave/src/end_to_end/mesapy_data_fusion.rs | 8 +-
tests/functional/enclave/src/management_service.rs | 24 +--
types/src/function.rs | 8 +-
types/src/task_state.rs | 21 +-
14 files changed, 316 insertions(+), 46 deletions(-)
diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh
index a2eb075..3b18f33 100755
--- a/cmake/scripts/test.sh
+++ b/cmake/scripts/test.sh
@@ -239,6 +239,7 @@ run_examples() {
python3 builtin_echo.py
python3 mesapy_echo.py
python3 mesapy_logistic_reg.py
+ python3 mesapy_optional_files.py
python3 builtin_gbdt_train.py
python3 builtin_online_decrypt.py
python3 builtin_private_join_and_compute.py
diff --git a/examples/c/builtin_ordered_set_intersect.c
b/examples/c/builtin_ordered_set_intersect.c
index c93d6cd..1ce5e57 100644
--- a/examples/c/builtin_ordered_set_intersect.c
+++ b/examples/c/builtin_ordered_set_intersect.c
@@ -62,12 +62,12 @@ const char *register_function_request_serialized = QUOTE(
"payload": [],
"arguments": ["order"],
"inputs": [
- {"name": "input_data1", "description": "Client 0 data."},
- {"name": "input_data2", "description": "Client 1 data."}
+ {"name": "input_data1", "description": "Client 0 data.", "optional":
false},
+ {"name": "input_data2", "description": "Client 1 data.", "optional":
false}
],
"outputs": [
- {"name": "output_result1", "description": "Output data."},
- {"name": "output_result2", "description": "Output data."}
+ {"name": "output_result1", "description": "Output data.", "optional":
false},
+ {"name": "output_result2", "description": "Output data.", "optional":
false}
],
"user_allowlist": ["user0", "user1"]
});
diff --git a/examples/python/mesapy_optional_files.py
b/examples/python/mesapy_optional_files.py
new file mode 100644
index 0000000..add7ceb
--- /dev/null
+++ b/examples/python/mesapy_optional_files.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+from teaclave import FunctionInput, FunctionOutput, OwnerList, DataMap
+from utils import USER_ID, USER_PASSWORD, connect_authentication_service,
connect_frontend_service, PlatformAdmin
+
+
+class OptionalFilesTwoParticipants:
+
+ def __init__(self, user_id1, user_password1, user_id2, user_password2):
+ self.user_id1 = user_id1
+ self.user_password1 = user_password1
+ with connect_authentication_service() as client1:
+ token1 = client1.user_login(self.user_id1, self.user_password1)
+
+ self.client1 = connect_frontend_service()
+ metadata = {"id": self.user_id1, "token": token1}
+ self.client1.metadata = metadata
+
+ self.user_id2 = user_id2
+ self.user_password2 = user_password2
+ with connect_authentication_service() as client2:
+ token2 = client2.user_login(self.user_id2, self.user_password2)
+
+ self.client2 = connect_frontend_service()
+ metadata = {"id": self.user_id2, "token": token2}
+ self.client2.metadata = metadata
+
+ def validate_task(self, task):
+ # The task has data from two parties
+ if len(task["inputs_ownership"]) != 2:
+ return False
+ # The data is from user_a and user_b
+ if (task["inputs_ownership"][1]['uids'] == ['user_a']
+ and task["inputs_ownership"][0]['uids'] == ['user_b']) or (
+ task["inputs_ownership"][1]['uids'] == ['user_b']
+ and task["inputs_ownership"][0]['uids'] == ['user_a']):
+ return True
+ else:
+ return False
+
+ def run_task(self, function_id):
+ client1 = self.client1
+ client2 = self.client2
+ print("[+] registering input file")
+ url =
"http://localhost:6789/fixtures/functions/gbdt_training/train.enc"
+ cmac = [
+ 0x88, 0x1a, 0xdc, 0xa6, 0xb0, 0x52, 0x44, 0x72, 0xda, 0x0a, 0x9d,
+ 0x0b, 0xb0, 0x2b, 0x9a, 0xf9
+ ]
+ schema = "teaclave-file-128"
+ key = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ iv = []
+ input_file_id1 = client1.register_input_file(url, schema, key, iv,
+ cmac)
+ input_file_id2 = client2.register_input_file(url, schema, key, iv,
+ cmac)
+
+ print("[+] creating task")
+ task_id = client1.create_task(function_id=function_id,
+ function_arguments={},
+ executor="mesapy",
+ inputs_ownership=[
+ OwnerList("input_data1",
+ [self.user_id1]),
+ OwnerList("input_data2",
+ [self.user_id2])
+ ])
+ client1.assign_data_to_task(task_id,
+ [DataMap("input_data1", input_file_id1)],
+ [])
+ client2.assign_data_to_task(task_id,
+ [DataMap("input_data2", input_file_id2)],
+ [])
+
+ task1 = client1.get_task(task_id)
+ if self.validate_task(task1):
+ print("User_a approves the task")
+ client1.approve_task(task_id)
+ else:
+ print("User_a disapproves the task")
+
+ task2 = client2.get_task(task_id)
+ if self.validate_task(task2):
+ print("User_b approves the task")
+ client2.approve_task(task_id)
+ else:
+ print("User_b disapproves the task")
+
+ client1.invoke_task(task_id)
+ result = client1.get_task_result(task_id)
+ return bytes(result)
+
+
+class OptionalFilesOneParticipant:
+
+ def __init__(self, user_id, user_password):
+ self.user_id = user_id
+ self.user_password = user_password
+ with connect_authentication_service() as client:
+ print(f"[+] {self.user_id} login")
+ token = client.user_login(self.user_id, self.user_password)
+
+ self.client = connect_frontend_service()
+ metadata = {"id": self.user_id, "token": token}
+ self.client.metadata = metadata
+
+ # The function template defines three optional input files
+ def register_function(self):
+ print("[+] registering function")
+ with open("mesapy_optional_files_payload.py", "rb") as f:
+ payload = f.read()
+ function_id = self.client.register_function(
+ name="mesapy-echo",
+ description="An echo function implemented in Python",
+ executor_type="python",
+ payload=list(payload),
+ arguments=[],
+ inputs=[
+ FunctionInput("input_data1", "Client 0 data.", True),
+ FunctionInput("input_data2", "Client 1 data.", True),
+ FunctionInput("input_data3", "Client 2 data.", True)
+ ])
+ return function_id
+
+ def with_input(self, user_id, function_id):
+ client = self.client
+ print("[+] registering input file")
+ url =
"http://localhost:6789/fixtures/functions/gbdt_training/train.enc"
+ cmac = [
+ 0x88, 0x1a, 0xdc, 0xa6, 0xb0, 0x52, 0x44, 0x72, 0xda, 0x0a, 0x9d,
+ 0x0b, 0xb0, 0x2b, 0x9a, 0xf9
+ ]
+ schema = "teaclave-file-128"
+ key = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ iv = []
+ input_file_id = client.register_input_file(url, schema, key, iv, cmac)
+ print("[+] creating task")
+ task_id = client.create_task(
+ function_id=function_id,
+ function_arguments={},
+ executor="mesapy",
+ inputs_ownership=[OwnerList("input_data1", [user_id])])
+ print("[+] assigning data to task")
+ client.assign_data_to_task(task_id,
+ [DataMap("input_data1", input_file_id)], [])
+
+ print("[+] approving task")
+ client.approve_task(task_id)
+
+ print("[+] invoking task")
+ client.invoke_task(task_id)
+ print("[+] getting result")
+ result = client.get_task_result(task_id)
+ print("[+] done")
+ return bytes(result)
+
+ def without_input(self, user_id, function_id):
+ print("[+] creating task")
+ client = self.client
+ task_id = client.create_task(function_id=function_id,
+ function_arguments={},
+ executor="mesapy")
+
+ print("[+] invoking task")
+ client.invoke_task(task_id)
+
+ print("[+] getting result")
+ result = client.get_task_result(task_id)
+ print("[+] done")
+ return bytes(result)
+
+
+def main():
+
+ platform_admin = PlatformAdmin("admin", "teaclave")
+ try:
+ platform_admin.register_user("user_a", "password")
+ platform_admin.register_user("user_b", "password")
+ except Exception:
+ pass
+ task = OptionalFilesOneParticipant("user_a", "password")
+ function_id = task.register_function()
+ print("Data owners do not register input files")
+ rt = task.without_input("user_a", function_id)
+ print(rt)
+ print("Data owners register input files")
+ rt = task.with_input("user_a", function_id)
+ print(rt)
+ print("The task has more than more participants")
+ task = OptionalFilesTwoParticipants("user_a", "password", "user_b",
+ "password")
+ rt = task.run_task(function_id)
+ print(rt)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/python/mesapy_optional_files_payload.py
b/examples/python/mesapy_optional_files_payload.py
new file mode 100644
index 0000000..9650952
--- /dev/null
+++ b/examples/python/mesapy_optional_files_payload.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+def entrypoint(argv):
+ try:
+ f = teaclave_open("input_data1", "rb")
+ data = f.readlines()
+ return data[0]
+ except:
+ return "input_data1 does not exist"
diff --git a/examples/rust/builtin_ordered_set_intersect/src/main.rs
b/examples/rust/builtin_ordered_set_intersect/src/main.rs
index 658784e..64960f2 100644
--- a/examples/rust/builtin_ordered_set_intersect/src/main.rs
+++ b/examples/rust/builtin_ordered_set_intersect/src/main.rs
@@ -120,12 +120,12 @@ impl Client {
None,
Some(&["order"]),
Some(vec![
- teaclave_client_sdk::FunctionInput::new("input_data1", "Client
0 data."),
- teaclave_client_sdk::FunctionInput::new("input_data2", "Client
1 data."),
+ teaclave_client_sdk::FunctionInput::new("input_data1", "Client
0 data.", false),
+ teaclave_client_sdk::FunctionInput::new("input_data2", "Client
1 data.", false),
]),
Some(vec![
- teaclave_client_sdk::FunctionOutput::new("output_result1",
"Output data."),
- teaclave_client_sdk::FunctionOutput::new("output_result2",
"Output data."),
+ teaclave_client_sdk::FunctionOutput::new("output_result1",
"Output data.", false),
+ teaclave_client_sdk::FunctionOutput::new("output_result2",
"Output data.", false),
]),
)?;
self.client.get_function(&function_id)?;
diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py
index ed9b1f1..39b650e 100644
--- a/sdk/python/teaclave.py
+++ b/sdk/python/teaclave.py
@@ -209,11 +209,13 @@ class FunctionInput:
Args:
name: Name of input data.
description: Description of the input data.
+ optional: [Default: False] Data owners do not need to register the
data.
"""
- def __init__(self, name: str, description: str):
+ def __init__(self, name: str, description: str, optional=False):
self.name = name
self.description = description
+ self.optional = optional
class FunctionOutput:
@@ -222,11 +224,13 @@ class FunctionOutput:
Args:
name: Name of output data.
description: Description of the output data.
+ optional: [Default: False] Data owners do not need to register the
data.
"""
- def __init__(self, name: str, description: str):
+ def __init__(self, name: str, description: str, optional=False):
self.name = name
self.description = description
+ self.optional = optional
class OwnerList:
diff --git a/services/management/enclave/src/service.rs
b/services/management/enclave/src/service.rs
index 675fe9a..55e730d 100644
--- a/services/management/enclave/src/service.rs
+++ b/services/management/enclave/src/service.rs
@@ -438,8 +438,8 @@ impl TeaclaveManagement for TeaclaveManagementService {
// access control: none
// when a task is created, following rules will be verified:
// 1) arugments match function definition
- // 2) input match function definition
- // 3) output match function definition
+ // 2) input files match function definition
+ // 3) output files match function definition
// 4) requested user_id in the user_allowlist
fn create_task(
&self,
@@ -466,7 +466,6 @@ impl TeaclaveManagement for TeaclaveManagementService {
return
Err(TeaclaveManagementServiceError::PermissionDenied.into());
}
}
-
let task = Task::<Create>::new(
user_id,
request.executor,
@@ -478,7 +477,6 @@ impl TeaclaveManagement for TeaclaveManagementService {
.map_err(|_| TeaclaveManagementServiceError::BadTask)?;
log::debug!("CreateTask: {:?}", task);
-
let ts: TaskState = task.into();
self.write_to_db(&ts)
.map_err(|_| TeaclaveManagementServiceError::StorageError)?;
@@ -643,11 +641,8 @@ impl TeaclaveManagement for TeaclaveManagementService {
})?;
log::debug!("InvokeTask: get task: {:?}", task);
-
let staged_task = task.stage_for_running(&user_id, function)?;
-
log::debug!("InvokeTask: staged task: {:?}", staged_task);
-
self.enqueue_to_db(StagedTask::get_queue_key().as_bytes(),
&staged_task)?;
let ts: TaskState = task.into();
@@ -831,10 +826,10 @@ impl TeaclaveManagementService {
input_file.uuid =
Uuid::parse_str("00000000-0000-0000-0000-000000000002")?;
self.write_to_db(&input_file)?;
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
- let function_input2 = FunctionInput::new("input2", "input_desc");
- let function_output2 = FunctionOutput::new("output2", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc",
false);
+ let function_input2 = FunctionInput::new("input2", "input_desc",
false);
+ let function_output2 = FunctionOutput::new("output2", "output_desc",
false);
let function = FunctionBuilder::new()
.id(Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap())
@@ -850,7 +845,7 @@ impl TeaclaveManagementService {
self.write_to_db(&function)?;
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_output = FunctionOutput::new("output", "output_desc",
false);
let function = FunctionBuilder::new()
.id(Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap())
.name("mock-func-2")
@@ -912,8 +907,8 @@ pub mod tests {
}
pub fn handle_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc",
false);
let function = FunctionBuilder::new()
.id(Uuid::new_v4())
.name("mock_function")
diff --git a/services/proto/src/proto/teaclave_frontend_service.proto
b/services/proto/src/proto/teaclave_frontend_service.proto
index 1dfb038..71d3628 100644
--- a/services/proto/src/proto/teaclave_frontend_service.proto
+++ b/services/proto/src/proto/teaclave_frontend_service.proto
@@ -98,11 +98,13 @@ message GetInputFileResponse {
message FunctionInput {
string name = 1;
string description = 2;
+ bool optional = 3;
}
message FunctionOutput {
string name = 1;
string description = 2;
+ bool optional = 3;
}
message OwnerList {
diff --git a/services/proto/src/teaclave_frontend_service.rs
b/services/proto/src/teaclave_frontend_service.rs
index ba5d870..26d3198 100644
--- a/services/proto/src/teaclave_frontend_service.rs
+++ b/services/proto/src/teaclave_frontend_service.rs
@@ -1051,6 +1051,7 @@ impl std::convert::TryFrom<proto::FunctionInput> for
FunctionInput {
let ret = Self {
name: proto.name,
description: proto.description,
+ optional: proto.optional,
};
Ok(ret)
@@ -1062,6 +1063,7 @@ impl From<FunctionInput> for proto::FunctionInput {
Self {
name: input.name,
description: input.description,
+ optional: input.optional,
}
}
}
@@ -1073,6 +1075,7 @@ impl std::convert::TryFrom<proto::FunctionOutput> for
FunctionOutput {
let ret = Self {
name: proto.name,
description: proto.description,
+ optional: proto.optional,
};
Ok(ret)
@@ -1084,6 +1087,7 @@ impl From<FunctionOutput> for proto::FunctionOutput {
Self {
name: output.name,
description: output.description,
+ optional: output.optional,
}
}
}
diff --git a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
index 675c61d..b323666 100644
--- a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
+++ b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
@@ -46,8 +46,8 @@ fn authorized_frontend_client() -> TeaclaveFrontendClient {
}
fn register_gbdt_function(client: &mut TeaclaveFrontendClient) -> ExternalID {
- let fn_input = FunctionInput::new("training_data", "Input traning data
file.");
- let fn_output = FunctionOutput::new("trained_model", "Output trained
model.");
+ let fn_input = FunctionInput::new("training_data", "Input traning data
file.", false);
+ let fn_output = FunctionOutput::new("trained_model", "Output trained
model.", false);
let fn_args = vec![
"feature_size",
"max_depth",
diff --git a/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
b/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
index d8c901c..69047de 100644
--- a/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
+++ b/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
@@ -52,9 +52,9 @@ def entrypoint(argv):
return summary
"#;
- let input1 = FunctionInput::new("InPartyA", "Input from party A");
- let input2 = FunctionInput::new("InPartyB", "Input from party B");
- let fusion_output = FunctionOutput::new("OutFusionData", "Output fusion
data");
+ let input1 = FunctionInput::new("InPartyA", "Input from party A", false);
+ let input2 = FunctionInput::new("InPartyB", "Input from party B", false);
+ let fusion_output = FunctionOutput::new("OutFusionData", "Output fusion
data", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mesapy_data_fusion_demo")
.description("Mesapy Data Fusion Function")
@@ -230,7 +230,7 @@ def entrypoint(argv):
return "%s" % cnt
"#;
- let input_spec = FunctionInput::new("InputData", "Lines of Data");
+ let input_spec = FunctionInput::new("InputData", "Lines of Data", false);
let request = RegisterFunctionRequestBuilder::new()
.name("wlc")
.description("Mesapy Word Line Count Function")
diff --git a/tests/functional/enclave/src/management_service.rs
b/tests/functional/enclave/src/management_service.rs
index 7001006..765192b 100644
--- a/tests/functional/enclave/src/management_service.rs
+++ b/tests/functional/enclave/src/management_service.rs
@@ -126,8 +126,8 @@ fn test_get_input_file() {
#[test_case]
fn test_register_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mock_function")
.executor_type(ExecutorType::Python)
@@ -146,8 +146,8 @@ fn test_register_function() {
#[test_case]
fn test_register_private_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mock_function")
.executor_type(ExecutorType::Python)
@@ -167,8 +167,8 @@ fn test_register_private_function() {
#[test_case]
fn test_delete_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mock_function")
.executor_type(ExecutorType::Python)
@@ -190,8 +190,8 @@ fn test_delete_function() {
#[test_case]
fn test_update_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mock_function")
.executor_type(ExecutorType::Python)
@@ -206,8 +206,8 @@ fn test_update_function() {
let response = client.register_function(request);
let original_id = response.unwrap().function_id;
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = UpdateFunctionRequestBuilder::new()
.function_id(original_id.clone())
.name("mock_function")
@@ -241,8 +241,8 @@ fn test_list_functions() {
#[test_case]
fn test_get_function() {
- let function_input = FunctionInput::new("input", "input_desc");
- let function_output = FunctionOutput::new("output", "output_desc");
+ let function_input = FunctionInput::new("input", "input_desc", false);
+ let function_output = FunctionOutput::new("output", "output_desc", false);
let request = RegisterFunctionRequestBuilder::new()
.name("mock_function")
.executor_type(ExecutorType::Python)
diff --git a/types/src/function.rs b/types/src/function.rs
index 8f93f64..b2ee267 100644
--- a/types/src/function.rs
+++ b/types/src/function.rs
@@ -24,13 +24,15 @@ use uuid::Uuid;
pub struct FunctionInput {
pub name: String,
pub description: String,
+ pub optional: bool,
}
impl FunctionInput {
- pub fn new(name: impl Into<String>, description: impl Into<String>) ->
Self {
+ pub fn new(name: impl Into<String>, description: impl Into<String>,
optional: bool) -> Self {
Self {
name: name.into(),
description: description.into(),
+ optional,
}
}
}
@@ -39,13 +41,15 @@ impl FunctionInput {
pub struct FunctionOutput {
pub name: String,
pub description: String,
+ pub optional: bool,
}
impl FunctionOutput {
- pub fn new(name: impl Into<String>, description: impl Into<String>) ->
Self {
+ pub fn new(name: impl Into<String>, description: impl Into<String>,
optional: bool) -> Self {
Self {
name: name.into(),
description: description.into(),
+ optional,
}
}
}
diff --git a/types/src/task_state.rs b/types/src/task_state.rs
index dca7203..54227d5 100644
--- a/types/src/task_state.rs
+++ b/types/src/task_state.rs
@@ -135,12 +135,29 @@ impl Task<Create> {
// check input fkeys
let inputs_spec: HashSet<&String> = function.inputs.iter().map(|f|
&f.name).collect();
- let req_input_fkeys: HashSet<&String> =
req_input_owners.keys().collect();
+ let mut req_input_fkeys: HashSet<&String> =
req_input_owners.keys().collect();
+ // If an input/output file is marked with `optional: True`, users do
not need to
+ // register the file.
+ let option_inputs_spec: HashSet<&String> = function
+ .inputs
+ .iter()
+ .filter(|f| f.optional)
+ .map(|f| &f.name)
+ .collect();
+ req_input_fkeys.extend(&option_inputs_spec);
+
ensure!(inputs_spec == req_input_fkeys, "input keys mismatch");
// check output fkeys
let outputs_spec: HashSet<&String> = function.outputs.iter().map(|f|
&f.name).collect();
- let req_output_fkeys: HashSet<&String> =
req_output_owners.keys().collect();
+ let mut req_output_fkeys: HashSet<&String> =
req_output_owners.keys().collect();
+ let option_outputs_spec: HashSet<&String> = function
+ .outputs
+ .iter()
+ .filter(|f| f.optional)
+ .map(|f| &f.name)
+ .collect();
+ req_output_fkeys.extend(&option_outputs_spec);
ensure!(outputs_spec == req_output_fkeys, "output keys mismatch");
let ts = TaskState {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]