This is an automated email from the ASF dual-hosted git repository. hsun pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
commit 014059a223fd8ddf3aab375e0b44e8375b3b5e60 Author: GeminiCarrie <[email protected]> AuthorDate: Wed May 10 06:15:57 2023 +0000 Update python sdk and examples --- .gitignore | 3 + cmake/scripts/test.sh | 18 +- examples/README.md | 10 +- examples/python/builtin_face_detection.py | 4 - examples/python/builtin_gbdt_train.py | 2 - examples/python/builtin_online_decrypt.py | 1 - examples/python/builtin_ordered_set_intersect.py | 15 +- examples/python/builtin_password_check.py | 15 +- .../python/builtin_private_join_and_compute.py | 15 +- examples/python/builtin_rsa_sign.py | 6 +- examples/python/mesapy_deadloop_cancel.py | 3 +- examples/python/mesapy_echo.py | 2 +- examples/python/mesapy_logistic_reg.py | 15 +- examples/python/requirements.txt | 3 + examples/python/test_disable_function.py | 6 +- examples/python/utils.py | 7 +- examples/python/wasm_c_simple_add.py | 2 +- examples/python/wasm_rust_psi.py | 15 +- sdk/python/teaclave.py | 761 ++++++++++----------- tests/scripts/functional_tests.py | 131 +++- 20 files changed, 500 insertions(+), 534 deletions(-) diff --git a/.gitignore b/.gitignore index 4190671d..ba070e54 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,6 @@ cov_report examples/c/builtin_echo examples/c/builtin_ordered_set_intersect examples/python/out.jpg +# ignore grpc files during building and testing +sdk/python/*_pb2.py +sdk/python/*_grpc.py diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh index e61e531e..0d2474dd 100755 --- a/cmake/scripts/test.sh +++ b/cmake/scripts/test.sh @@ -79,6 +79,14 @@ wait_port() { done } +generate_python_grpc_stubs() { + python3 -m grpc_tools.protoc \ + --proto_path=${TEACLAVE_PROJECT_ROOT}/services/proto/src/proto \ + --python_out=${TEACLAVE_PROJECT_ROOT}/sdk/python \ + --grpclib_python_out=${TEACLAVE_PROJECT_ROOT}/sdk/python \ + ${TEACLAVE_PROJECT_ROOT}/services/proto/src/proto/*.proto +} + run_integration_tests() { trap cleanup INT TERM ERR @@ -152,7 +160,9 @@ run_functional_tests() { ./teaclave_functional_tests -t end_to_end - # Run script tests + generate_python_grpc_stubs + + export PYTHONPATH=${TEACLAVE_PROJECT_ROOT}/sdk/python ./scripts/functional_tests.py -v popd @@ -283,6 +293,8 @@ run_examples() { sleep 3 # wait for execution services popd + generate_python_grpc_stubs + # run builtin examples builtin_examples @@ -328,6 +340,8 @@ run_libos_examples() { sleep 3 # wait for execution services popd + generate_python_grpc_stubs + # run builtin examples builtin_examples @@ -374,6 +388,8 @@ run_cancel_test() { echo "executor 1 pid: $exe_pid1" echo "executor 2 pid: $exe_pid2" + generate_python_grpc_stubs + pushd ${TEACLAVE_PROJECT_ROOT}/examples/python export PYTHONPATH=${TEACLAVE_PROJECT_ROOT}/sdk/python python3 mesapy_deadloop_cancel.py diff --git a/examples/README.md b/examples/README.md index 7cba2cf5..8eeda5df 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,8 +10,14 @@ results with the Teclave's client SDK in both single and multi-party setups. Before trying these examples, please make sure all services in the Teaclave platform has been properly launched. Also, for examples implemented in Python, -don't forget to set the `PYTHONPATH` to the `sdk` path so that the scripts can -successfully import the `teaclave` module. +don't forget to generate protocol stub files and set the `PYTHONPATH` to the +`sdk` path so that the scripts can successfully import the `teaclave` module. + +Generate stub files by grpcio-tools and grpclib. + +``` +python3 -m grpc_tools.protoc --proto_path=../../services/proto/src/proto --python_out=. --grpclib_python_out=. ../../services/proto/src/proto/{teaclave_authentication_service.proto,teaclave_frontend_service.proto,teaclave_common.proto} +``` For instance, use the following command to invoke an echo function in Teaclave: diff --git a/examples/python/builtin_face_detection.py b/examples/python/builtin_face_detection.py index a0743735..570f9f9a 100644 --- a/examples/python/builtin_face_detection.py +++ b/examples/python/builtin_face_detection.py @@ -17,13 +17,9 @@ # specific language governing permissions and limitations # under the License. -import os -import sys import json from PIL import Image, ImageDraw -import requests - from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service from teaclave import FunctionArgument diff --git a/examples/python/builtin_gbdt_train.py b/examples/python/builtin_gbdt_train.py index 9b125104..bc932109 100644 --- a/examples/python/builtin_gbdt_train.py +++ b/examples/python/builtin_gbdt_train.py @@ -17,8 +17,6 @@ # specific language governing permissions and limitations # under the License. -import sys - from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service diff --git a/examples/python/builtin_online_decrypt.py b/examples/python/builtin_online_decrypt.py index 9a30b3b0..302857da 100644 --- a/examples/python/builtin_online_decrypt.py +++ b/examples/python/builtin_online_decrypt.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. -import sys import base64 from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service diff --git a/examples/python/builtin_ordered_set_intersect.py b/examples/python/builtin_ordered_set_intersect.py index 277cef70..d084bf2b 100644 --- a/examples/python/builtin_ordered_set_intersect.py +++ b/examples/python/builtin_ordered_set_intersect.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin # In the example, user 0 creates the task and user 0, 1, upload their private data. # Then user 0 invokes the task and user 0, 1 get the result. @@ -63,13 +61,6 @@ USER_DATA_1 = UserData("user1", "password", ], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) -class DataList: - - def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id - - class Client: def __init__(self, user_id, user_password): @@ -143,8 +134,8 @@ class Client: output_id = client.register_output_file(url, schema, key, iv) print(f"[+] {self.user_id} assigning data to task") - client.assign_data_to_task(task_id, [DataList(input_label, input_id)], - [DataList(output_label, output_id)]) + client.assign_data_to_task(task_id, [DataMap(input_label, input_id)], + [DataMap(output_label, output_id)]) def approve_task(self, task_id): client = self.client diff --git a/examples/python/builtin_password_check.py b/examples/python/builtin_password_check.py index 1715be0d..dac68646 100644 --- a/examples/python/builtin_password_check.py +++ b/examples/python/builtin_password_check.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - -from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from teaclave import FunctionInput, OwnerList, DataMap +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin # In the example, user 0 creates the task and user 0, 1, upload their private data. # Then user 0 invokes the task and user 0, 1 get the result. @@ -69,13 +67,6 @@ USER_DATA_1 = UserData("user1", "password", ], [], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) -class DataList: - - def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id - - class Client: def __init__(self, user_id, user_password): @@ -134,7 +125,7 @@ class Client: input_id = client.register_input_file(url, schema, key, iv, cmac) print(f"[+] {self.user_id} assigning data to task") - client.assign_data_to_task(task_id, [DataList(input_label, input_id)], + client.assign_data_to_task(task_id, [DataMap(input_label, input_id)], []) def approve_task(self, task_id): diff --git a/examples/python/builtin_private_join_and_compute.py b/examples/python/builtin_private_join_and_compute.py index 0b7c89b6..a5e77c9b 100644 --- a/examples/python/builtin_private_join_and_compute.py +++ b/examples/python/builtin_private_join_and_compute.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin # In the example, user 3 creates the task and user 0, 1, 2 upload their private data. # Then user 3 invokes the task and user 0, 1, 2 get the result. @@ -70,13 +68,6 @@ USER_DATA_2 = UserData("user2", "password", USER_DATA_3 = UserData("user3", "password") -class DataList: - - def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id - - class ConfigClient: def __init__(self, user_id, user_password): @@ -173,8 +164,8 @@ class DataClient: output_id = client.register_output_file(url, schema, key, iv) print(f"[+] {self.user_id} assigning data to task") - client.assign_data_to_task(task_id, [DataList(input_label, input_id)], - [DataList(output_label, output_id)]) + client.assign_data_to_task(task_id, [DataMap(input_label, input_id)], + [DataMap(output_label, output_id)]) def approve_task(self, task_id): client = self.client diff --git a/examples/python/builtin_rsa_sign.py b/examples/python/builtin_rsa_sign.py index 3ba494a1..63128d66 100644 --- a/examples/python/builtin_rsa_sign.py +++ b/examples/python/builtin_rsa_sign.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - -from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from teaclave import FunctionInput, FunctionArgument, OwnerList, DataMap +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin def get_client(user_id, user_password): diff --git a/examples/python/mesapy_deadloop_cancel.py b/examples/python/mesapy_deadloop_cancel.py index a4b0eaa1..c4311da6 100644 --- a/examples/python/mesapy_deadloop_cancel.py +++ b/examples/python/mesapy_deadloop_cancel.py @@ -17,10 +17,9 @@ # specific language governing permissions and limitations # under the License. -import sys import time -from teaclave import FunctionInput, FunctionOutput, OwnerList, DataMap, TaskStatus +from teaclave import TaskStatus from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin diff --git a/examples/python/mesapy_echo.py b/examples/python/mesapy_echo.py index b60df003..71093744 100644 --- a/examples/python/mesapy_echo.py +++ b/examples/python/mesapy_echo.py @@ -19,7 +19,7 @@ import sys -from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap +from teaclave import FunctionArgument from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin diff --git a/examples/python/mesapy_logistic_reg.py b/examples/python/mesapy_logistic_reg.py index 2b8c8b81..409275e7 100644 --- a/examples/python/mesapy_logistic_reg.py +++ b/examples/python/mesapy_logistic_reg.py @@ -20,11 +20,9 @@ An example about Logistic Regression in MesaPy. """ -import sys -import binascii from typing import List from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from utils import connect_authentication_service, connect_frontend_service from enum import Enum @@ -79,13 +77,6 @@ class OutputData: self.label = label -class DataList: - - def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id - - class ConfigClient: def __init__(self, user_id, user_password): @@ -161,7 +152,7 @@ class ConfigClient: key = da.file_key iv = da.iv input_id = client.register_input_file(url, schema, key, iv, cmac) - input_data_list.append(DataList(da.label, input_id)) + input_data_list.append(DataMap(da.label, input_id)) print(f"[+] {self.user_id} registering output file") output_data_list = [] for out_data in outputs: @@ -170,7 +161,7 @@ class ConfigClient: key = out_data.file_key iv = out_data.iv output_id = client.register_output_file(out_url, schema, key, iv) - output_data_list.append(DataList(out_data.label, output_id)) + output_data_list.append(DataMap(out_data.label, output_id)) print(f"[+] {self.user_id} assigning data to task") client.assign_data_to_task(task_id, input_data_list, output_data_list) diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt index 6f3b7486..82d36f2f 100644 --- a/examples/python/requirements.txt +++ b/examples/python/requirements.txt @@ -3,3 +3,6 @@ toml cryptography requests Pillow +grpclib +grpcio +grpcio-tools diff --git a/examples/python/test_disable_function.py b/examples/python/test_disable_function.py index dbea609d..9923f6e9 100644 --- a/examples/python/test_disable_function.py +++ b/examples/python/test_disable_function.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - -from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from teaclave import FunctionArgument +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin class UserData: diff --git a/examples/python/utils.py b/examples/python/utils.py index 41299fa7..70115647 100644 --- a/examples/python/utils.py +++ b/examples/python/utils.py @@ -45,7 +45,7 @@ class PlatformAdmin: def __init__(self, user_id: str, user_password: str): self.client = AuthenticationService(AUTHENTICATION_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, - ENCLAVE_INFO_PATH).connect() + ENCLAVE_INFO_PATH) token = self.client.user_login(user_id, user_password) self.client.metadata = {"id": user_id, "token": token} @@ -55,10 +55,9 @@ class PlatformAdmin: def connect_authentication_service(): return AuthenticationService(AUTHENTICATION_SERVICE_ADDRESS, - AS_ROOT_CA_CERT_PATH, - ENCLAVE_INFO_PATH).connect() + AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH) def connect_frontend_service(): return FrontendService(FRONTEND_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, - ENCLAVE_INFO_PATH).connect() + ENCLAVE_INFO_PATH) diff --git a/examples/python/wasm_c_simple_add.py b/examples/python/wasm_c_simple_add.py index 867cba3a..ee6762b4 100644 --- a/examples/python/wasm_c_simple_add.py +++ b/examples/python/wasm_c_simple_add.py @@ -19,7 +19,7 @@ import sys -from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap +from teaclave import FunctionArgument from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin diff --git a/examples/python/wasm_rust_psi.py b/examples/python/wasm_rust_psi.py index 61fc1c53..60176e88 100644 --- a/examples/python/wasm_rust_psi.py +++ b/examples/python/wasm_rust_psi.py @@ -17,10 +17,8 @@ # specific language governing permissions and limitations # under the License. -import sys - from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap -from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin +from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin class UserData: @@ -66,13 +64,6 @@ USER_DATA_1 = UserData("user1", "password", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) -class DataList: - - def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id - - class Client: def __init__(self, user_id, user_password): @@ -168,8 +159,8 @@ class Client: output_id = client.register_output_file(url, schema, key, iv) print(f"[+] {self.user_id} assigning data to task") - client.assign_data_to_task(task_id, [DataList(input_label, input_id)], - [DataList(output_label, output_id)]) + client.assign_data_to_task(task_id, [DataMap(input_label, input_id)], + [DataMap(output_label, output_id)]) def approve_task(self, task_id): client = self.client diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py index 2fec9cd8..c19d325e 100644 --- a/sdk/python/teaclave.py +++ b/sdk/python/teaclave.py @@ -23,26 +23,33 @@ trusted TLS channel and communicate with Teaclave services (e.g., the authentication service and frontend service) through RPC protocols. """ -import struct import json import base64 import toml -import os import time +import os import ssl -import socket - -from typing import Tuple, Dict, List, Any -from enum import IntEnum import cryptography from cryptography import x509 from cryptography.hazmat.backends import default_backend +from google.protobuf.json_format import MessageToDict +from grpclib.client import Channel, _ChannelState +from grpclib.protocol import H2Protocol + from OpenSSL.crypto import load_certificate, FILETYPE_PEM, FILETYPE_ASN1 from OpenSSL.crypto import X509Store, X509StoreContext from OpenSSL import crypto +import teaclave_authentication_service_pb2 as auth +import teaclave_frontend_service_pb2 as fe +from teaclave_authentication_service_grpc import TeaclaveAuthenticationApiStub +from teaclave_frontend_service_grpc import TeaclaveFrontendStub +from teaclave_common_pb2 import TaskStatus, FileCryptoInfo + +from typing import Tuple, Dict, List, Any + __all__ = [ 'FrontendService', 'AuthenticationService', 'FunctionArgument', 'FunctionInput', 'FunctionOutput', 'OwnerList', 'DataMap' @@ -51,19 +58,13 @@ __all__ = [ Metadata = Dict[str, str] -class TaskStatus(IntEnum): - Created = 0 - DataAssigned = 1 - Approved = 2 - Staged = 3 - Running = 4 - Finished = 10 - Canceled = 20 - Failed = 99 - - class Request: - pass + message = None + + def __init__(self, method, response, metadata=dict()): + self.method = method + self.metadata = metadata + self.response = response class TeaclaveException(Exception): @@ -71,8 +72,8 @@ class TeaclaveException(Exception): class TeaclaveService: - channel = None metadata = None + stub = None def __init__(self, name: str, @@ -80,38 +81,79 @@ class TeaclaveService: as_root_ca_cert_path: str, enclave_info_path: str, dump_report=False): - self._context = ssl._create_unverified_context() self._name = name self._address = address self._as_root_ca_cert_path = as_root_ca_cert_path self._enclave_info_path = enclave_info_path - self._closed = False self._dump_report = dump_report + self._channel = TeaclaveChannel(self._name, self._address, + self._as_root_ca_cert_path, + self._enclave_info_path) + self._loop = self._channel._loop + + def call_method(self, request): + return self._loop.run_until_complete( + getattr(self.stub, request.method)(request.message, + metadata=request.metadata)) + def __enter__(self): return self def __exit__(self, *exc): - if not self._closed: - self.close() + self.close() def close(self): - self._closed = True - if self.channel: self.channel.close() + if self._channel: self._channel.close() - def check_channel(self): - if not self.channel: raise TeaclaveException("Channel is None") + def __del__(self) -> None: + self.close() def check_metadata(self): if not self.metadata: raise TeaclaveException("Metadata is None") - def connect(self): - """Establish trusted connection and verify remote attestation report. - """ - sock = socket.create_connection(self._address) - channel = self._context.wrap_socket(sock, - server_hostname=self._address[0]) - cert = channel.getpeercert(binary_form=True) + def check_channel(self): + self._channel.check_channel() + + def get_metadata(self): + return self.metadata + + +def create_context() -> ssl.SSLContext: + ctx = ssl._create_unverified_context() + ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20') + ctx.set_alpn_protocols(['h2']) + try: + ctx.set_npn_protocols(['h2']) + except NotImplementedError: + pass + return ctx + + +class TeaclaveChannel(Channel): + + def __init__(self, + name: str, + address: Tuple[str, int], + as_root_ca_cert_path: str, + enclave_info_path: str, + dump_report=False): + context = create_context() + super().__init__(host=address[0], port=address[1], ssl=context) + self._name = name + self._as_root_ca_cert_path = as_root_ca_cert_path + self._enclave_info_path = enclave_info_path + self._dump_report = dump_report + + def check_channel(self): + if self._state == _ChannelState.TRANSIENT_FAILURE: + raise TeaclaveException("Channel is None") + + async def __connect__(self) -> H2Protocol: + protocol = await super().__connect__() + sslobj = protocol.connection._transport.get_extra_info('ssl_object') + cert = sslobj.getpeercert(binary_form=True) if not cert: raise TeaclaveException("Peer cert is None") try: self._verify_report(self._as_root_ca_cert_path, @@ -119,9 +161,7 @@ class TeaclaveService: except Exception as e: raise TeaclaveException( f"Failed to verify attestation report: {e}") - self.channel = channel - - return self + return protocol def _verify_report(self, as_root_ca_cert_path: str, enclave_info_path: str, cert: Dict[str, Any], endpoint_name: str): @@ -244,9 +284,9 @@ class FunctionInput: """ def __init__(self, name: str, description: str, optional=False): - self.name = name - self.description = description - self.optional = optional + self.message = fe.FunctionInput(name=name, + description=description, + optional=optional) class FunctionOutput: @@ -260,9 +300,9 @@ class FunctionOutput: """ def __init__(self, name: str, description: str, optional=False): - self.name = name - self.description = description - self.optional = optional + self.message = fe.FunctionOutput(name=name, + description=description, + optional=optional) class FunctionArgument: @@ -280,9 +320,9 @@ class FunctionArgument: key: str, default_value: str = "", allow_overwrite=True): - self.key = key - self.default_value = default_value - self.allow_overwrite = allow_overwrite + self.message = fe.FunctionArgument(key=key, + default_value=default_value, + allow_overwrite=allow_overwrite) class OwnerList: @@ -295,8 +335,7 @@ class OwnerList: """ def __init__(self, data_name: str, uids: List[str]): - self.data_name = data_name - self.uids = uids + self.message = fe.OwnerList(data_name=data_name, uids=uids) class DataMap: @@ -309,8 +348,7 @@ class DataMap: """ def __init__(self, data_name, data_id): - self.data_name = data_name - self.data_id = data_id + self.message = fe.DataMap(data_name=data_name, data_id=data_id) class CryptoInfo: @@ -324,73 +362,70 @@ class CryptoInfo: """ def __init__(self, schema: str, key: List[int], iv: List[int]): - self.schema = schema - self.key = key - self.iv = iv + + self.message = FileCryptoInfo(schema=schema, + key=bytes(key), + iv=bytes(iv)) class UserRegisterRequest(Request): def __init__(self, metadata: Metadata, user_id: str, user_password: str, role: str, attribute: str): - self.request = "user_register" - self.metadata = metadata - self.id = user_id - self.password = user_password - self.role = role - self.attribute = attribute + super().__init__("UserRegister", auth.UserRegisterResponse, metadata) + self.message = auth.UserRegisterRequest(id=user_id, + password=user_password, + role=role, + attribute=attribute) class UserUpdateRequest(Request): def __init__(self, metadata: Metadata, user_id: str, user_password: str, role: str, attribute: str): - self.request = "user_update" - self.metadata = metadata - self.id = user_id - self.password = user_password - self.role = role - self.attribute = attribute + super().__init__("UserUpdate", auth.UserUpdateResponse) + self.message = auth.UserUpdateRequest(id=user_id, + password=user_password, + role=role, + attribute=attribute) class UserLoginRequest(Request): def __init__(self, user_id: str, user_password: str): - self.request = "user_login" - self.id = user_id - self.password = user_password + super().__init__("UserLogin", auth.UserLoginResponse) + self.message = auth.UserLoginRequest(id=user_id, + password=user_password) class UserChangePasswordRequest(Request): def __init__(self, metadata: Metadata, password: str): - self.request = "user_change_password" - self.metadata = metadata - self.password = password + super().__init__("UserChangePassword", auth.UserChangePasswordResponse, + metadata) + self.message = auth.UserChangePasswordRequest(password=password) class ResetUserPasswordRequest(Request): def __init__(self, metadata: Metadata, user_id: str): - self.request = "reset_user_password" - self.metadata = metadata - self.id = user_id + super().__init__("ResetUserPassword", auth.ResetUserPasswordResponse, + metadata) + self.message = auth.ResetUserPasswordRequest(id=user_id) class DeleteUserRequest(Request): def __init__(self, metadata: Metadata, user_id: str): - self.request = "delete_user" - self.metadata = metadata - self.id = user_id + super().__init__("DeleteUser", auth.DeleteUserResponse, metadata) + self.message = auth.DeleteUserRequest(id=user_id) class ListUsersRequest(Request): def __init__(self, metadata: Metadata, user_id: str): - self.request = "list_users" - self.metadata = metadata - self.id = user_id + super().__init__("ListUsers", auth.ListUsersResponse, metadata) + self.message = auth.ListUsersRequest(id=user_id) class AuthenticationService(TeaclaveService): @@ -414,9 +449,13 @@ class AuthenticationService(TeaclaveService): dump_report=False): super().__init__("authentication", address, as_root_ca_cert_path, enclave_info_path, dump_report) + self.stub = TeaclaveAuthenticationApiStub(self._channel) - def user_register(self, user_id: str, user_password: str, role: str, - attribute: str): + def user_register(self, + user_id: str, + user_password: str, + role="", + attribute=""): """Register a new user. Args: @@ -430,18 +469,17 @@ class AuthenticationService(TeaclaveService): self.check_metadata() request = UserRegisterRequest(self.metadata, user_id, user_password, role, attribute) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to register user ({reason})") - - def user_update(self, user_id: str, user_password: str, role: str, - attribute: str): + try: + response = self.call_method(request) + return response + except Exception as e: + raise TeaclaveException(f"Failed to register user {str(e)}") + + def user_update(self, + user_id: str, + user_password: str, + role: str, + attribute=""): """Update an existing user. Args: @@ -455,15 +493,11 @@ class AuthenticationService(TeaclaveService): self.check_metadata() request = UserUpdateRequest(self.metadata, user_id, user_password, role, attribute) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to update user ({reason})") + try: + response = self.call_method(request) + return response + except Exception as e: + raise TeaclaveException(f"Failed to update user {str(e)}") def user_login(self, user_id: str, user_password: str) -> str: """Login and get a session token. @@ -477,17 +511,14 @@ class AuthenticationService(TeaclaveService): str: User login token. """ - self.check_channel() + self._channel.check_channel() request = UserLoginRequest(user_id, user_password) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["token"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to login user ({reason})") + try: + response = self.call_method(request) + self.metadata = {"id": user_id, "token": response.token} + return response.token + except Exception as e: + raise TeaclaveException(f"Failed to login user {str(e)}") def user_change_password(self, user_password: str): """Change password. @@ -499,15 +530,11 @@ class AuthenticationService(TeaclaveService): self.check_channel() self.check_metadata() request = UserChangePasswordRequest(self.metadata, user_password) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to change password ({reason})") + try: + response = self.call_method(request) + return response + except Exception as e: + raise TeaclaveException(f"Failed to change password {str(e)}") def reset_user_password(self, user_id: str) -> str: """Reset password of a managed user. @@ -523,15 +550,12 @@ class AuthenticationService(TeaclaveService): self.check_channel() self.check_metadata() request = ResetUserPasswordRequest(self.metadata, user_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["password"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to reset password ({reason})") + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) + raise TeaclaveException(f"Failed to reset password {reason}") def delete_user(self, user_id: str) -> str: """Delete a user. @@ -543,14 +567,11 @@ class AuthenticationService(TeaclaveService): self.check_channel() self.check_metadata() request = DeleteUserRequest(self.metadata, user_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to delete user ({reason})") def list_users(self, user_id: str) -> str: @@ -565,15 +586,13 @@ class AuthenticationService(TeaclaveService): str: User list """ self.check_channel() + self.check_metadata() request = ListUsersRequest(self.metadata, user_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["ids"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to list user ({reason})") @@ -584,18 +603,23 @@ class RegisterFunctionRequest(Request): arguments: List[FunctionArgument], inputs: List[FunctionInput], outputs: List[FunctionOutput], user_allowlist: List[str], usage_quota: int): - self.request = "register_function" - self.metadata = metadata - self.name = name - self.description = description - self.executor_type = executor_type - self.public = public - self.payload = payload - self.arguments = arguments - self.inputs = inputs - self.outputs = outputs - self.user_allowlist = user_allowlist - self.usage_quota = usage_quota + super().__init__("RegisterFunction", fe.RegisterFunctionResponse, + metadata) + arguments = [x.message for x in arguments] + inputs = [x.message for x in inputs] + outputs = [x.message for x in outputs] + + self.message = fe.RegisterFunctionRequest( + name=name, + description=description, + executor_type=executor_type, + public=public, + payload=bytes(payload), + arguments=arguments, + inputs=inputs, + outputs=outputs, + user_allowlist=user_allowlist, + usage_quota=usage_quota) class UpdateFunctionRequest(Request): @@ -605,97 +629,87 @@ class UpdateFunctionRequest(Request): payload: List[int], arguments: List[FunctionArgument], inputs: List[FunctionInput], outputs: List[FunctionOutput], user_allowlist: List[str], usage_quota: int): - self.request = "update_function" - self.metadata = metadata - self.function_id = function_id - self.name = name - self.description = description - self.executor_type = executor_type - self.public = public - self.payload = payload - self.arguments = arguments - self.inputs = inputs - self.outputs = outputs - self.user_allowlist = user_allowlist - self.usage_quota = usage_quota + super().__init__("UpdateFunction", fe.UpdateFunctionResponse, metadata) + arguments = [x.message for x in arguments] + inputs = [x.message for x in inputs] + outputs = [x.message for x in outputs] + + self.message = fe.UpdateFunctionRequest(function_id, name, description, + executor_type, public, payload, + arguments, inputs, outputs, + user_allowlist, usage_quota) class ListFunctionsRequest(Request): def __init__(self, metadata: Metadata, user_id: str): - self.request = "list_functions" - self.metadata = metadata - self.user_id = user_id + super().__init__("ListFunctions", fe.ListFunctionsResponse, metadata) + self.message = fe.ListFunctionsRequest(user_id=user_id) class DeleteFunctionRequest(Request): def __init__(self, metadata: Metadata, function_id: str): - self.request = "delete_function" - self.metadata = metadata - self.function_id = function_id + super().__init__("ListFunctions", fe.DeleteFunctionResponse, metadata) + self.message = fe.DeleteFunctionRequest(function_id=function_id) class DisableFunctionRequest(Request): def __init__(self, metadata: Metadata, function_id: str): - self.request = "disable_function" - self.metadata = metadata - self.function_id = function_id + super().__init__("DisableFunction", fe.DisableFunctionResponse, + metadata) + self.message = fe.DisableFunctionRequest(function_id=function_id) class GetFunctionRequest(Request): def __init__(self, metadata: Metadata, function_id: str): - self.request = "get_function" - self.metadata = metadata - self.function_id = function_id + super().__init__("GetFunction", fe.GetFunctionResponse, metadata) + self.message = fe.GetFunctionRequest(function_id=function_id) class GetFunctionUsageStatsRequest(Request): def __init__(self, metadata: Metadata, function_id: str): - self.request = "get_function_usage_stats" - self.metadata = metadata - self.function_id = function_id + super().__init__("GetFunctionUsageStats", + fe.GetFunctionUsageStatsResponse, metadata) + self.message = fe.GetFunctionUsageStatsRequest(function_id=function_id) class RegisterInputFileRequest(Request): def __init__(self, metadata: Metadata, url: str, cmac: List[int], crypto_info: CryptoInfo): - self.request = "register_input_file" - self.metadata = metadata - self.url = url - self.cmac = cmac - self.crypto_info = crypto_info + super().__init__("RegisterInputFile", fe.RegisterInputFileResponse, + metadata) + self.message = fe.RegisterInputFileRequest( + url=url, cmac=bytes(cmac), crypto_info=crypto_info.message) class RegisterOutputFileRequest(Request): def __init__(self, metadata: Metadata, url: str, crypto_info: CryptoInfo): - self.request = "register_output_file" - self.metadata = metadata - self.url = url - self.crypto_info = crypto_info + super().__init__("RegisterOutputFile", fe.RegisterOutputFileResponse, + metadata) + self.message = fe.RegisterOutputFileRequest( + url=url, crypto_info=crypto_info.message) class UpdateInputFileRequest(Request): def __init__(self, metadata: Metadata, data_id: str, url: str): - self.request = "update_input_file" - self.metadata = metadata - self.data_id = data_id - self.url = url + super().__init__("UpdateInputFile", fe.UpdateInputFileResponse, + metadata) + self.message = fe.UpdateInputFileRequest(data_id=data_id, url=url) class UpdateOutputFileRequest(Request): def __init__(self, metadata: Metadata, data_id: str, url: str): - self.request = "update_output_file" - self.metadata = metadata - self.data_id = data_id - self.url = url + super().__init__("UpdateInputFile", fe.UpdateOutputFileResponse, + metadata) + self.message = fe.UpdateOutputFileRequest(data_id=data_id, url=url) class CreateTaskRequest(Request): @@ -704,56 +718,56 @@ class CreateTaskRequest(Request): function_arguments: Dict[str, Any], executor: str, inputs_ownership: List[OwnerList], outputs_ownership: List[OwnerList]): - self.request = "create_task" - self.metadata = metadata - self.function_id = function_id - self.function_arguments = function_arguments - self.executor = executor - self.inputs_ownership = inputs_ownership - self.outputs_ownership = outputs_ownership + super().__init__("CreateTask", fe.CreateTaskResponse, metadata) + inputs_ownership = [x.message for x in inputs_ownership] + outputs_ownership = [x.message for x in outputs_ownership] + + self.message = fe.CreateTaskRequest( + function_id=function_id, + function_arguments=function_arguments, + executor=executor, + inputs_ownership=inputs_ownership, + outputs_ownership=outputs_ownership) class AssignDataRequest(Request): def __init__(self, metadata: Metadata, task_id: str, inputs: List[DataMap], outputs: List[DataMap]): - self.request = "assign_data" - self.metadata = metadata - self.task_id = task_id - self.inputs = inputs - self.outputs = outputs + super().__init__("AssignData", fe.AssignDataResponse, metadata) + inputs = [x.message for x in inputs] + outputs = [x.message for x in outputs] + self.message = fe.AssignDataRequest(task_id=task_id, + inputs=inputs, + outputs=outputs) class ApproveTaskRequest(Request): def __init__(self, metadata: Metadata, task_id: str): - self.request = "approve_task" - self.metadata = metadata - self.task_id = task_id + super().__init__("ApproveTask", fe.ApproveTaskResponse, metadata) + self.message = fe.ApproveTaskRequest(task_id=task_id) class InvokeTaskRequest(Request): def __init__(self, metadata: Metadata, task_id: str): - self.request = "invoke_task" - self.metadata = metadata - self.task_id = task_id + super().__init__("InvokeTask", fe.InvokeTaskResponse, metadata) + self.message = fe.InvokeTaskRequest(task_id=task_id) class CancelTaskRequest(Request): def __init__(self, metadata: Metadata, task_id: str): - self.request = "cancel_task" - self.metadata = metadata - self.task_id = task_id + super().__init__("CancelTask", fe.CancelTaskResponse, metadata) + self.message = fe.CancelTaskRequest(task_id=task_id) class GetTaskRequest(Request): def __init__(self, metadata: Metadata, task_id: str): - self.request = "get_task" - self.metadata = metadata - self.task_id = task_id + super().__init__("GetTask", fe.GetTaskResponse, metadata) + self.message = fe.GetTaskRequest(task_id=task_id) class FrontendService(TeaclaveService): @@ -776,6 +790,7 @@ class FrontendService(TeaclaveService): dump_report=False): super().__init__("frontend", address, as_root_ca_cert_path, enclave_info_path, dump_report) + self.stub = TeaclaveFrontendStub(self._channel) def register_function( self, @@ -796,14 +811,11 @@ class FrontendService(TeaclaveService): executor_type, public, payload, arguments, inputs, outputs, user_allowlist, usage_quota) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["function_id"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response.function_id + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to register function ({reason})") def update_function( @@ -826,53 +838,45 @@ class FrontendService(TeaclaveService): description, executor_type, public, payload, arguments, inputs, outputs, user_allowlist, usage_quota) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["function_id"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] - raise TeaclaveException(f"Failed to update function ({reason})") + try: + response = self.call_method(request) + return response.function_id + except Exception as e: + reason = str(e) + raise TeaclaveException(f"Failed to register function ({reason})") def list_functions(self, user_id: str): self.check_metadata() self.check_channel() request = ListFunctionsRequest(self.metadata, user_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"] - else: - raise TeaclaveException("Failed to list functions") + try: + response = self.call_method(request) + except Exception as e: + raise TeaclaveException(f"Failed to list functions ({str(e)})") + return MessageToDict(response, + preserving_proto_field_name=True, + use_integers_for_enums=True) def get_function(self, function_id: str): self.check_metadata() self.check_channel() request = GetFunctionRequest(self.metadata, function_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to get function ({reason})") - def get_function_usage_stats(self, user_id: str, function_id: str): + def get_function_usage_stats(self, function_id: str): self.check_metadata() self.check_channel() request = GetFunctionUsageStatsRequest(self.metadata, function_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) raise TeaclaveException( f"Failed to get function usage statistics ({reason})") @@ -880,23 +884,23 @@ class FrontendService(TeaclaveService): self.check_metadata() self.check_channel() request = DeleteFunctionRequest(self.metadata, function_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"] - else: - raise TeaclaveException("Failed to delete function") + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) + raise TeaclaveException(f"Failed to delete function ({reason})") def disable_function(self, function_id: str): self.check_metadata() self.check_channel() request = DisableFunctionRequest(self.metadata, function_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"] - else: - raise TeaclaveException("Failed to disable function") + try: + response = self.call_method(request) + return response + except Exception as e: + reason = str(e) + raise TeaclaveException(f"Failed to disable function ({reason})") def register_input_file(self, url: str, schema: str, key: List[int], iv: List[int], cmac: List[int]): @@ -904,14 +908,11 @@ class FrontendService(TeaclaveService): self.check_channel() request = RegisterInputFileRequest(self.metadata, url, cmac, CryptoInfo(schema, key, iv)) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["data_id"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response.data_id + except Exception as e: + reason = str(e) raise TeaclaveException( f"Failed to register input file ({reason})") @@ -921,14 +922,11 @@ class FrontendService(TeaclaveService): self.check_channel() request = RegisterOutputFileRequest(self.metadata, url, CryptoInfo(schema, key, iv)) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["data_id"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response.data_id + except Exception as e: + reason = str(e) raise TeaclaveException( f"Failed to register output file ({reason})") @@ -944,14 +942,11 @@ class FrontendService(TeaclaveService): request = CreateTaskRequest(self.metadata, function_id, function_arguments, executor, inputs_ownership, outputs_ownership) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - return response["content"]["task_id"] - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return response.task_id + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to create task ({reason})") def assign_data_to_task(self, task_id: str, inputs: List[DataMap], @@ -959,14 +954,10 @@ class FrontendService(TeaclaveService): self.check_metadata() self.check_channel() request = AssignDataRequest(self.metadata, task_id, inputs, outputs) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + self.call_method(request) + except Exception as e: + reason = str(e) raise TeaclaveException( f"Failed to assign data to task ({reason})") @@ -974,137 +965,83 @@ class FrontendService(TeaclaveService): self.check_metadata() self.check_channel() request = ApproveTaskRequest(self.metadata, task_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + self.call_method(request) + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to approve task ({reason})") def invoke_task(self, task_id: str): self.check_metadata() self.check_channel() request = InvokeTaskRequest(self.metadata, task_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + self.call_method(request) + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to invoke task ({reason})") def cancel_task(self, task_id: str): self.check_metadata() self.check_channel() request = CancelTaskRequest(self.metadata, task_id) - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] == "ok": - pass - else: - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + self.call_method(request) + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to cancel task ({reason})") - def get_task(self, task_id: str) -> dict: + def get_task(self, task_id: str): self.check_metadata() self.check_channel() request = GetTaskRequest(self.metadata, task_id) - - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] != "ok": - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + response = self.call_method(request) + return MessageToDict(response, + preserving_proto_field_name=True, + use_integers_for_enums=True) + except Exception as e: + reason = str(e) raise TeaclaveException(f"Failed to get task result ({reason})") - return response["content"] def get_task_result(self, task_id: str): self.check_metadata() self.check_channel() request = GetTaskRequest(self.metadata, task_id) - while True: - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] != "ok": - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + time.sleep(1) + response = self.call_method(request) + if response.status == TaskStatus.Finished: + break + elif response.status == TaskStatus.Canceled: + raise TeaclaveException("Task Canceled, Error: " + + response.result.Err.reason) + elif response.status == TaskStatus.Failed: + raise TeaclaveException("Task Failed, Error: " + + response.result.Err.reason) + except Exception as e: + reason = str(e) raise TeaclaveException( f"Failed to get task result ({reason})") - time.sleep(1) - if response["content"]["status"] == TaskStatus.Finished: - break - elif response["content"]["status"] == TaskStatus.Canceled: - raise TeaclaveException( - "Task Canceled, Error: " + - response["content"]["result"]["result"]["Err"]["reason"]) - elif response["content"]["status"] == TaskStatus.Failed: - raise TeaclaveException( - "Task Failed, Error: " + - response["content"]["result"]["result"]["Err"]["reason"]) - return response["content"]["result"]["result"]["Ok"]["return_value"] + return response.result.Ok.return_value def get_output_cmac_by_tag(self, task_id: str, tag: str): self.check_metadata() self.check_channel() request = GetTaskRequest(self.metadata, task_id) while True: - _write_message(self.channel, request) - response = _read_message(self.channel) - if response["result"] != "ok": - reason = "unknown" - if "request_error" in response: - reason = response["request_error"] + try: + time.sleep(1) + response = self.call_method(request) + if response.status == TaskStatus.Finished: + break + except Exception as e: + reason = str(e) raise TeaclaveException( - f"Failed to get output cmac by tag ({reason})") - time.sleep(1) - if response["content"]["status"] == TaskStatus.Finished: - break - - return response["content"]["result"]["result"]["Ok"]["tags_map"][tag] - - -def _write_message(sock: ssl.SSLSocket, message: Any): - - class RequestEncoder(json.JSONEncoder): - - def default(self, o): - if isinstance(o, Request): - request = o.__dict__["request"] - j = {} - j["message"] = {} - j["message"][request] = {} - for k, v in o.__dict__.items(): - if k == "metadata": j[k] = v - elif k == "request": continue - else: j["message"][request][k] = v - return j - else: - return o.__dict__ - - message = json.dumps(message, cls=RequestEncoder, - separators=(',', ':')).encode() - sock.sendall(struct.pack(">Q", len(message))) - sock.sendall(message) - - -def _read_message(sock: ssl.SSLSocket): - response_len = struct.unpack(">Q", sock.read(8)) - raw = bytearray() - total_recv = 0 - while total_recv < response_len[0]: - data = sock.recv() - total_recv += len(data) - raw += data - response = json.loads(raw) - return response + f"Failed to get task result ({reason})") + response = MessageToDict(response, + preserving_proto_field_name=True, + use_integers_for_enums=True) + return base64.b64decode(response["result"]["Ok"]["tags_map"][tag]) diff --git a/tests/scripts/functional_tests.py b/tests/scripts/functional_tests.py index b7e248b9..04bb223c 100755 --- a/tests/scripts/functional_tests.py +++ b/tests/scripts/functional_tests.py @@ -33,6 +33,14 @@ from OpenSSL.crypto import load_certificate, FILETYPE_PEM, FILETYPE_ASN1 from OpenSSL.crypto import X509Store, X509StoreContext from OpenSSL import crypto +import h2.connection +import h2.events + +from io import BytesIO +from h2.config import H2Configuration +from urllib.parse import unquote +from teaclave_authentication_service_pb2 import UserLoginRequest, UserLoginResponse + HOSTNAME = 'localhost' AUTHENTICATION_SERVICE_ADDRESS = (HOSTNAME, 7776) CONTEXT = ssl._create_unverified_context() @@ -52,19 +60,6 @@ else: ENCLAVE_INFO_PATH = "../../release/tests/enclave_info.toml" -def write_message(sock, message): - message = json.dumps(message) - message = message.encode() - sock.write(struct.pack(">Q", len(message))) - sock.write(message) - - -def read_message(sock): - response_len = struct.unpack(">Q", sock.read(8)) - response = sock.read(response_len[0]) - return response - - def verify_report(cert, endpoint_name): def load_certificates(pem_bytes): @@ -121,50 +116,114 @@ def verify_report(cert, endpoint_name): raise Exception("mr_signer error") +def encode_message(message): + message_bin = message.SerializeToString() + header = struct.pack('?', False) + struct.pack('>I', len(message_bin)) + return header + message_bin + + +def decode_message(message_bin, message_type): + f = BytesIO(message_bin) + meta = f.read(5) + message_len = struct.unpack('>I', meta[1:])[0] + message_body = f.read(message_len) + message = message_type.FromString(message_body) + return message + + class TestAuthenticationService(unittest.TestCase): def setUp(self): sock = socket.create_connection(AUTHENTICATION_SERVICE_ADDRESS) + CONTEXT.set_alpn_protocols(['h2']) self.socket = CONTEXT.wrap_socket(sock, server_hostname=HOSTNAME) cert = self.socket.getpeercert(binary_form=True) verify_report(cert, "authentication") + config = H2Configuration(client_side=True, header_encoding='ascii') + self.connection = h2.connection.H2Connection(config) + self.connection.initiate_connection() + self.socket.sendall(self.connection.data_to_send()) + self.stream_id = 1 + + def set_headers(self, method_path): + headers = [(':method', 'POST'), (':path', method_path), + (':authority', HOSTNAME), (':scheme', 'https'), + ('content-type', 'application/grpc')] + return headers + + def send_message(self, message, method_path): + headers = self.set_headers(method_path) + self.connection.send_headers(self.stream_id, headers) + message_data = encode_message(message) + self.connection.send_data(self.stream_id, + message_data, + end_stream=True) + self.socket.sendall(self.connection.data_to_send()) + + def recv_message(self): + body = None + headers = None + response_stream_ended = False + max_frame_size = self.connection.max_outbound_frame_size + print(max_frame_size) + while not response_stream_ended: + # read raw data from the socket + data = self.socket.recv(max_frame_size) + if not data: + break + + # feed raw data into h2, and process resulting events + events = self.connection.receive_data(data) + for event in events: + if isinstance(event, h2.events.ResponseReceived): + headers = dict(event.headers) + if isinstance(event, h2.events.DataReceived): + # update flow control so the server doesn't starve us + self.connection.acknowledge_received_data( + event.flow_controlled_length, event.stream_id) + # more response body data received + body += event.data + if isinstance(event, h2.events.StreamEnded): + # response body completed, let's exit the loop + response_stream_ended = True + break + # send any pending data to the server + self.socket.sendall(self.connection.data_to_send()) + return (headers, body) def tearDown(self): + self.connection.close_connection() + self.socket.sendall(self.connection.data_to_send()) self.socket.close() def test_invalid_request(self): + path = '/teaclave_authentication_service_proto.TeaclaveAuthenticationApi/InvalidRequest' user_id = "invalid_id" user_password = "invalid_password" - message = { - "invalid_request": "user_login", - "id": user_id, - "password": user_password - } - write_message(self.socket, message) + message = UserLoginRequest(id=user_id, password=user_password) + self.send_message(message, path) - response = read_message(self.socket) - self.assertEqual( - response, b'{"result":"err","request_error":"invalid request"}') + (headers, response) = self.recv_message() + self.assertEqual(response, None) + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html + # grpc status UNIMPLEMENTED: 12 + self.assertEqual(headers['grpc-status'], '12') def test_login_permission_denied(self): + path = '/teaclave_authentication_service_proto.TeaclaveAuthenticationApi/UserLogin' user_id = "invalid_id" user_password = "invalid_password" - message = { - "message": { - "user_login": { - "id": user_id, - "password": user_password - } - } - } - write_message(self.socket, message) - - response = read_message(self.socket) - self.assertEqual( - response, - b'{"result":"err","request_error":"authentication failed"}') + message = UserLoginRequest(id=user_id, password=user_password) + self.send_message(message, path) + (headers, body) = self.recv_message() + self.assertEqual(body, None) + self.assertEqual(headers['grpc-status'], '16') + message = unquote(headers['grpc-message'], + encoding='utf-8', + errors='replace') + self.assertEqual(message, 'authentication failed') if __name__ == '__main__': --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
