This is an automated email from the ASF dual-hosted git repository.

mssun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c51791  Support optional input/output files (#603)
7c51791 is described below

commit 7c517915745425de53730641172f770838605589
Author: Qinkun Bao <[email protected]>
AuthorDate: Thu Feb 3 22:35:32 2022 -0800

    Support optional input/output files (#603)
---
 cmake/scripts/test.sh                              |   1 +
 examples/c/builtin_ordered_set_intersect.c         |   8 +-
 examples/python/mesapy_optional_files.py           | 216 +++++++++++++++++++++
 examples/python/mesapy_optional_files_payload.py   |  27 +++
 .../rust/builtin_ordered_set_intersect/src/main.rs |   8 +-
 sdk/python/teaclave.py                             |   8 +-
 services/management/enclave/src/service.rs         |  23 +--
 .../src/proto/teaclave_frontend_service.proto      |   2 +
 services/proto/src/teaclave_frontend_service.rs    |   4 +
 .../enclave/src/end_to_end/builtin_gbdt_train.rs   |   4 +-
 .../enclave/src/end_to_end/mesapy_data_fusion.rs   |   8 +-
 tests/functional/enclave/src/management_service.rs |  24 +--
 types/src/function.rs                              |   8 +-
 types/src/task_state.rs                            |  21 +-
 14 files changed, 316 insertions(+), 46 deletions(-)

diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh
index a2eb075..3b18f33 100755
--- a/cmake/scripts/test.sh
+++ b/cmake/scripts/test.sh
@@ -239,6 +239,7 @@ run_examples() {
   python3 builtin_echo.py
   python3 mesapy_echo.py
   python3 mesapy_logistic_reg.py
+  python3 mesapy_optional_files.py
   python3 builtin_gbdt_train.py
   python3 builtin_online_decrypt.py
   python3 builtin_private_join_and_compute.py
diff --git a/examples/c/builtin_ordered_set_intersect.c 
b/examples/c/builtin_ordered_set_intersect.c
index c93d6cd..1ce5e57 100644
--- a/examples/c/builtin_ordered_set_intersect.c
+++ b/examples/c/builtin_ordered_set_intersect.c
@@ -62,12 +62,12 @@ const char *register_function_request_serialized = QUOTE(
     "payload": [],
     "arguments": ["order"],
     "inputs": [
-        {"name": "input_data1", "description": "Client 0 data."},
-        {"name": "input_data2", "description": "Client 1 data."}
+        {"name": "input_data1", "description": "Client 0 data.", "optional": 
false},
+        {"name": "input_data2", "description": "Client 1 data.", "optional": 
false}
     ],
     "outputs": [ 
-        {"name": "output_result1", "description": "Output data."},
-        {"name": "output_result2", "description": "Output data."}
+        {"name": "output_result1", "description": "Output data.", "optional": 
false},
+        {"name": "output_result2", "description": "Output data.", "optional": 
false}
     ],
     "user_allowlist": ["user0", "user1"]
 });
diff --git a/examples/python/mesapy_optional_files.py 
b/examples/python/mesapy_optional_files.py
new file mode 100644
index 0000000..add7ceb
--- /dev/null
+++ b/examples/python/mesapy_optional_files.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+from teaclave import FunctionInput, FunctionOutput, OwnerList, DataMap
+from utils import USER_ID, USER_PASSWORD, connect_authentication_service, 
connect_frontend_service, PlatformAdmin
+
+
+class OptionalFilesTwoParticipants:
+
+    def __init__(self, user_id1, user_password1, user_id2, user_password2):
+        self.user_id1 = user_id1
+        self.user_password1 = user_password1
+        with connect_authentication_service() as client1:
+            token1 = client1.user_login(self.user_id1, self.user_password1)
+
+        self.client1 = connect_frontend_service()
+        metadata = {"id": self.user_id1, "token": token1}
+        self.client1.metadata = metadata
+
+        self.user_id2 = user_id2
+        self.user_password2 = user_password2
+        with connect_authentication_service() as client2:
+            token2 = client2.user_login(self.user_id2, self.user_password2)
+
+        self.client2 = connect_frontend_service()
+        metadata = {"id": self.user_id2, "token": token2}
+        self.client2.metadata = metadata
+
+    def validate_task(self, task):
+        # The task has data from two parties
+        if len(task["inputs_ownership"]) != 2:
+            return False
+        # The data is from user_a and user_b
+        if (task["inputs_ownership"][1]['uids'] == ['user_a']
+                and task["inputs_ownership"][0]['uids'] == ['user_b']) or (
+                    task["inputs_ownership"][1]['uids'] == ['user_b']
+                    and task["inputs_ownership"][0]['uids'] == ['user_a']):
+            return True
+        else:
+            return False
+
+    def run_task(self, function_id):
+        client1 = self.client1
+        client2 = self.client2
+        print("[+] registering input file")
+        url = 
"http://localhost:6789/fixtures/functions/gbdt_training/train.enc";
+        cmac = [
+            0x88, 0x1a, 0xdc, 0xa6, 0xb0, 0x52, 0x44, 0x72, 0xda, 0x0a, 0x9d,
+            0x0b, 0xb0, 0x2b, 0x9a, 0xf9
+        ]
+        schema = "teaclave-file-128"
+        key = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+        iv = []
+        input_file_id1 = client1.register_input_file(url, schema, key, iv,
+                                                     cmac)
+        input_file_id2 = client2.register_input_file(url, schema, key, iv,
+                                                     cmac)
+
+        print("[+] creating task")
+        task_id = client1.create_task(function_id=function_id,
+                                      function_arguments={},
+                                      executor="mesapy",
+                                      inputs_ownership=[
+                                          OwnerList("input_data1",
+                                                    [self.user_id1]),
+                                          OwnerList("input_data2",
+                                                    [self.user_id2])
+                                      ])
+        client1.assign_data_to_task(task_id,
+                                    [DataMap("input_data1", input_file_id1)],
+                                    [])
+        client2.assign_data_to_task(task_id,
+                                    [DataMap("input_data2", input_file_id2)],
+                                    [])
+
+        task1 = client1.get_task(task_id)
+        if self.validate_task(task1):
+            print("User_a approves the task")
+            client1.approve_task(task_id)
+        else:
+            print("User_a disapproves the task")
+
+        task2 = client2.get_task(task_id)
+        if self.validate_task(task2):
+            print("User_b approves the task")
+            client2.approve_task(task_id)
+        else:
+            print("User_b disapproves the task")
+
+        client1.invoke_task(task_id)
+        result = client1.get_task_result(task_id)
+        return bytes(result)
+
+
+class OptionalFilesOneParticipant:
+
+    def __init__(self, user_id, user_password):
+        self.user_id = user_id
+        self.user_password = user_password
+        with connect_authentication_service() as client:
+            print(f"[+] {self.user_id} login")
+            token = client.user_login(self.user_id, self.user_password)
+
+        self.client = connect_frontend_service()
+        metadata = {"id": self.user_id, "token": token}
+        self.client.metadata = metadata
+
+    # The function template defines three optional input files
+    def register_function(self):
+        print("[+] registering function")
+        with open("mesapy_optional_files_payload.py", "rb") as f:
+            payload = f.read()
+        function_id = self.client.register_function(
+            name="mesapy-echo",
+            description="An echo function implemented in Python",
+            executor_type="python",
+            payload=list(payload),
+            arguments=[],
+            inputs=[
+                FunctionInput("input_data1", "Client 0 data.", True),
+                FunctionInput("input_data2", "Client 1 data.", True),
+                FunctionInput("input_data3", "Client 2 data.", True)
+            ])
+        return function_id
+
+    def with_input(self, user_id, function_id):
+        client = self.client
+        print("[+] registering input file")
+        url = 
"http://localhost:6789/fixtures/functions/gbdt_training/train.enc";
+        cmac = [
+            0x88, 0x1a, 0xdc, 0xa6, 0xb0, 0x52, 0x44, 0x72, 0xda, 0x0a, 0x9d,
+            0x0b, 0xb0, 0x2b, 0x9a, 0xf9
+        ]
+        schema = "teaclave-file-128"
+        key = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+        iv = []
+        input_file_id = client.register_input_file(url, schema, key, iv, cmac)
+        print("[+] creating task")
+        task_id = client.create_task(
+            function_id=function_id,
+            function_arguments={},
+            executor="mesapy",
+            inputs_ownership=[OwnerList("input_data1", [user_id])])
+        print("[+] assigning data to task")
+        client.assign_data_to_task(task_id,
+                                   [DataMap("input_data1", input_file_id)], [])
+
+        print("[+] approving task")
+        client.approve_task(task_id)
+
+        print("[+] invoking task")
+        client.invoke_task(task_id)
+        print("[+] getting result")
+        result = client.get_task_result(task_id)
+        print("[+] done")
+        return bytes(result)
+
+    def without_input(self, user_id, function_id):
+        print("[+] creating task")
+        client = self.client
+        task_id = client.create_task(function_id=function_id,
+                                     function_arguments={},
+                                     executor="mesapy")
+
+        print("[+] invoking task")
+        client.invoke_task(task_id)
+
+        print("[+] getting result")
+        result = client.get_task_result(task_id)
+        print("[+] done")
+        return bytes(result)
+
+
+def main():
+
+    platform_admin = PlatformAdmin("admin", "teaclave")
+    try:
+        platform_admin.register_user("user_a", "password")
+        platform_admin.register_user("user_b", "password")
+    except Exception:
+        pass
+    task = OptionalFilesOneParticipant("user_a", "password")
+    function_id = task.register_function()
+    print("Data owners do not register input files")
+    rt = task.without_input("user_a", function_id)
+    print(rt)
+    print("Data owners register input files")
+    rt = task.with_input("user_a", function_id)
+    print(rt)
+    print("The task has more than more participants")
+    task = OptionalFilesTwoParticipants("user_a", "password", "user_b",
+                                        "password")
+    rt = task.run_task(function_id)
+    print(rt)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/python/mesapy_optional_files_payload.py 
b/examples/python/mesapy_optional_files_payload.py
new file mode 100644
index 0000000..9650952
--- /dev/null
+++ b/examples/python/mesapy_optional_files_payload.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+def entrypoint(argv):
+    try:
+        f = teaclave_open("input_data1", "rb")
+        data = f.readlines()
+        return data[0]
+    except:
+        return "input_data1 does not exist"
diff --git a/examples/rust/builtin_ordered_set_intersect/src/main.rs 
b/examples/rust/builtin_ordered_set_intersect/src/main.rs
index 658784e..64960f2 100644
--- a/examples/rust/builtin_ordered_set_intersect/src/main.rs
+++ b/examples/rust/builtin_ordered_set_intersect/src/main.rs
@@ -120,12 +120,12 @@ impl Client {
             None,
             Some(&["order"]),
             Some(vec![
-                teaclave_client_sdk::FunctionInput::new("input_data1", "Client 
0 data."),
-                teaclave_client_sdk::FunctionInput::new("input_data2", "Client 
1 data."),
+                teaclave_client_sdk::FunctionInput::new("input_data1", "Client 
0 data.", false),
+                teaclave_client_sdk::FunctionInput::new("input_data2", "Client 
1 data.", false),
             ]),
             Some(vec![
-                teaclave_client_sdk::FunctionOutput::new("output_result1", 
"Output data."),
-                teaclave_client_sdk::FunctionOutput::new("output_result2", 
"Output data."),
+                teaclave_client_sdk::FunctionOutput::new("output_result1", 
"Output data.", false),
+                teaclave_client_sdk::FunctionOutput::new("output_result2", 
"Output data.", false),
             ]),
         )?;
         self.client.get_function(&function_id)?;
diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py
index ed9b1f1..39b650e 100644
--- a/sdk/python/teaclave.py
+++ b/sdk/python/teaclave.py
@@ -209,11 +209,13 @@ class FunctionInput:
     Args:
         name: Name of input data.
         description: Description of the input data.
+        optional: [Default: False] Data owners do not need to register the 
data.
     """
 
-    def __init__(self, name: str, description: str):
+    def __init__(self, name: str, description: str, optional=False):
         self.name = name
         self.description = description
+        self.optional = optional
 
 
 class FunctionOutput:
@@ -222,11 +224,13 @@ class FunctionOutput:
     Args:
         name: Name of output data.
         description: Description of the output data.
+        optional: [Default: False] Data owners do not need to register the 
data.
     """
 
-    def __init__(self, name: str, description: str):
+    def __init__(self, name: str, description: str, optional=False):
         self.name = name
         self.description = description
+        self.optional = optional
 
 
 class OwnerList:
diff --git a/services/management/enclave/src/service.rs 
b/services/management/enclave/src/service.rs
index 675fe9a..55e730d 100644
--- a/services/management/enclave/src/service.rs
+++ b/services/management/enclave/src/service.rs
@@ -438,8 +438,8 @@ impl TeaclaveManagement for TeaclaveManagementService {
     // access control: none
     // when a task is created, following rules will be verified:
     // 1) arugments match function definition
-    // 2) input match function definition
-    // 3) output match function definition
+    // 2) input files match function definition
+    // 3) output files match function definition
     // 4) requested user_id in the user_allowlist
     fn create_task(
         &self,
@@ -466,7 +466,6 @@ impl TeaclaveManagement for TeaclaveManagementService {
                 return 
Err(TeaclaveManagementServiceError::PermissionDenied.into());
             }
         }
-
         let task = Task::<Create>::new(
             user_id,
             request.executor,
@@ -478,7 +477,6 @@ impl TeaclaveManagement for TeaclaveManagementService {
         .map_err(|_| TeaclaveManagementServiceError::BadTask)?;
 
         log::debug!("CreateTask: {:?}", task);
-
         let ts: TaskState = task.into();
         self.write_to_db(&ts)
             .map_err(|_| TeaclaveManagementServiceError::StorageError)?;
@@ -643,11 +641,8 @@ impl TeaclaveManagement for TeaclaveManagementService {
         })?;
 
         log::debug!("InvokeTask: get task: {:?}", task);
-
         let staged_task = task.stage_for_running(&user_id, function)?;
-
         log::debug!("InvokeTask: staged task: {:?}", staged_task);
-
         self.enqueue_to_db(StagedTask::get_queue_key().as_bytes(), 
&staged_task)?;
 
         let ts: TaskState = task.into();
@@ -831,10 +826,10 @@ impl TeaclaveManagementService {
         input_file.uuid = 
Uuid::parse_str("00000000-0000-0000-0000-000000000002")?;
         self.write_to_db(&input_file)?;
 
-        let function_input = FunctionInput::new("input", "input_desc");
-        let function_output = FunctionOutput::new("output", "output_desc");
-        let function_input2 = FunctionInput::new("input2", "input_desc");
-        let function_output2 = FunctionOutput::new("output2", "output_desc");
+        let function_input = FunctionInput::new("input", "input_desc", false);
+        let function_output = FunctionOutput::new("output", "output_desc", 
false);
+        let function_input2 = FunctionInput::new("input2", "input_desc", 
false);
+        let function_output2 = FunctionOutput::new("output2", "output_desc", 
false);
 
         let function = FunctionBuilder::new()
             
.id(Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap())
@@ -850,7 +845,7 @@ impl TeaclaveManagementService {
 
         self.write_to_db(&function)?;
 
-        let function_output = FunctionOutput::new("output", "output_desc");
+        let function_output = FunctionOutput::new("output", "output_desc", 
false);
         let function = FunctionBuilder::new()
             
.id(Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap())
             .name("mock-func-2")
@@ -912,8 +907,8 @@ pub mod tests {
     }
 
     pub fn handle_function() {
-        let function_input = FunctionInput::new("input", "input_desc");
-        let function_output = FunctionOutput::new("output", "output_desc");
+        let function_input = FunctionInput::new("input", "input_desc", false);
+        let function_output = FunctionOutput::new("output", "output_desc", 
false);
         let function = FunctionBuilder::new()
             .id(Uuid::new_v4())
             .name("mock_function")
diff --git a/services/proto/src/proto/teaclave_frontend_service.proto 
b/services/proto/src/proto/teaclave_frontend_service.proto
index 1dfb038..71d3628 100644
--- a/services/proto/src/proto/teaclave_frontend_service.proto
+++ b/services/proto/src/proto/teaclave_frontend_service.proto
@@ -98,11 +98,13 @@ message GetInputFileResponse {
 message FunctionInput {
   string name = 1;
   string description = 2;
+  bool optional = 3;
 }
 
 message FunctionOutput {
   string name = 1;
   string description = 2;
+  bool optional = 3;
 }
 
 message OwnerList {
diff --git a/services/proto/src/teaclave_frontend_service.rs 
b/services/proto/src/teaclave_frontend_service.rs
index ba5d870..26d3198 100644
--- a/services/proto/src/teaclave_frontend_service.rs
+++ b/services/proto/src/teaclave_frontend_service.rs
@@ -1051,6 +1051,7 @@ impl std::convert::TryFrom<proto::FunctionInput> for 
FunctionInput {
         let ret = Self {
             name: proto.name,
             description: proto.description,
+            optional: proto.optional,
         };
 
         Ok(ret)
@@ -1062,6 +1063,7 @@ impl From<FunctionInput> for proto::FunctionInput {
         Self {
             name: input.name,
             description: input.description,
+            optional: input.optional,
         }
     }
 }
@@ -1073,6 +1075,7 @@ impl std::convert::TryFrom<proto::FunctionOutput> for 
FunctionOutput {
         let ret = Self {
             name: proto.name,
             description: proto.description,
+            optional: proto.optional,
         };
 
         Ok(ret)
@@ -1084,6 +1087,7 @@ impl From<FunctionOutput> for proto::FunctionOutput {
         Self {
             name: output.name,
             description: output.description,
+            optional: output.optional,
         }
     }
 }
diff --git a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs 
b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
index 675c61d..b323666 100644
--- a/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
+++ b/tests/functional/enclave/src/end_to_end/builtin_gbdt_train.rs
@@ -46,8 +46,8 @@ fn authorized_frontend_client() -> TeaclaveFrontendClient {
 }
 
 fn register_gbdt_function(client: &mut TeaclaveFrontendClient) -> ExternalID {
-    let fn_input = FunctionInput::new("training_data", "Input traning data 
file.");
-    let fn_output = FunctionOutput::new("trained_model", "Output trained 
model.");
+    let fn_input = FunctionInput::new("training_data", "Input traning data 
file.", false);
+    let fn_output = FunctionOutput::new("trained_model", "Output trained 
model.", false);
     let fn_args = vec![
         "feature_size",
         "max_depth",
diff --git a/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs 
b/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
index d8c901c..69047de 100644
--- a/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
+++ b/tests/functional/enclave/src/end_to_end/mesapy_data_fusion.rs
@@ -52,9 +52,9 @@ def entrypoint(argv):
     return summary
 "#;
 
-    let input1 = FunctionInput::new("InPartyA", "Input from party A");
-    let input2 = FunctionInput::new("InPartyB", "Input from party B");
-    let fusion_output = FunctionOutput::new("OutFusionData", "Output fusion 
data");
+    let input1 = FunctionInput::new("InPartyA", "Input from party A", false);
+    let input2 = FunctionInput::new("InPartyB", "Input from party B", false);
+    let fusion_output = FunctionOutput::new("OutFusionData", "Output fusion 
data", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mesapy_data_fusion_demo")
         .description("Mesapy Data Fusion Function")
@@ -230,7 +230,7 @@ def entrypoint(argv):
     return "%s" % cnt
 "#;
 
-    let input_spec = FunctionInput::new("InputData", "Lines of Data");
+    let input_spec = FunctionInput::new("InputData", "Lines of Data", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("wlc")
         .description("Mesapy Word Line Count Function")
diff --git a/tests/functional/enclave/src/management_service.rs 
b/tests/functional/enclave/src/management_service.rs
index 7001006..765192b 100644
--- a/tests/functional/enclave/src/management_service.rs
+++ b/tests/functional/enclave/src/management_service.rs
@@ -126,8 +126,8 @@ fn test_get_input_file() {
 
 #[test_case]
 fn test_register_function() {
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mock_function")
         .executor_type(ExecutorType::Python)
@@ -146,8 +146,8 @@ fn test_register_function() {
 
 #[test_case]
 fn test_register_private_function() {
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mock_function")
         .executor_type(ExecutorType::Python)
@@ -167,8 +167,8 @@ fn test_register_private_function() {
 
 #[test_case]
 fn test_delete_function() {
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mock_function")
         .executor_type(ExecutorType::Python)
@@ -190,8 +190,8 @@ fn test_delete_function() {
 
 #[test_case]
 fn test_update_function() {
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mock_function")
         .executor_type(ExecutorType::Python)
@@ -206,8 +206,8 @@ fn test_update_function() {
     let response = client.register_function(request);
     let original_id = response.unwrap().function_id;
 
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = UpdateFunctionRequestBuilder::new()
         .function_id(original_id.clone())
         .name("mock_function")
@@ -241,8 +241,8 @@ fn test_list_functions() {
 
 #[test_case]
 fn test_get_function() {
-    let function_input = FunctionInput::new("input", "input_desc");
-    let function_output = FunctionOutput::new("output", "output_desc");
+    let function_input = FunctionInput::new("input", "input_desc", false);
+    let function_output = FunctionOutput::new("output", "output_desc", false);
     let request = RegisterFunctionRequestBuilder::new()
         .name("mock_function")
         .executor_type(ExecutorType::Python)
diff --git a/types/src/function.rs b/types/src/function.rs
index 8f93f64..b2ee267 100644
--- a/types/src/function.rs
+++ b/types/src/function.rs
@@ -24,13 +24,15 @@ use uuid::Uuid;
 pub struct FunctionInput {
     pub name: String,
     pub description: String,
+    pub optional: bool,
 }
 
 impl FunctionInput {
-    pub fn new(name: impl Into<String>, description: impl Into<String>) -> 
Self {
+    pub fn new(name: impl Into<String>, description: impl Into<String>, 
optional: bool) -> Self {
         Self {
             name: name.into(),
             description: description.into(),
+            optional,
         }
     }
 }
@@ -39,13 +41,15 @@ impl FunctionInput {
 pub struct FunctionOutput {
     pub name: String,
     pub description: String,
+    pub optional: bool,
 }
 
 impl FunctionOutput {
-    pub fn new(name: impl Into<String>, description: impl Into<String>) -> 
Self {
+    pub fn new(name: impl Into<String>, description: impl Into<String>, 
optional: bool) -> Self {
         Self {
             name: name.into(),
             description: description.into(),
+            optional,
         }
     }
 }
diff --git a/types/src/task_state.rs b/types/src/task_state.rs
index dca7203..54227d5 100644
--- a/types/src/task_state.rs
+++ b/types/src/task_state.rs
@@ -135,12 +135,29 @@ impl Task<Create> {
 
         // check input fkeys
         let inputs_spec: HashSet<&String> = function.inputs.iter().map(|f| 
&f.name).collect();
-        let req_input_fkeys: HashSet<&String> = 
req_input_owners.keys().collect();
+        let mut req_input_fkeys: HashSet<&String> = 
req_input_owners.keys().collect();
+        // If an input/output file is marked with `optional: True`, users do 
not need to
+        // register the file.
+        let option_inputs_spec: HashSet<&String> = function
+            .inputs
+            .iter()
+            .filter(|f| f.optional)
+            .map(|f| &f.name)
+            .collect();
+        req_input_fkeys.extend(&option_inputs_spec);
+
         ensure!(inputs_spec == req_input_fkeys, "input keys mismatch");
 
         // check output fkeys
         let outputs_spec: HashSet<&String> = function.outputs.iter().map(|f| 
&f.name).collect();
-        let req_output_fkeys: HashSet<&String> = 
req_output_owners.keys().collect();
+        let mut req_output_fkeys: HashSet<&String> = 
req_output_owners.keys().collect();
+        let option_outputs_spec: HashSet<&String> = function
+            .outputs
+            .iter()
+            .filter(|f| f.optional)
+            .map(|f| &f.name)
+            .collect();
+        req_output_fkeys.extend(&option_outputs_spec);
         ensure!(outputs_spec == req_output_fkeys, "output keys mismatch");
 
         let ts = TaskState {

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to