This is an automated email from the ASF dual-hosted git repository. yuanz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-teaclave-trustzone-sdk.git
commit 265b00823029af0f7f0ba757503ba36fc3dd5824 Author: ivila <[email protected]> AuthorDate: Fri Mar 7 14:13:04 2025 +0800 examples: add mnist-rs - add example mnist-rs - add test_mnist_rs Signed-off-by: Zehui Chen <[email protected]> Reviewed-by: Yuan Zhuang <[email protected]> --- ci/ci.sh | 2 + ci/ci.sh => examples/mnist-rs/Makefile | 55 +++-- examples/mnist-rs/README.md | 234 +++++++++++++++++++++ ci/ci.sh => examples/mnist-rs/host/Cargo.toml | 56 +++-- ci/ci.sh => examples/mnist-rs/host/Makefile | 54 ++--- examples/mnist-rs/host/samples/0.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/1.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/2.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/3.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/4.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/5.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/6.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/7.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/7.png | Bin 0 -> 721 bytes examples/mnist-rs/host/samples/8.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/9.bin | Bin 0 -> 784 bytes examples/mnist-rs/host/samples/model.bin | Bin 0 -> 31575 bytes examples/mnist-rs/host/src/commands/infer.rs | 85 ++++++++ examples/mnist-rs/host/src/commands/mod.rs | 20 ++ examples/mnist-rs/host/src/commands/serve.rs | 116 ++++++++++ examples/mnist-rs/host/src/commands/train.rs | 131 ++++++++++++ examples/mnist-rs/host/src/main.rs | 45 ++++ examples/mnist-rs/host/src/tee.rs | 157 ++++++++++++++ ci/ci.sh => examples/mnist-rs/proto/Cargo.toml | 44 +--- examples/mnist-rs/proto/src/inference.rs | 18 ++ examples/mnist-rs/proto/src/lib.rs | 26 +++ examples/mnist-rs/proto/src/train.rs | 33 +++ ci/ci.sh => examples/mnist-rs/rust-toolchain.toml | 41 +--- examples/mnist-rs/ta/Cargo.toml | 51 +++++ ci/ci.sh => examples/mnist-rs/ta/common/Cargo.toml | 47 ++--- examples/mnist-rs/ta/common/src/lib.rs | 25 +++ examples/mnist-rs/ta/common/src/model.rs | 86 ++++++++ examples/mnist-rs/ta/common/src/utils.rs | 35 +++ .../mnist-rs/ta/inference/Cargo.toml | 53 ++--- examples/mnist-rs/ta/inference/Makefile | 51 +++++ examples/mnist-rs/ta/inference/build.rs | 25 +++ examples/mnist-rs/ta/inference/src/main.rs | 91 ++++++++ examples/mnist-rs/ta/inference/uuid.txt | 1 + ci/ci.sh => examples/mnist-rs/ta/train/Cargo.toml | 53 ++--- examples/mnist-rs/ta/train/Makefile | 50 +++++ examples/mnist-rs/ta/train/build.rs | 25 +++ examples/mnist-rs/ta/train/src/main.rs | 127 +++++++++++ examples/mnist-rs/ta/train/src/trainer.rs | 114 ++++++++++ examples/mnist-rs/ta/train/uuid.txt | 1 + tests/setup.sh | 5 + ci/ci.sh => tests/test_mnist_rs.sh | 51 +++-- 46 files changed, 1722 insertions(+), 286 deletions(-) diff --git a/ci/ci.sh b/ci/ci.sh index 7556757..8281081 100755 --- a/ci/ci.sh +++ b/ci/ci.sh @@ -46,6 +46,8 @@ if [ "$STD" ]; then ./test_tls_server.sh ./test_eth_wallet.sh ./test_secure_db_abstraction.sh +else + ./test_mnist_rs.sh fi popd diff --git a/ci/ci.sh b/examples/mnist-rs/Makefile old mode 100755 new mode 100644 similarity index 51% copy from ci/ci.sh copy to examples/mnist-rs/Makefile index 7556757..19bac24 --- a/ci/ci.sh +++ b/examples/mnist-rs/Makefile @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,32 @@ # specific language governing permissions and limitations # under the License. -set -xe +# If _HOST or _TA specific compiler/target are not specified, then use common +# compiler/target for both +CROSS_COMPILE_HOST ?= aarch64-linux-gnu- +CROSS_COMPILE_TA ?= aarch64-linux-gnu- +TARGET_HOST ?= aarch64-unknown-linux-gnu +TARGET_TA ?= aarch64-unknown-linux-gnu + +.PHONY: host ta all clean + +all: toolchain host ta -pushd ../tests +toolchain: + rustup toolchain install -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh +host: toolchain + $(q)make -C host TARGET=$(TARGET_HOST) \ + CROSS_COMPILE=$(CROSS_COMPILE_HOST) -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +ta: toolchain + $(q)make -C ta/train TARGET=$(TARGET_TA) \ + CROSS_COMPILE=$(CROSS_COMPILE_TA) + $(q)make -C ta/inference TARGET=$(TARGET_TA) \ + CROSS_COMPILE=$(CROSS_COMPILE_TA) -popd +clean: + $(q)make -C host clean + cd proto && cargo clean + $(q)make -C ta/train clean + $(q)make -C ta/inference clean diff --git a/examples/mnist-rs/README.md b/examples/mnist-rs/README.md new file mode 100644 index 0000000..97d7af2 --- /dev/null +++ b/examples/mnist-rs/README.md @@ -0,0 +1,234 @@ +# mnist-rs + +This demo demonstrates how to train and perform inference in TEE. + +## Install the TA + +There are two TAs in the project: + +| TA | UUID | Usage | +| ---- | ---- | ---- | +| Train | 1b5f5b74-e9cf-4e62-8c3e-7e41da6d76f6 | for training new Model| +| Inference | ff09aa8a-fbb9-4734-ae8c-d7cd1a3f6744 | for performing reference| + +Separate them as normally training consumes much more resource than performing +inference, which results in different build settings(memory, concurrency, etc). + +Make sure to install them before attempting to use any functions, you can refer +to the [Makefile](../../Makefile) in the root directory for details. + +## Running the Host + +There are three subcommands in the host: + +1. Train + + Trains a new model and exports it to the given path. + + ``` shell + mnist-rs train -o model.bin + ``` + + This subcommand downloads the MNIST dataset, feeds the dataset into TEE and + perform training inside TEE, outputs the model to the given path after the + training finished. + + ```mermaid + sequenceDiagram + actor D as Developer + participant C as mnist-rs(REE) + participant T as Train TA(TEE) + + D ->> C: train -o model.bin + C ->> C: Fetch mnist dataset + + C ->> T: open_session_with_operation + T ->> T: Initialize Global Trainer<br/> with given learning rate + T ->> C: Initialize finished + + loop iterate over num_epoches + rect rgb(191, 223, 255) + note right of C: Train + loop chunk by batch_size over train datasets + C ->> T: invoke_command Train with chunk datasets + T ->> T: Forward with given data + T ->> T: Backward Optimization + T ->> C: Train Output(loss, accuracy) + end + end + rect rgb(200, 150, 255) + note right of C: Valid + loop chunk by batch_size over test datasets + C ->> T: invoke_command Test with chunk datasets + T ->> T: Forward with given data + T ->> C: Test Output(loss, accuracy) + end + end + end + + C ->> T: Export Command + T ->> C: Model Record + C ->> D: model.bin + ``` + + For detailed usage, run: `mnist-rs train --help`, a demo output is: + + ``` shell + Usage: mnist-rs train [OPTIONS] + + Options: + -n, --num-epochs <NUM_EPOCHS> [default: 6] + -b, --batch-size <BATCH_SIZE> [default: 64] + -l, --learning-rate <LEARNING_RATE> [default: 0.0001] + -o, --output <OUTPUT> + -h, --help Print help + ``` + +2. Infer + + Loads a model from the given path, tests it with a given image, and prints + the inference result. + + ```shell + mnist-rs infer -m model.bin -b samples/7.bin -i samples/7.png + ``` + + This subcommand loads the model the model from the given path and tests it + with the given binaries and images, and prints the inference results. For + convenience, you can use the sample binaries and images in the `samples` + folder. + + ```mermaid + sequenceDiagram + actor D as Developer + participant C as mnist-rs(REE) + participant T as Inference TA(TEE) + + D ->> C: infer -m model.bin<br/> -b samples/7.bin<br/> -i samples/7.png + + C ->> C: load Model Record from disk + C ->> T: open_session_with_operation + T ->> T: Initialize Global Model<br/> with given Model Record + T ->> C: Initialize finished + + rect rgb(191, 223, 255) + note right of C: Infer with samples/7.bin + C ->> C: Load file from disk. + C ->> T: invoke_command: Feed data + T ->> T: Forward with given data + T ->> C: Infer result + C ->> D: Print result + end + + rect rgb(191, 223, 255) + note right of C: Infer with samples/7.png + C ->> C: Load image from disk. + C ->> C: Convert image to luma8 binary + C ->> T: invoke_command with data + T ->> T: Forward with given data + T ->> C: Infer result + C ->> D: Print result + end + + ``` + For detailed usage, run: `mnist-rs infer --help`, a demo output is: + + ```shell + Usage: mnist-rs infer [OPTIONS] --model <MODEL> + + Options: + -m, --model <MODEL> The path of the model + -b, --binary <BINARY> The path of the input binary, must be 768 byte binary, can be multiple + -i, --image <IMAGE> The path of the input image, must be dimension of 28x28, can be multiple + -h, --help Print help + ``` + +3. Serve + + Loads a model from the given path, starts a web server and serves it as an + API. + + ```shell + mnist-rs serve -m model.bin + ``` + + This subcommand loads the model the model from the given path and starts a + web server to provide inference APIs. + + **Available APIs**: + + | Method | Endpoint | Body | + | ---- | ---- | ---- | + | POST | `/inference/image` | an image with dimensions 28x28 | + | POST | `/inference/binary` | a 784-byte binary | + + You can test the server with the following commands: + + ```shell + # Perform inference using an image + curl --data-binary "@./samples/7.png" http://localhost:3000/inference/image + # Perform inference using a binary file + curl --data-binary "@./samples/7.bin" http://localhost:3000/inference/binary + ``` + + ```mermaid + sequenceDiagram + actor D as Developer + actor H as HttpClient + participant C as mnist-rs(REE) + participant T as Inference TA(TEE) + + D ->> C: serve -m model.bin + + C ->> C: Load Model Record from disk + C ->> T: open_session_with_operation + T ->> T: Initialize Global Model<br/> with given Model Record + T ->> C: Initialize finished + + C ->> C: Start http server + + loop accept request + par /inference/binary + H ->> C: Request with binary data + C ->> T: invoke_command: Feed data + T ->> T: Forward with given data + T ->> C: Infer result + C ->> H: Infer result + end + par /inference/image + H ->> C: Request with image data + C ->> C: Convert image to luma8 binary + C ->> T: invoke_command: Feed data + T ->> T: Forward with given data + T ->> C: Infer result + C ->> H: Infer result + end + end + + + ``` + + For detailed usage, run: `mnist-rs serve --help`, a demo output is: + + ```shell + Usage: mnist-rs serve [OPTIONS] --model <MODEL> + + Options: + -m, --model <MODEL> The path of the model + -p, --port <PORT> [default: 3000] + -h, --help Print help + ``` + +## Credits + +This demo project is inspired by the crates and examples from +[tracel-ai/burn](https://github.com/tracel-ai/burn), including: + +1. [crates/burn-no-std-tests](https://github.com/tracel-ai/burn/tree/v0.16.0/crates/burn-dataset) +2. [examples/custom-training-loop](https://github.com/tracel-ai/burn/tree/v0.16.0/examples/custom-training-loop) +3. [examples/mnist-inference-web](https://github.com/tracel-ai/burn/tree/v0.16.0/examples/mnist-inference-web) + +Special thanks to @[Guillaume Lagrange](https://github.com/laggui) for sharing +knowledge and providing early reviews. + +TODO: standard license files after 0.4.0 released. diff --git a/ci/ci.sh b/examples/mnist-rs/host/Cargo.toml old mode 100755 new mode 100644 similarity index 53% copy from ci/ci.sh copy to examples/mnist-rs/host/Cargo.toml index 7556757..90fab20 --- a/ci/ci.sh +++ b/examples/mnist-rs/host/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,29 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh +[package] +name = "mnist-rs" +version = "0.4.0" +authors = ["Teaclave Contributors <[email protected]>"] +license = "Apache-2.0" +repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git" +description = "An example of Rust OP-TEE TrustZone SDK." +edition = "2021" +publish = false -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +[dependencies] +proto = { path = "../proto" } +optee-teec = { path = "../../../optee-teec" } +clap = { version = "4.5.31", features = ["derive"] } +rand = "0.9.0" +rust-mnist = "0.2.0" +bytemuck = { version = "1.21.0", features = ["min_const_generics"] } +serde_json = "1.0.139" +image = "0.25.5" +tiny_http = "0.12.0" +anyhow = "1.0.97" +ureq = "3.0.8" +flate2 = "1.1.0" -popd +[profile.release] +lto = true diff --git a/ci/ci.sh b/examples/mnist-rs/host/Makefile old mode 100755 new mode 100644 similarity index 57% copy from ci/ci.sh copy to examples/mnist-rs/host/Makefile index 7556757..1804544 --- a/ci/ci.sh +++ b/examples/mnist-rs/host/Makefile @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,23 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh - -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi - -popd +NAME := mnist-rs + +TARGET ?= aarch64-unknown-linux-gnu +CROSS_COMPILE ?= aarch64-linux-gnu- +OBJCOPY := $(CROSS_COMPILE)objcopy +LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\" + +OUT_DIR := $(CURDIR)/target/$(TARGET)/release + + +all: host strip + +host: + @cargo build --target $(TARGET_HOST) --release --config $(LINKER_CFG) + +strip: host + @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) $(OUT_DIR)/$(NAME) + +clean: + @cargo clean diff --git a/examples/mnist-rs/host/samples/0.bin b/examples/mnist-rs/host/samples/0.bin new file mode 100644 index 0000000..6d4d6d9 Binary files /dev/null and b/examples/mnist-rs/host/samples/0.bin differ diff --git a/examples/mnist-rs/host/samples/1.bin b/examples/mnist-rs/host/samples/1.bin new file mode 100644 index 0000000..e4727fe Binary files /dev/null and b/examples/mnist-rs/host/samples/1.bin differ diff --git a/examples/mnist-rs/host/samples/2.bin b/examples/mnist-rs/host/samples/2.bin new file mode 100644 index 0000000..e9820a4 Binary files /dev/null and b/examples/mnist-rs/host/samples/2.bin differ diff --git a/examples/mnist-rs/host/samples/3.bin b/examples/mnist-rs/host/samples/3.bin new file mode 100644 index 0000000..3b30163 Binary files /dev/null and b/examples/mnist-rs/host/samples/3.bin differ diff --git a/examples/mnist-rs/host/samples/4.bin b/examples/mnist-rs/host/samples/4.bin new file mode 100644 index 0000000..5a7ec04 Binary files /dev/null and b/examples/mnist-rs/host/samples/4.bin differ diff --git a/examples/mnist-rs/host/samples/5.bin b/examples/mnist-rs/host/samples/5.bin new file mode 100644 index 0000000..8616576 Binary files /dev/null and b/examples/mnist-rs/host/samples/5.bin differ diff --git a/examples/mnist-rs/host/samples/6.bin b/examples/mnist-rs/host/samples/6.bin new file mode 100644 index 0000000..9b747bf Binary files /dev/null and b/examples/mnist-rs/host/samples/6.bin differ diff --git a/examples/mnist-rs/host/samples/7.bin b/examples/mnist-rs/host/samples/7.bin new file mode 100644 index 0000000..6e67157 Binary files /dev/null and b/examples/mnist-rs/host/samples/7.bin differ diff --git a/examples/mnist-rs/host/samples/7.png b/examples/mnist-rs/host/samples/7.png new file mode 100644 index 0000000..637d88f Binary files /dev/null and b/examples/mnist-rs/host/samples/7.png differ diff --git a/examples/mnist-rs/host/samples/8.bin b/examples/mnist-rs/host/samples/8.bin new file mode 100644 index 0000000..eb7ab3e Binary files /dev/null and b/examples/mnist-rs/host/samples/8.bin differ diff --git a/examples/mnist-rs/host/samples/9.bin b/examples/mnist-rs/host/samples/9.bin new file mode 100644 index 0000000..bd51b5a Binary files /dev/null and b/examples/mnist-rs/host/samples/9.bin differ diff --git a/examples/mnist-rs/host/samples/model.bin b/examples/mnist-rs/host/samples/model.bin new file mode 100644 index 0000000..6cae1f9 Binary files /dev/null and b/examples/mnist-rs/host/samples/model.bin differ diff --git a/examples/mnist-rs/host/src/commands/infer.rs b/examples/mnist-rs/host/src/commands/infer.rs new file mode 100644 index 0000000..8c3b339 --- /dev/null +++ b/examples/mnist-rs/host/src/commands/infer.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use image::EncodableLayout; +use optee_teec::Context; +use proto::{Image, IMAGE_SIZE}; + +#[derive(Parser, Debug)] +pub struct Args { + /// The path of the model. + #[arg(short, long)] + model: String, + /// The path of the input binary, must be 768 byte binary, can be multiple + #[arg(short, long)] + binary: Vec<String>, + /// The path of the input image, must be dimension of 28x28, can be multiple + #[arg(short, long)] + image: Vec<String>, +} + +pub fn execute(args: &Args) -> anyhow::Result<()> { + let model_path = std::path::absolute(&args.model)?; + println!("Load model from \"{}\"", model_path.display()); + let record = std::fs::read(&model_path)?; + let mut ctx = Context::new()?; + let mut caller = crate::tee::InferenceTaConnector::new(&mut ctx, &record)?; + + let mut binaries: Vec<Image> = args + .binary + .iter() + .map(|v| { + let data = std::fs::read(v)?; + anyhow::ensure!(data.len() == IMAGE_SIZE); + + TryInto::<Image>::try_into(data) + .map_err(|err| anyhow::Error::msg(format!("cannot convert {:?} into Image", err))) + }) + .collect::<Result<Vec<_>, anyhow::Error>>()?; + let images: Vec<Image> = args + .image + .iter() + .map(|v| { + let img = image::open(v).unwrap().to_luma8(); + let bytes = img.as_bytes(); + anyhow::ensure!(bytes.len() == IMAGE_SIZE); + TryInto::<Image>::try_into(bytes) + .map_err(|err| anyhow::Error::msg(format!("cannot convert {:?} into Image", err))) + }) + .collect::<Result<Vec<_>, anyhow::Error>>()?; + binaries.extend(images); + + let result = caller.infer_batch(&binaries)?; + anyhow::ensure!(binaries.len() == result.len()); + + for (i, binary) in args.binary.iter().enumerate() { + println!("{}. {}: {}", i + 1, binary, result[i]); + } + + for (i, image) in args.image.iter().enumerate() { + println!( + "{}. {}: {}", + i + args.binary.len() + 1, + image, + result[args.binary.len()] + ); + } + println!("Infer Success"); + + Ok(()) +} diff --git a/examples/mnist-rs/host/src/commands/mod.rs b/examples/mnist-rs/host/src/commands/mod.rs new file mode 100644 index 0000000..64e4e41 --- /dev/null +++ b/examples/mnist-rs/host/src/commands/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod infer; +pub mod serve; +pub mod train; diff --git a/examples/mnist-rs/host/src/commands/serve.rs b/examples/mnist-rs/host/src/commands/serve.rs new file mode 100644 index 0000000..6f8b3e1 --- /dev/null +++ b/examples/mnist-rs/host/src/commands/serve.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use image::EncodableLayout; +use proto::{IMAGE_HEIGHT, IMAGE_SIZE, IMAGE_WIDTH}; +use std::io::Cursor; +use tiny_http::{Request, Response, Server}; + +#[derive(Parser, Debug)] +pub struct Args { + /// The path of the model. + #[arg(short, long)] + model: String, + #[arg(short, long, default_value_t = 3000)] + port: u16, +} + +type HttpResponse = Response<Cursor<Vec<u8>>>; + +pub fn execute(args: &Args) -> anyhow::Result<()> { + let model_path = std::path::absolute(&args.model)?; + println!("Load model from \"{}\"", model_path.display()); + let record = std::fs::read(&model_path)?; + let mut ctx = optee_teec::Context::new()?; + let mut caller = crate::tee::InferenceTaConnector::new(&mut ctx, &record)?; + + let addr = format!("0.0.0.0:{}", args.port); + println!("Server runs on: {}", addr); + + let server = Server::http(&addr) + .map_err(|err| anyhow::Error::msg(format!("cannot start server: {:?}", err)))?; + + loop { + let mut request = server.recv()?; + let response = match handle(&mut caller, &mut request) { + Ok(v) => v, + Err(err) => { + eprintln!("unexpected error: {:#?}", err); + Response::from_string("Internal Error").with_status_code(500) + } + }; + request.respond(response)?; + } +} + +fn handle( + caller: &mut crate::tee::InferenceTaConnector<'_>, + request: &mut Request, +) -> anyhow::Result<HttpResponse> { + if request.method().ne(&tiny_http::Method::Post) { + return Ok(Response::from_string("Invalid Request Method").with_status_code(400)); + } + + match request.url() { + "/inference/image" => handle_image(caller, request), + "/inference/binary" => handle_binary(caller, request), + _ => Ok(Response::from_string("Not Found").with_status_code(404)), + } +} + +fn handle_image( + caller: &mut crate::tee::InferenceTaConnector<'_>, + request: &mut Request, +) -> anyhow::Result<HttpResponse> { + let mut data = Vec::with_capacity(IMAGE_SIZE); + request.as_reader().read_to_end(&mut data)?; + let img = image::ImageReader::new(Cursor::new(data)) + .with_guessed_format()? + .decode()? + .to_luma8(); + if img.width() as usize != IMAGE_WIDTH || img.height() as usize != IMAGE_HEIGHT { + return Ok(Response::from_string("Invalid Image").with_status_code(400)); + } + let result = handle_infer(caller, img.as_bytes())?; + + println!("Performing Inference with Image, Result is {}", result); + Ok(Response::from_data(result.to_string())) +} + +fn handle_binary( + caller: &mut crate::tee::InferenceTaConnector<'_>, + request: &mut Request, +) -> anyhow::Result<HttpResponse> { + let mut data = Vec::with_capacity(IMAGE_SIZE); + request.as_reader().read_to_end(&mut data)?; + if data.len() != IMAGE_SIZE { + return Ok(Response::from_string("Invalid Tensor").with_status_code(400)); + } + + let result = handle_infer(caller, &data)?; + println!("Performing Inference with Binary, Result is {}", result); + Ok(Response::from_data(result.to_string())) +} + +fn handle_infer( + caller: &mut crate::tee::InferenceTaConnector<'_>, + image: &[u8], +) -> anyhow::Result<u8> { + let result = caller.infer_batch(bytemuck::cast_slice(image))?; + Ok(result[0]) +} diff --git a/examples/mnist-rs/host/src/commands/train.rs b/examples/mnist-rs/host/src/commands/train.rs new file mode 100644 index 0000000..4e8fac0 --- /dev/null +++ b/examples/mnist-rs/host/src/commands/train.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::{Cursor, Read}; +use std::path::PathBuf; + +use crate::tee::TrainerTaConnector; +use optee_teec::Context; +use proto::Image; +use rand::seq::SliceRandom; + +#[derive(clap::Parser, Debug)] +pub struct Args { + #[arg(short, long, default_value_t = 6)] + num_epochs: usize, + #[arg(short, long, default_value_t = 64)] + batch_size: usize, + #[arg(short, long, default_value_t = 0.0001)] + learning_rate: f64, + #[arg(short, long)] + output: Option<String>, +} + +fn convert_datasets(images: &Vec<Image>, labels: &[u8]) -> Vec<(Image, u8)> { + let mut datasets: Vec<(Image, u8)> = images + .iter() + .map(|v| v.to_owned()) + .zip(labels.iter().copied()) + .collect(); + datasets.shuffle(&mut rand::rng()); + datasets +} + +pub fn execute(args: &Args) -> anyhow::Result<()> { + // Initialize trainer + let mut ctx = Context::new()?; + let mut trainer = TrainerTaConnector::new(&mut ctx, args.learning_rate)?; + // Download mnist data + let data = check_download_mnist_data()?; + // Prepare datasets + let train_datasets = convert_datasets(&data.train_data, &data.train_labels); + let valid_datasets = convert_datasets(&data.test_data, &data.test_labels); + // Training loop, Originally inspired by burn/crates/custom-training-loop + for epoch in 1..args.num_epochs + 1 { + for (iteration, data) in train_datasets.chunks(args.batch_size).enumerate() { + let images: Vec<Image> = data.iter().map(|v| v.0).collect(); + let labels: Vec<u8> = data.iter().map(|v| v.1).collect(); + let output = trainer.train(&images, &labels)?; + println!( + "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, iteration, output.loss, output.accuracy, + ); + } + + for (iteration, data) in valid_datasets.chunks(args.batch_size).enumerate() { + let images: Vec<Image> = data.iter().map(|v| v.0).collect(); + let labels: Vec<u8> = data.iter().map(|v| v.1).collect(); + let output = trainer.valid(&images, &labels)?; + println!( + "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", + epoch, iteration, output.loss, output.accuracy, + ); + } + } + // Export the model to the given path + if let Some(output_path) = args.output.as_ref() { + let record = trainer.export()?; + println!("Export record to \"{}\"", output_path); + std::fs::write(output_path, &record)?; + } + println!("Train Success"); + Ok(()) +} + +fn check_download_mnist_data() -> anyhow::Result<rust_mnist::Mnist> { + const DATA_PATH: &str = "./data/"; + + let folder = PathBuf::from(DATA_PATH); + if !folder.exists() { + std::fs::create_dir_all(&folder)?; + } + + // Expected file properties (name, compressed_size, uncompressed_size) for + // verification after download + const EXPECTED_MNIST_FILE_SIZES: [(&str, u64, u64); 4] = [ + ("train-images-idx3-ubyte", 9912422, 47040016), + ("train-labels-idx1-ubyte", 28881, 60008), + ("t10k-images-idx3-ubyte", 1648877, 7840016), + ("t10k-labels-idx1-ubyte", 4542, 10008), + ]; + + // Verify if all files are correctly downloaded + for (filename, compressed_size, uncompressed_size) in EXPECTED_MNIST_FILE_SIZES.iter() { + let file = folder.join(filename); + if file.exists() && file.is_file() && std::fs::metadata(&file)?.len() == *compressed_size { + println!("File {} exist, skip.", file.display()); + continue; + } + + let url = format!( + "https://storage.googleapis.com/cvdf-datasets/mnist/{}.gz", + filename + ); + println!("Download {} from {}", filename, url); + let body = ureq::get(&url).call()?.body_mut().read_to_vec()?; + + anyhow::ensure!(body.len() == *compressed_size as usize); + + let mut gz = flate2::bufread::GzDecoder::new(Cursor::new(body)); + let mut buffer = Vec::with_capacity(*uncompressed_size as usize); + gz.read_to_end(&mut buffer)?; + + std::fs::write(file, &buffer)?; + } + + Ok(rust_mnist::Mnist::new(DATA_PATH)) +} diff --git a/examples/mnist-rs/host/src/main.rs b/examples/mnist-rs/host/src/main.rs new file mode 100644 index 0000000..da3a6a6 --- /dev/null +++ b/examples/mnist-rs/host/src/main.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod commands; +mod tee; + +use clap::{Parser, Subcommand}; + +#[derive(Parser)] +#[command(version, about, long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + Train(commands::train::Args), + Infer(commands::infer::Args), + Serve(commands::serve::Args), +} + +fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + match cli.command { + Commands::Train(args) => commands::train::execute(&args), + Commands::Infer(args) => commands::infer::execute(&args), + Commands::Serve(args) => commands::serve::execute(&args), + } +} diff --git a/examples/mnist-rs/host/src/tee.rs b/examples/mnist-rs/host/src/tee.rs new file mode 100644 index 0000000..d59dbc9 --- /dev/null +++ b/examples/mnist-rs/host/src/tee.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use optee_teec::{Context, ErrorKind, Operation, ParamNone, ParamTmpRef, Session, Uuid}; +use proto::{inference, train, Image}; + +const MAX_OUTPUT_SERIALIZE_SIZE: usize = 1 * 1024; +const MAX_MODEL_RECORD_SIZE: usize = 10 * 1024 * 1024; + +pub struct TrainerTaConnector<'a> { + sess: Session<'a>, +} + +impl<'a> TrainerTaConnector<'a> { + pub fn new(ctx: &'a mut Context, learning_rate: f64) -> optee_teec::Result<Self> { + let bytes = learning_rate.to_le_bytes(); + let uuid = Uuid::parse_str(train::UUID).map_err(|err| { + println!("parse uuid \"{}\" failed due to: {:?}", train::UUID, err); + ErrorKind::BadParameters + })?; + let mut op = Operation::new( + 0, + ParamTmpRef::new_input(bytes.as_slice()), + ParamNone, + ParamNone, + ParamNone, + ); + + Ok(Self { + sess: ctx.open_session_with_operation(uuid, &mut op)?, + }) + } + pub fn train(&mut self, images: &[Image], labels: &[u8]) -> optee_teec::Result<train::Output> { + let mut buffer = vec![0_u8; MAX_OUTPUT_SERIALIZE_SIZE]; + let images = bytemuck::cast_slice(images); + let size = { + let mut op = Operation::new( + 0, + ParamTmpRef::new_input(images), + ParamTmpRef::new_input(labels), + ParamTmpRef::new_output(&mut buffer), + ParamNone, + ); + self.sess + .invoke_command(train::Command::Train as u32, &mut op)?; + op.parameters().2.updated_size() + }; + let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| { + println!("proto error: {:?}", err); + ErrorKind::BadFormat + })?; + Ok(result) + } + pub fn valid(&mut self, images: &[Image], labels: &[u8]) -> optee_teec::Result<train::Output> { + let mut buffer = vec![0_u8; MAX_OUTPUT_SERIALIZE_SIZE]; + let images = bytemuck::cast_slice(images); + let size = { + let mut op = Operation::new( + 0, + ParamTmpRef::new_input(images), + ParamTmpRef::new_input(labels), + ParamTmpRef::new_output(&mut buffer), + ParamNone, + ); + self.sess + .invoke_command(train::Command::Valid as u32, &mut op)?; + op.parameters().2.updated_size() + }; + let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| { + println!("proto error: {:?}", err); + ErrorKind::BadFormat + })?; + Ok(result) + } + + pub fn export(&mut self) -> optee_teec::Result<Vec<u8>> { + let mut buffer = vec![0_u8; MAX_MODEL_RECORD_SIZE]; + let size = { + let mut op = Operation::new( + 0, + ParamTmpRef::new_output(&mut buffer), + ParamNone, + ParamNone, + ParamNone, + ); + self.sess + .invoke_command(train::Command::Export as u32, &mut op)?; + op.parameters().0.updated_size() + }; + buffer.resize(size, 0); + Ok(buffer) + } +} + +pub struct InferenceTaConnector<'a> { + sess: Session<'a>, +} + +unsafe impl Send for InferenceTaConnector<'_> {} + +impl<'a> InferenceTaConnector<'a> { + pub fn new(ctx: &'a mut Context, record: &[u8]) -> optee_teec::Result<Self> { + let uuid = Uuid::parse_str(inference::UUID).map_err(|err| { + println!( + "parse uuid \"{}\" failed due to: {:?}", + inference::UUID, + err + ); + ErrorKind::BadParameters + })?; + let mut op = Operation::new( + 0, + ParamTmpRef::new_input(record), + ParamNone, + ParamNone, + ParamNone, + ); + + Ok(Self { + sess: ctx.open_session_with_operation(uuid, &mut op)?, + }) + } + pub fn infer_batch(&mut self, images: &[Image]) -> optee_teec::Result<Vec<u8>> { + let mut output = vec![0_u8; images.len()]; + let size = { + let mut op = Operation::new( + 0, + ParamTmpRef::new_input(bytemuck::cast_slice(images)), + ParamTmpRef::new_output(&mut output), + ParamNone, + ParamNone, + ); + self.sess.invoke_command(0, &mut op)?; + op.parameters().1.updated_size() + }; + + if output.len() != size { + println!("mismatch response, want {}, got {}", size, output.len()); + return Err(ErrorKind::Generic.into()); + } + Ok(output) + } +} diff --git a/ci/ci.sh b/examples/mnist-rs/proto/Cargo.toml old mode 100755 new mode 100644 similarity index 57% copy from ci/ci.sh copy to examples/mnist-rs/proto/Cargo.toml index 7556757..1bba120 --- a/ci/ci.sh +++ b/examples/mnist-rs/proto/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,15 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh - -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +[package] +name = "proto" +version = "0.4.0" +authors = ["Teaclave Contributors <[email protected]>"] +license = "Apache-2.0" +repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git" +description = "Data structures and functions shared by host and TA." +edition = "2021" -popd +[dependencies] +num_enum = { version = "0.7.3", default-features = false } +serde = { version = "1.0.218", default-features = false, features = ["derive"] } diff --git a/examples/mnist-rs/proto/src/inference.rs b/examples/mnist-rs/proto/src/inference.rs new file mode 100644 index 0000000..18a58fb --- /dev/null +++ b/examples/mnist-rs/proto/src/inference.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub const UUID: &str = &include_str!("../../ta/inference/uuid.txt"); diff --git a/examples/mnist-rs/proto/src/lib.rs b/examples/mnist-rs/proto/src/lib.rs new file mode 100644 index 0000000..69212ac --- /dev/null +++ b/examples/mnist-rs/proto/src/lib.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![no_std] +pub mod inference; +pub mod train; + +pub const IMAGE_HEIGHT: usize = 28; +pub const IMAGE_WIDTH: usize = 28; +pub const IMAGE_SIZE: usize = IMAGE_HEIGHT * IMAGE_WIDTH; +pub const NUM_CLASSES: usize = 10; +pub type Image = [u8; IMAGE_SIZE]; diff --git a/examples/mnist-rs/proto/src/train.rs b/examples/mnist-rs/proto/src/train.rs new file mode 100644 index 0000000..9bc804f --- /dev/null +++ b/examples/mnist-rs/proto/src/train.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use num_enum::{IntoPrimitive, TryFromPrimitive}; + +#[derive(Debug, TryFromPrimitive, IntoPrimitive)] +#[repr(u32)] +pub enum Command { + Train, + Valid, + Export, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct Output { + pub loss: f32, + pub accuracy: f32, +} + +pub const UUID: &str = &include_str!("../../ta/train/uuid.txt"); diff --git a/ci/ci.sh b/examples/mnist-rs/rust-toolchain.toml old mode 100755 new mode 100644 similarity index 57% copy from ci/ci.sh copy to examples/mnist-rs/rust-toolchain.toml index 7556757..25dac2a --- a/ci/ci.sh +++ b/examples/mnist-rs/rust-toolchain.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,12 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh - -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +# Toolchain override, burn currently requires Rust 1.83 and will soon require +# Rust 1.85. -popd +[toolchain] +channel = "nightly-2025-01-16" +components = [ "rust-src" ] +targets = ["aarch64-unknown-linux-gnu", "arm-unknown-linux-gnueabihf"] +# minimal profile: install rustc, cargo, and rust-std +profile = "minimal" diff --git a/examples/mnist-rs/ta/Cargo.toml b/examples/mnist-rs/ta/Cargo.toml new file mode 100644 index 0000000..8303332 --- /dev/null +++ b/examples/mnist-rs/ta/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[workspace] +resolver = "2" + +members = [ + "common", + "train", + "inference", +] + +[workspace.package] +edition = "2021" +license = "Apache-2.0" +version = "0.4.0" +repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git" +authors = ["Teaclave Contributors <[email protected]>"] + +[workspace.dependencies] +optee-utee-sys = { path = "../../../optee-utee/optee-utee-sys" } +optee-utee = { path = "../../../optee-utee" } +optee-utee-build = { path = "../../../optee-utee-build" } + +proto = { path = "../proto" } + +bytemuck = { version = "1.21.0", features = ["min_const_generics"] } +burn = { git = "https://github.com/tracel-ai/burn.git", rev = "a1ca4346424197f42ab1b5bff7f9d1210b029318", default-features = false, features = ["ndarray", "autodiff"] } +spin = "0.9.8" +serde = { version = "1.0.218", default-features = false, features = ["derive"] } +serde_json = { version = "1.0.139", default-features = false, features = ["alloc"] } + + +[profile.release] +panic = "abort" +lto = true +opt-level = 1 diff --git a/ci/ci.sh b/examples/mnist-rs/ta/common/Cargo.toml old mode 100755 new mode 100644 similarity index 57% copy from ci/ci.sh copy to examples/mnist-rs/ta/common/Cargo.toml index 7556757..66bdb84 --- a/ci/ci.sh +++ b/examples/mnist-rs/ta/common/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,18 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh - -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +[package] +name = "common" +description = "Some common structs and functions." +publish = false +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true -popd +[dependencies] +proto = { workspace = true } +optee-utee-sys = { workspace = true } +optee-utee = { workspace = true } +burn = { workspace = true } diff --git a/examples/mnist-rs/ta/common/src/lib.rs b/examples/mnist-rs/ta/common/src/lib.rs new file mode 100644 index 0000000..9332788 --- /dev/null +++ b/examples/mnist-rs/ta/common/src/lib.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![no_std] +extern crate alloc; + +mod model; +mod utils; + +pub use model::Model; +pub use utils::*; diff --git a/examples/mnist-rs/ta/common/src/model.rs b/examples/mnist-rs/ta/common/src/model.rs new file mode 100644 index 0000000..39213f4 --- /dev/null +++ b/examples/mnist-rs/ta/common/src/model.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use alloc::vec::Vec; +use burn::{ + prelude::*, + record::{FullPrecisionSettings, Recorder, RecorderError}, +}; +use proto::{Image, IMAGE_SIZE, NUM_CLASSES}; + +/// This is a simple model designed solely to demonstrate how to train and +/// perform inference, don't use it in production. +#[derive(Module, Debug)] +pub struct Model<B: Backend> { + linear: nn::Linear<B>, +} + +impl<B: Backend> Model<B> { + pub fn new(device: &B::Device) -> Self { + Self { + linear: nn::LinearConfig::new(IMAGE_SIZE, NUM_CLASSES).init(device), + } + } + + pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> { + self.linear.forward(input) + } + + pub fn export(&self) -> Result<Vec<u8>, RecorderError> { + let recorder = burn::record::BinBytesRecorder::<FullPrecisionSettings>::new(); + recorder.record(self.clone().into_record(), ()) + } + + pub fn import(device: &B::Device, record: Vec<u8>) -> Result<Self, RecorderError> { + let recorder = burn::record::BinBytesRecorder::<FullPrecisionSettings>::new(); + let record = recorder.load(record, device)?; + + let m = Self::new(device); + Ok(m.load_record(record)) + } +} + +impl<B: Backend> Model<B> { + pub fn image_to_tensor(device: &B::Device, image: &Image) -> Tensor<B, 2> { + let tensor = TensorData::from(image.as_slice()).convert::<B::FloatElem>(); + let tensor = Tensor::<B, 1>::from_data(tensor, device); + let tensor = tensor.reshape([1, IMAGE_SIZE]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + ((tensor / 255) - 0.1307) / 0.3081 + } + + pub fn images_to_tensors(device: &B::Device, images: &[Image]) -> Tensor<B, 2> { + let tensors = images + .iter() + .map(|v| Self::image_to_tensor(device, v)) + .collect(); + Tensor::cat(tensors, 0) + } + + pub fn labels_to_tensors(device: &B::Device, labels: &[u8]) -> Tensor<B, 1, Int> { + let targets = labels + .iter() + .map(|item| { + Tensor::<B, 1, Int>::from_data([(*item as i64).elem::<B::IntElem>()], device) + }) + .collect(); + Tensor::cat(targets, 0) + } +} diff --git a/examples/mnist-rs/ta/common/src/utils.rs b/examples/mnist-rs/ta/common/src/utils.rs new file mode 100644 index 0000000..e80e030 --- /dev/null +++ b/examples/mnist-rs/ta/common/src/utils.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use optee_utee::{trace_println, ErrorKind, Parameter, Result}; + +pub fn copy_to_output(param: &mut Parameter, data: &[u8]) -> Result<()> { + let mut output = unsafe { param.as_memref()? }; + + let buffer = output.buffer(); + if buffer.len() < data.len() { + trace_println!( + "expect output buffer size {}, got size {} instead", + data.len(), + buffer.len() + ); + return Err(ErrorKind::ShortBuffer.into()); + } + buffer[..data.len()].copy_from_slice(data); + output.set_updated_size(data.len()); + Ok(()) +} diff --git a/ci/ci.sh b/examples/mnist-rs/ta/inference/Cargo.toml old mode 100755 new mode 100644 similarity index 57% copy from ci/ci.sh copy to examples/mnist-rs/ta/inference/Cargo.toml index 7556757..b10e306 --- a/ci/ci.sh +++ b/examples/mnist-rs/ta/inference/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,26 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh +[package] +name = "inference" +description = "An inference TA of MNIST model." +publish = false +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +[dependencies] +common = { path = "../common" } +proto = { workspace = true } +optee-utee-sys = { workspace = true } +optee-utee = { workspace = true } +burn = { workspace = true } +serde_json = { workspace = true } +bytemuck = { workspace = true } +spin = { workspace = true } -popd +[build-dependencies] +proto = { workspace = true } +optee-utee-build = { workspace = true } diff --git a/examples/mnist-rs/ta/inference/Makefile b/examples/mnist-rs/ta/inference/Makefile new file mode 100644 index 0000000..1e6f1bb --- /dev/null +++ b/examples/mnist-rs/ta/inference/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +UUID ?= $(shell cat "./uuid.txt") +NAME := inference + +TARGET ?= aarch64-unknown-linux-gnu +CROSS_COMPILE ?= aarch64-linux-gnu- +OBJCOPY := $(CROSS_COMPILE)objcopy +# Configure the linker to use GCC, which works on both cross-compilation and ARM machines +LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\" +EXTRA_FLAGS = -Z build-std=core,alloc + +TA_SIGN_KEY ?= $(TA_DEV_KIT_DIR)/keys/default_ta.pem +SIGN := $(TA_DEV_KIT_DIR)/scripts/sign_encrypt.py +OUT_DIR := $(CURDIR)/../target/$(TARGET)/release + + +ifeq ($(STD),) +all: ta strip sign +else +all: + @echo "Please \`unset STD\` then rerun \`source environment\` to build the No-STD version" +endif + +ta: + @cargo build --target $(TARGET) --release --config $(LINKER_CFG) $(EXTRA_FLAGS) + +strip: ta + @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) $(OUT_DIR)/stripped_$(NAME) + +sign: strip + @$(SIGN) --uuid $(UUID) --key $(TA_SIGN_KEY) --in $(OUT_DIR)/stripped_$(NAME) --out $(OUT_DIR)/$(UUID).ta + @echo "SIGN => ${UUID}" + +clean: + @cargo clean diff --git a/examples/mnist-rs/ta/inference/build.rs b/examples/mnist-rs/ta/inference/build.rs new file mode 100644 index 0000000..314c05c --- /dev/null +++ b/examples/mnist-rs/ta/inference/build.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use optee_utee_build::{Error, RustEdition, TaConfig}; + +fn main() -> Result<(), Error> { + let config = TaConfig::new_default_with_cargo_env(proto::train::UUID)? + .ta_data_size(1 * 1024 * 1024) + .ta_stack_size(1 * 1024 * 1024); + optee_utee_build::build(RustEdition::Before2024, config) +} diff --git a/examples/mnist-rs/ta/inference/src/main.rs b/examples/mnist-rs/ta/inference/src/main.rs new file mode 100644 index 0000000..9c23396 --- /dev/null +++ b/examples/mnist-rs/ta/inference/src/main.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![no_std] +#![no_main] +extern crate alloc; + +use burn::{ + backend::{ndarray::NdArrayDevice, NdArray}, + tensor::cast::ToElement, +}; + +use common::{copy_to_output, Model}; +use optee_utee::{ + ta_close_session, ta_create, ta_destroy, ta_invoke_command, ta_open_session, trace_println, +}; +use optee_utee::{ErrorKind, Parameters, Result}; +use proto::Image; +use spin::Mutex; + +type NoStdModel = Model<NdArray>; +const DEVICE: NdArrayDevice = NdArrayDevice::Cpu; +static MODEL: Mutex<Option<NoStdModel>> = Mutex::new(Option::None); + +#[ta_create] +fn create() -> Result<()> { + trace_println!("[+] TA create"); + Ok(()) +} + +#[ta_open_session] +fn open_session(params: &mut Parameters) -> Result<()> { + let mut p0 = unsafe { params.0.as_memref()? }; + + let mut model = MODEL.lock(); + model.replace(Model::import(&DEVICE, p0.buffer().to_vec()).map_err(|err| { + trace_println!("import failed: {:?}", err); + ErrorKind::BadParameters + })?); + + Ok(()) +} + +#[ta_close_session] +fn close_session() { + trace_println!("[+] TA close session"); +} + +#[ta_destroy] +fn destroy() { + trace_println!("[+] TA destroy"); +} + +#[ta_invoke_command] +fn invoke_command(_cmd_id: u32, params: &mut Parameters) -> Result<()> { + trace_println!("[+] TA invoke command"); + let mut p0 = unsafe { params.0.as_memref()? }; + let images: &[Image] = bytemuck::cast_slice(p0.buffer()); + let input = NoStdModel::images_to_tensors(&DEVICE, images); + + let output = MODEL + .lock() + .as_ref() + .ok_or(ErrorKind::CorruptObject)? + .forward(input); + let result: alloc::vec::Vec<u8> = output + .iter_dim(0) + .map(|v| { + let data = burn::tensor::activation::softmax(v, 1); + data.argmax(1).into_scalar().to_u8() + }) + .collect(); + + copy_to_output(&mut params.1, &result) +} + +include!(concat!(env!("OUT_DIR"), "/user_ta_header.rs")); diff --git a/examples/mnist-rs/ta/inference/uuid.txt b/examples/mnist-rs/ta/inference/uuid.txt new file mode 100644 index 0000000..d00c0a4 --- /dev/null +++ b/examples/mnist-rs/ta/inference/uuid.txt @@ -0,0 +1 @@ +ff09aa8a-fbb9-4734-ae8c-d7cd1a3f6744 \ No newline at end of file diff --git a/ci/ci.sh b/examples/mnist-rs/ta/train/Cargo.toml old mode 100755 new mode 100644 similarity index 56% copy from ci/ci.sh copy to examples/mnist-rs/ta/train/Cargo.toml index 7556757..a0c24f9 --- a/ci/ci.sh +++ b/examples/mnist-rs/ta/train/Cargo.toml @@ -1,5 +1,3 @@ -#!/bin/bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,35 +15,26 @@ # specific language governing permissions and limitations # under the License. -set -xe - -pushd ../tests - -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh +[package] +name = "train" +description = "An training TA of MNIST model." +publish = false +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +[dependencies] +common = { path = "../common" } +proto = { workspace = true } +optee-utee-sys = { workspace = true } +optee-utee = { workspace = true } +burn = { workspace = true } +spin = { workspace = true } +serde_json = { workspace = true } +bytemuck = { workspace = true, features = ["min_const_generics"] } -popd +[build-dependencies] +proto = { workspace = true } +optee-utee-build = { workspace = true } diff --git a/examples/mnist-rs/ta/train/Makefile b/examples/mnist-rs/ta/train/Makefile new file mode 100644 index 0000000..4b8bbd4 --- /dev/null +++ b/examples/mnist-rs/ta/train/Makefile @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +UUID ?= $(shell cat "./uuid.txt") +NAME := train + +TARGET ?= aarch64-unknown-linux-gnu +CROSS_COMPILE ?= aarch64-linux-gnu- +OBJCOPY := $(CROSS_COMPILE)objcopy +# Configure the linker to use GCC, which works on both cross-compilation and ARM machines +LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\" +EXTRA_FLAGS = -Z build-std=core,alloc + +TA_SIGN_KEY ?= $(TA_DEV_KIT_DIR)/keys/default_ta.pem +SIGN := $(TA_DEV_KIT_DIR)/scripts/sign_encrypt.py +OUT_DIR := $(CURDIR)/../target/$(TARGET)/release + +ifeq ($(STD),) +all: ta strip sign +else +all: + @echo "Please \`unset STD\` then rerun \`source environment\` to build the No-STD version" +endif + +ta: + @cargo build --target $(TARGET) --release --config $(LINKER_CFG) $(EXTRA_FLAGS) + +strip: ta + @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) $(OUT_DIR)/stripped_$(NAME) + +sign: strip + @$(SIGN) --uuid $(UUID) --key $(TA_SIGN_KEY) --in $(OUT_DIR)/stripped_$(NAME) --out $(OUT_DIR)/$(UUID).ta + @echo "SIGN => ${UUID}" + +clean: + @cargo clean diff --git a/examples/mnist-rs/ta/train/build.rs b/examples/mnist-rs/ta/train/build.rs new file mode 100644 index 0000000..e60b54b --- /dev/null +++ b/examples/mnist-rs/ta/train/build.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use optee_utee_build::{Error, RustEdition, TaConfig}; + +fn main() -> Result<(), Error> { + let config = TaConfig::new_default_with_cargo_env(proto::train::UUID)? + .ta_data_size(5 * 1024 * 1024) + .ta_stack_size(1 * 1024 * 1024); + optee_utee_build::build(RustEdition::Before2024, config) +} diff --git a/examples/mnist-rs/ta/train/src/main.rs b/examples/mnist-rs/ta/train/src/main.rs new file mode 100644 index 0000000..f3ad57a --- /dev/null +++ b/examples/mnist-rs/ta/train/src/main.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![no_std] +#![no_main] +extern crate alloc; + +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use common::copy_to_output; +use optee_utee::{ + ta_close_session, ta_create, ta_destroy, ta_invoke_command, ta_open_session, trace_println, +}; +use optee_utee::{ErrorKind, Parameters, Result}; +use proto::train::Command; +use spin::Mutex; + +mod trainer; + +type NoStdTrainer = trainer::Trainer<Autodiff<NdArray>>; + +const DEVICE: NdArrayDevice = NdArrayDevice::Cpu; +static TRAINER: Mutex<Option<NoStdTrainer>> = Mutex::new(Option::None); + +#[ta_create] +fn create() -> Result<()> { + trace_println!("[+] TA create"); + Ok(()) +} + +#[ta_open_session] +fn open_session(params: &mut Parameters) -> Result<()> { + let mut p0 = unsafe { params.0.as_memref()? }; + + let learning_rate = f64::from_le_bytes(p0.buffer().try_into().map_err(|err| { + trace_println!("bad parameter {:?}", err); + ErrorKind::BadParameters + })?); + trace_println!("Initialize with learning_rate: {}", learning_rate); + + let mut trainer = TRAINER.lock(); + trainer.replace(NoStdTrainer::new(DEVICE, learning_rate)); + + Ok(()) +} + +#[ta_close_session] +fn close_session() { + trace_println!("[+] TA close session"); +} + +#[ta_destroy] +fn destroy() { + trace_println!("[+] TA destroy"); +} + +#[ta_invoke_command] +fn invoke_command(cmd_id: u32, params: &mut Parameters) -> Result<()> { + match Command::try_from(cmd_id) { + Ok(Command::Train) => { + let mut p0 = unsafe { params.0.as_memref()? }; + let mut p1 = unsafe { params.1.as_memref()? }; + + let images = p0.buffer(); + let labels = p1.buffer(); + + let mut trainer = TRAINER.lock(); + let result = trainer + .as_mut() + .ok_or(ErrorKind::CorruptObject)? + .train(bytemuck::cast_slice(images), labels); + let bytes = serde_json::to_vec(&result).map_err(|err| { + trace_println!("unexpected error: {:?}", err); + ErrorKind::BadState + })?; + + copy_to_output(&mut params.2, &bytes) + } + Ok(Command::Valid) => { + let mut p0 = unsafe { params.0.as_memref()? }; + let mut p1 = unsafe { params.1.as_memref()? }; + + let images = p0.buffer(); + let labels = p1.buffer(); + + let trainer = TRAINER.lock(); + let result = trainer + .as_ref() + .ok_or(ErrorKind::CorruptObject)? + .valid(bytemuck::cast_slice(images), labels); + + let bytes = serde_json::to_vec(&result).map_err(|err| { + trace_println!("unexpected error: {:?}", err); + ErrorKind::BadState + })?; + copy_to_output(&mut params.2, &bytes) + } + Ok(Command::Export) => { + let trainer = TRAINER.lock(); + let result = trainer + .as_ref() + .ok_or(ErrorKind::CorruptObject)? + .export() + .map_err(|err| { + trace_println!("unexpected error: {:?}", err); + ErrorKind::BadState + })?; + copy_to_output(&mut params.0, &result) + } + Err(_) => Err(ErrorKind::BadParameters.into()), + } +} + +include!(concat!(env!("OUT_DIR"), "/user_ta_header.rs")); diff --git a/examples/mnist-rs/ta/train/src/trainer.rs b/examples/mnist-rs/ta/train/src/trainer.rs new file mode 100644 index 0000000..5d4b9a8 --- /dev/null +++ b/examples/mnist-rs/ta/train/src/trainer.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use alloc::vec::Vec; +use burn::{ + module::AutodiffModule, + nn::loss::CrossEntropyLoss, + optim::{adaptor::OptimizerAdaptor, Adam, AdamConfig, GradientsParams, Optimizer}, + prelude::*, + record::RecorderError, + tensor::{backend::AutodiffBackend, cast::ToElement}, +}; +use common::Model; +use proto::{train::Output, Image}; + +pub struct Trainer<B: AutodiffBackend> { + model: Model<B>, + device: B::Device, + optim: OptimizerAdaptor<Adam, Model<B>, B>, + lr: f64, +} + +impl<B: AutodiffBackend> Trainer<B> { + pub fn new(device: B::Device, lr: f64) -> Self { + let mut seed = [0_u8; 8]; + optee_utee::Random::generate(seed.as_mut_slice()); + B::seed(u64::from_le_bytes(seed)); + + Self { + optim: AdamConfig::new().init(), + model: Model::new(&device), + device, + lr, + } + } + + // Originally inspired by the burn/examples/custom-training-loop package. + // You may refer to + // https://github.com/tracel-ai/burn/blob/v0.16.0/examples/custom-training-loop + // for details. + pub fn train(&mut self, images: &[Image], labels: &[u8]) -> Output { + let images = Model::images_to_tensors(&self.device, images); + let targets = Model::labels_to_tensors(&self.device, labels); + let model = self.model.clone(); + + let output = model.forward(images); + let loss = + CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let accuracy = accuracy(output, targets); + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model. + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer. + self.model = self.optim.step(self.lr, model, grads); + + Output { + loss: loss.into_scalar().to_f32(), + accuracy, + } + } + + // Originally inspired by the burn/examples/custom-training-loop package. + // You may refer to + // https://github.com/tracel-ai/burn/blob/v0.16.0/examples/custom-training-loop + // for details. + pub fn valid(&self, images: &[Image], labels: &[u8]) -> Output { + // Get the model without autodiff. + let model_valid = self.model.valid(); + + let images = Model::images_to_tensors(&self.device, images); + let targets = Model::labels_to_tensors(&self.device, labels); + + let output = model_valid.forward(images); + let loss = + CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone()); + let accuracy = accuracy(output, targets); + + Output { + loss: loss.into_scalar().to_f32(), + accuracy, + } + } + + pub fn export(&self) -> Result<Vec<u8>, RecorderError> { + self.model.export() + } +} + +// Originally copied from the burn/crates/no-std-tests package. You may refer +// to https://github.com/tracel-ai/burn/blob/v0.16.0/crates/burn-no-std-test for +// details. +fn accuracy<B: Backend>(output: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> f32 { + let predictions = output.argmax(1).squeeze(1); + let num_predictions: usize = targets.dims().iter().product(); + let num_corrects = predictions.equal(targets).int().sum().into_scalar(); + + num_corrects.elem::<f32>() / num_predictions as f32 * 100.0 +} diff --git a/examples/mnist-rs/ta/train/uuid.txt b/examples/mnist-rs/ta/train/uuid.txt new file mode 100644 index 0000000..1c014e8 --- /dev/null +++ b/examples/mnist-rs/ta/train/uuid.txt @@ -0,0 +1 @@ +1b5f5b74-e9cf-4e62-8c3e-7e41da6d76f6 \ No newline at end of file diff --git a/tests/setup.sh b/tests/setup.sh index e8a6a2a..6e7329f 100755 --- a/tests/setup.sh +++ b/tests/setup.sh @@ -43,6 +43,11 @@ run_in_qemu() { sleep 5 } +run_in_qemu_with_timeout_secs() { + screen -S qemu_screen -p 0 -X stuff "$1\n" + sleep $2 +} + # Check if the image file exists locally if [ ! -d "${IMG}" ]; then echo "Image file '${IMG}' not found locally. Downloading from network." diff --git a/ci/ci.sh b/tests/test_mnist_rs.sh similarity index 51% copy from ci/ci.sh copy to tests/test_mnist_rs.sh index 7556757..8b66e79 100755 --- a/ci/ci.sh +++ b/tests/test_mnist_rs.sh @@ -19,33 +19,30 @@ set -xe -pushd ../tests +# Include base script +source setup.sh -./test_hello_world.sh -./test_random.sh -./test_secure_storage.sh -./test_aes.sh -./test_hotp.sh -./test_acipher.sh -./test_big_int.sh -./test_diffie_hellman.sh -./test_digest.sh -./test_authentication.sh -./test_time.sh -./test_signature_verification.sh -./test_supp_plugin.sh -./test_error_handling.sh -./test_tcp_client.sh -./test_udp_socket.sh +# Copy TA and host binary +cp ../examples/mnist-rs/ta/target/$TARGET_TA/release/*.ta shared +cp ../examples/mnist-rs/host/target/$TARGET_HOST/release/mnist-rs shared +# Copy samples files +cp -r ../examples/mnist-rs/host/samples shared -# Run std only tests -if [ "$STD" ]; then - ./test_serde.sh - ./test_message_passing_interface.sh - ./test_tls_client.sh - ./test_tls_server.sh - ./test_eth_wallet.sh - ./test_secure_db_abstraction.sh -fi +# Run script specific commands in QEMU +run_in_qemu "cp *.ta /lib/optee_armtz/\n" +# Do not export the model due to QEMU's memory limitations. +run_in_qemu_with_timeout_secs "./mnist-rs train -n 1\n" 300 +run_in_qemu "./mnist-rs infer -m samples/model.bin -b samples/7.bin -i samples/7.png\n" +run_in_qemu "^C" -popd +# Script specific checks +{ + grep -q "Train Success" screenlog.0 && + grep -q "Infer Success" screenlog.0 +} || { + cat -v screenlog.0 + cat -v /tmp/serial.log + false +} + +rm screenlog.0 --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
