This is an automated email from the ASF dual-hosted git repository.

yuanz pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/incubator-teaclave-trustzone-sdk.git

commit 265b00823029af0f7f0ba757503ba36fc3dd5824
Author: ivila <[email protected]>
AuthorDate: Fri Mar 7 14:13:04 2025 +0800

    examples: add mnist-rs
    
    - add example mnist-rs
    - add test_mnist_rs
    
    Signed-off-by: Zehui Chen <[email protected]>
    Reviewed-by: Yuan Zhuang <[email protected]>
---
 ci/ci.sh                                           |   2 +
 ci/ci.sh => examples/mnist-rs/Makefile             |  55 +++--
 examples/mnist-rs/README.md                        | 234 +++++++++++++++++++++
 ci/ci.sh => examples/mnist-rs/host/Cargo.toml      |  56 +++--
 ci/ci.sh => examples/mnist-rs/host/Makefile        |  54 ++---
 examples/mnist-rs/host/samples/0.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/1.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/2.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/3.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/4.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/5.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/6.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/7.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/7.png               | Bin 0 -> 721 bytes
 examples/mnist-rs/host/samples/8.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/9.bin               | Bin 0 -> 784 bytes
 examples/mnist-rs/host/samples/model.bin           | Bin 0 -> 31575 bytes
 examples/mnist-rs/host/src/commands/infer.rs       |  85 ++++++++
 examples/mnist-rs/host/src/commands/mod.rs         |  20 ++
 examples/mnist-rs/host/src/commands/serve.rs       | 116 ++++++++++
 examples/mnist-rs/host/src/commands/train.rs       | 131 ++++++++++++
 examples/mnist-rs/host/src/main.rs                 |  45 ++++
 examples/mnist-rs/host/src/tee.rs                  | 157 ++++++++++++++
 ci/ci.sh => examples/mnist-rs/proto/Cargo.toml     |  44 +---
 examples/mnist-rs/proto/src/inference.rs           |  18 ++
 examples/mnist-rs/proto/src/lib.rs                 |  26 +++
 examples/mnist-rs/proto/src/train.rs               |  33 +++
 ci/ci.sh => examples/mnist-rs/rust-toolchain.toml  |  41 +---
 examples/mnist-rs/ta/Cargo.toml                    |  51 +++++
 ci/ci.sh => examples/mnist-rs/ta/common/Cargo.toml |  47 ++---
 examples/mnist-rs/ta/common/src/lib.rs             |  25 +++
 examples/mnist-rs/ta/common/src/model.rs           |  86 ++++++++
 examples/mnist-rs/ta/common/src/utils.rs           |  35 +++
 .../mnist-rs/ta/inference/Cargo.toml               |  53 ++---
 examples/mnist-rs/ta/inference/Makefile            |  51 +++++
 examples/mnist-rs/ta/inference/build.rs            |  25 +++
 examples/mnist-rs/ta/inference/src/main.rs         |  91 ++++++++
 examples/mnist-rs/ta/inference/uuid.txt            |   1 +
 ci/ci.sh => examples/mnist-rs/ta/train/Cargo.toml  |  53 ++---
 examples/mnist-rs/ta/train/Makefile                |  50 +++++
 examples/mnist-rs/ta/train/build.rs                |  25 +++
 examples/mnist-rs/ta/train/src/main.rs             | 127 +++++++++++
 examples/mnist-rs/ta/train/src/trainer.rs          | 114 ++++++++++
 examples/mnist-rs/ta/train/uuid.txt                |   1 +
 tests/setup.sh                                     |   5 +
 ci/ci.sh => tests/test_mnist_rs.sh                 |  51 +++--
 46 files changed, 1722 insertions(+), 286 deletions(-)

diff --git a/ci/ci.sh b/ci/ci.sh
index 7556757..8281081 100755
--- a/ci/ci.sh
+++ b/ci/ci.sh
@@ -46,6 +46,8 @@ if [ "$STD" ]; then
     ./test_tls_server.sh
     ./test_eth_wallet.sh
     ./test_secure_db_abstraction.sh
+else
+    ./test_mnist_rs.sh
 fi
 
 popd
diff --git a/ci/ci.sh b/examples/mnist-rs/Makefile
old mode 100755
new mode 100644
similarity index 51%
copy from ci/ci.sh
copy to examples/mnist-rs/Makefile
index 7556757..19bac24
--- a/ci/ci.sh
+++ b/examples/mnist-rs/Makefile
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,32 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
+# If _HOST or _TA specific compiler/target are not specified, then use common
+# compiler/target for both
+CROSS_COMPILE_HOST ?= aarch64-linux-gnu-
+CROSS_COMPILE_TA ?= aarch64-linux-gnu-
+TARGET_HOST ?= aarch64-unknown-linux-gnu
+TARGET_TA ?= aarch64-unknown-linux-gnu
+
+.PHONY: host ta all clean
+
+all: toolchain host ta
 
-pushd ../tests
+toolchain:
+       rustup toolchain install
 
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
+host: toolchain
+       $(q)make -C host TARGET=$(TARGET_HOST) \
+               CROSS_COMPILE=$(CROSS_COMPILE_HOST)
 
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+ta: toolchain
+       $(q)make -C ta/train TARGET=$(TARGET_TA) \
+               CROSS_COMPILE=$(CROSS_COMPILE_TA)
+       $(q)make -C ta/inference TARGET=$(TARGET_TA) \
+               CROSS_COMPILE=$(CROSS_COMPILE_TA)
 
-popd
+clean:
+       $(q)make -C host clean
+       cd proto && cargo clean
+       $(q)make -C ta/train clean
+       $(q)make -C ta/inference clean
diff --git a/examples/mnist-rs/README.md b/examples/mnist-rs/README.md
new file mode 100644
index 0000000..97d7af2
--- /dev/null
+++ b/examples/mnist-rs/README.md
@@ -0,0 +1,234 @@
+# mnist-rs
+
+This demo demonstrates how to train and perform inference in TEE.
+
+## Install the TA
+
+There are two TAs in the project:
+
+| TA | UUID | Usage |
+| ---- | ---- | ---- |
+| Train | 1b5f5b74-e9cf-4e62-8c3e-7e41da6d76f6 | for training new Model|
+| Inference | ff09aa8a-fbb9-4734-ae8c-d7cd1a3f6744 | for performing reference|
+
+Separate them as normally training consumes much more resource than performing
+inference, which results in different build settings(memory, concurrency, etc).
+
+Make sure to install them before attempting to use any functions, you can refer
+to the [Makefile](../../Makefile) in the root directory for details.
+
+## Running the Host
+
+There are three subcommands in the host:
+
+1. Train
+
+    Trains a new model and exports it to the given path.
+
+    ``` shell
+    mnist-rs train -o model.bin
+    ```
+
+    This subcommand downloads the MNIST dataset, feeds the dataset into TEE and
+    perform training inside TEE, outputs the model to the given path after the
+    training finished.
+
+    ```mermaid
+    sequenceDiagram
+    actor D as Developer
+    participant C as mnist-rs(REE)
+    participant T as Train TA(TEE)
+    
+    D ->> C: train -o model.bin
+    C ->> C: Fetch mnist dataset
+
+    C ->> T: open_session_with_operation
+    T ->> T: Initialize Global Trainer<br/> with given learning rate
+    T ->> C: Initialize finished
+
+    loop iterate over num_epoches
+        rect rgb(191, 223, 255)
+        note right of C: Train
+            loop chunk by batch_size over train datasets
+                C ->> T: invoke_command Train with chunk datasets
+                T ->> T: Forward with given data
+                T ->> T: Backward Optimization
+                T ->> C: Train Output(loss, accuracy)
+            end
+        end
+        rect rgb(200, 150, 255)
+        note right of C: Valid
+            loop chunk by batch_size over test datasets
+                C ->> T: invoke_command Test with chunk datasets
+                T ->> T: Forward with given data
+                T ->> C: Test Output(loss, accuracy)
+            end
+        end
+    end
+
+    C ->> T: Export Command
+    T ->> C: Model Record
+    C ->> D: model.bin
+    ```
+
+    For detailed usage, run: `mnist-rs train --help`, a demo output is:
+
+    ``` shell
+    Usage: mnist-rs train [OPTIONS]
+
+    Options:
+      -n, --num-epochs <NUM_EPOCHS>        [default: 6]
+      -b, --batch-size <BATCH_SIZE>        [default: 64]
+      -l, --learning-rate <LEARNING_RATE>  [default: 0.0001]
+      -o, --output <OUTPUT>
+      -h, --help                           Print help
+    ```
+
+2. Infer
+
+    Loads a model from the given path, tests it with a given image, and prints
+    the inference result.
+
+    ```shell
+    mnist-rs infer -m model.bin -b samples/7.bin -i samples/7.png
+    ```
+
+    This subcommand loads the model the model from the given path and tests it
+    with the given binaries and images, and prints the inference results. For
+    convenience, you can use the sample binaries and images in the `samples`
+    folder.
+
+    ```mermaid
+    sequenceDiagram
+    actor D as Developer
+    participant C as mnist-rs(REE)
+    participant T as Inference TA(TEE)
+    
+    D ->> C: infer -m model.bin<br/> -b samples/7.bin<br/> -i samples/7.png
+
+    C ->> C: load Model Record from disk
+    C ->> T: open_session_with_operation
+    T ->> T: Initialize Global Model<br/> with given Model Record
+    T ->> C: Initialize finished
+
+    rect rgb(191, 223, 255)
+    note right of C: Infer with samples/7.bin
+        C ->> C: Load file from disk.
+        C ->> T: invoke_command: Feed data
+        T ->> T: Forward with given data
+        T ->> C: Infer result
+        C ->> D: Print result
+    end
+
+    rect rgb(191, 223, 255)
+    note right of C: Infer with samples/7.png
+        C ->> C: Load image from disk.
+        C ->> C: Convert image to luma8 binary
+        C ->> T: invoke_command with data
+        T ->> T: Forward with given data
+        T ->> C: Infer result
+        C ->> D: Print result
+    end
+
+    ```
+    For detailed usage, run: `mnist-rs infer --help`, a demo output is:
+
+    ```shell
+    Usage: mnist-rs infer [OPTIONS] --model <MODEL>
+
+    Options:
+        -m, --model <MODEL>    The path of the model
+        -b, --binary <BINARY>  The path of the input binary, must be 768 byte 
binary, can be multiple
+        -i, --image <IMAGE>    The path of the input image, must be dimension 
of 28x28, can be multiple
+        -h, --help             Print help
+    ```
+
+3. Serve
+
+    Loads a model from the given path, starts a web server and serves it as an
+    API.
+
+    ```shell
+    mnist-rs serve -m model.bin
+    ```
+
+    This subcommand loads the model the model from the given path and starts a
+    web server to provide inference APIs.
+
+    **Available APIs**:
+
+    | Method | Endpoint | Body |
+    | ---- | ---- | ---- |
+    | POST | `/inference/image` | an image with dimensions 28x28 |
+    | POST | `/inference/binary` | a 784-byte binary |
+
+    You can test the server with the following commands:
+
+    ```shell
+    # Perform inference using an image
+    curl --data-binary "@./samples/7.png" http://localhost:3000/inference/image
+    # Perform inference using a binary file
+    curl --data-binary "@./samples/7.bin" 
http://localhost:3000/inference/binary
+    ```
+
+    ```mermaid
+    sequenceDiagram
+    actor D as Developer
+    actor H as HttpClient
+    participant C as mnist-rs(REE)
+    participant T as Inference TA(TEE)
+    
+    D ->> C: serve -m model.bin
+
+    C ->> C: Load Model Record from disk
+    C ->> T: open_session_with_operation
+    T ->> T: Initialize Global Model<br/> with given Model Record
+    T ->> C: Initialize finished
+
+    C ->> C: Start http server
+
+    loop accept request
+        par /inference/binary
+            H ->> C: Request with binary data
+            C ->> T: invoke_command: Feed data
+            T ->> T: Forward with given data
+            T ->> C: Infer result
+            C ->> H: Infer result
+        end
+        par /inference/image
+            H ->> C: Request with image data
+            C ->> C: Convert image to luma8 binary
+            C ->> T: invoke_command: Feed data
+            T ->> T: Forward with given data
+            T ->> C: Infer result
+            C ->> H: Infer result
+        end
+    end
+    
+
+    ```
+
+    For detailed usage, run: `mnist-rs serve --help`, a demo output is:
+
+    ```shell
+    Usage: mnist-rs serve [OPTIONS] --model <MODEL>
+
+    Options:
+        -m, --model <MODEL>  The path of the model
+        -p, --port <PORT>    [default: 3000]
+        -h, --help           Print help
+    ```
+
+## Credits
+
+This demo project is inspired by the crates and examples from
+[tracel-ai/burn](https://github.com/tracel-ai/burn), including:
+
+1. 
[crates/burn-no-std-tests](https://github.com/tracel-ai/burn/tree/v0.16.0/crates/burn-dataset)
+2. 
[examples/custom-training-loop](https://github.com/tracel-ai/burn/tree/v0.16.0/examples/custom-training-loop)
+3. 
[examples/mnist-inference-web](https://github.com/tracel-ai/burn/tree/v0.16.0/examples/mnist-inference-web)
+
+Special thanks to @[Guillaume Lagrange](https://github.com/laggui) for sharing
+knowledge and providing early reviews.
+
+TODO: standard license files after 0.4.0 released.
diff --git a/ci/ci.sh b/examples/mnist-rs/host/Cargo.toml
old mode 100755
new mode 100644
similarity index 53%
copy from ci/ci.sh
copy to examples/mnist-rs/host/Cargo.toml
index 7556757..90fab20
--- a/ci/ci.sh
+++ b/examples/mnist-rs/host/Cargo.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
+[package]
+name = "mnist-rs"
+version = "0.4.0"
+authors = ["Teaclave Contributors <[email protected]>"]
+license = "Apache-2.0"
+repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git";
+description = "An example of Rust OP-TEE TrustZone SDK."
+edition = "2021"
+publish = false
 
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+[dependencies]
+proto = { path = "../proto" }
+optee-teec = { path = "../../../optee-teec" }
+clap = { version = "4.5.31", features = ["derive"] }
+rand = "0.9.0"
+rust-mnist = "0.2.0"
+bytemuck = { version = "1.21.0", features = ["min_const_generics"] }
+serde_json = "1.0.139"
+image = "0.25.5"
+tiny_http = "0.12.0"
+anyhow = "1.0.97"
+ureq = "3.0.8"
+flate2 = "1.1.0"
 
-popd
+[profile.release]
+lto = true
diff --git a/ci/ci.sh b/examples/mnist-rs/host/Makefile
old mode 100755
new mode 100644
similarity index 57%
copy from ci/ci.sh
copy to examples/mnist-rs/host/Makefile
index 7556757..1804544
--- a/ci/ci.sh
+++ b/examples/mnist-rs/host/Makefile
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
-
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
-
-popd
+NAME := mnist-rs
+
+TARGET ?= aarch64-unknown-linux-gnu
+CROSS_COMPILE ?= aarch64-linux-gnu-
+OBJCOPY := $(CROSS_COMPILE)objcopy
+LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\"
+
+OUT_DIR := $(CURDIR)/target/$(TARGET)/release
+
+
+all: host strip
+
+host:
+       @cargo build --target $(TARGET_HOST) --release --config $(LINKER_CFG)
+
+strip: host
+       @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) $(OUT_DIR)/$(NAME)
+
+clean:
+       @cargo clean
diff --git a/examples/mnist-rs/host/samples/0.bin 
b/examples/mnist-rs/host/samples/0.bin
new file mode 100644
index 0000000..6d4d6d9
Binary files /dev/null and b/examples/mnist-rs/host/samples/0.bin differ
diff --git a/examples/mnist-rs/host/samples/1.bin 
b/examples/mnist-rs/host/samples/1.bin
new file mode 100644
index 0000000..e4727fe
Binary files /dev/null and b/examples/mnist-rs/host/samples/1.bin differ
diff --git a/examples/mnist-rs/host/samples/2.bin 
b/examples/mnist-rs/host/samples/2.bin
new file mode 100644
index 0000000..e9820a4
Binary files /dev/null and b/examples/mnist-rs/host/samples/2.bin differ
diff --git a/examples/mnist-rs/host/samples/3.bin 
b/examples/mnist-rs/host/samples/3.bin
new file mode 100644
index 0000000..3b30163
Binary files /dev/null and b/examples/mnist-rs/host/samples/3.bin differ
diff --git a/examples/mnist-rs/host/samples/4.bin 
b/examples/mnist-rs/host/samples/4.bin
new file mode 100644
index 0000000..5a7ec04
Binary files /dev/null and b/examples/mnist-rs/host/samples/4.bin differ
diff --git a/examples/mnist-rs/host/samples/5.bin 
b/examples/mnist-rs/host/samples/5.bin
new file mode 100644
index 0000000..8616576
Binary files /dev/null and b/examples/mnist-rs/host/samples/5.bin differ
diff --git a/examples/mnist-rs/host/samples/6.bin 
b/examples/mnist-rs/host/samples/6.bin
new file mode 100644
index 0000000..9b747bf
Binary files /dev/null and b/examples/mnist-rs/host/samples/6.bin differ
diff --git a/examples/mnist-rs/host/samples/7.bin 
b/examples/mnist-rs/host/samples/7.bin
new file mode 100644
index 0000000..6e67157
Binary files /dev/null and b/examples/mnist-rs/host/samples/7.bin differ
diff --git a/examples/mnist-rs/host/samples/7.png 
b/examples/mnist-rs/host/samples/7.png
new file mode 100644
index 0000000..637d88f
Binary files /dev/null and b/examples/mnist-rs/host/samples/7.png differ
diff --git a/examples/mnist-rs/host/samples/8.bin 
b/examples/mnist-rs/host/samples/8.bin
new file mode 100644
index 0000000..eb7ab3e
Binary files /dev/null and b/examples/mnist-rs/host/samples/8.bin differ
diff --git a/examples/mnist-rs/host/samples/9.bin 
b/examples/mnist-rs/host/samples/9.bin
new file mode 100644
index 0000000..bd51b5a
Binary files /dev/null and b/examples/mnist-rs/host/samples/9.bin differ
diff --git a/examples/mnist-rs/host/samples/model.bin 
b/examples/mnist-rs/host/samples/model.bin
new file mode 100644
index 0000000..6cae1f9
Binary files /dev/null and b/examples/mnist-rs/host/samples/model.bin differ
diff --git a/examples/mnist-rs/host/src/commands/infer.rs 
b/examples/mnist-rs/host/src/commands/infer.rs
new file mode 100644
index 0000000..8c3b339
--- /dev/null
+++ b/examples/mnist-rs/host/src/commands/infer.rs
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use clap::Parser;
+use image::EncodableLayout;
+use optee_teec::Context;
+use proto::{Image, IMAGE_SIZE};
+
+#[derive(Parser, Debug)]
+pub struct Args {
+    /// The path of the model.
+    #[arg(short, long)]
+    model: String,
+    /// The path of the input binary, must be 768 byte binary, can be multiple
+    #[arg(short, long)]
+    binary: Vec<String>,
+    /// The path of the input image, must be dimension of 28x28, can be 
multiple
+    #[arg(short, long)]
+    image: Vec<String>,
+}
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    let model_path = std::path::absolute(&args.model)?;
+    println!("Load model from \"{}\"", model_path.display());
+    let record = std::fs::read(&model_path)?;
+    let mut ctx = Context::new()?;
+    let mut caller = crate::tee::InferenceTaConnector::new(&mut ctx, &record)?;
+
+    let mut binaries: Vec<Image> = args
+        .binary
+        .iter()
+        .map(|v| {
+            let data = std::fs::read(v)?;
+            anyhow::ensure!(data.len() == IMAGE_SIZE);
+
+            TryInto::<Image>::try_into(data)
+                .map_err(|err| anyhow::Error::msg(format!("cannot convert {:?} 
into Image", err)))
+        })
+        .collect::<Result<Vec<_>, anyhow::Error>>()?;
+    let images: Vec<Image> = args
+        .image
+        .iter()
+        .map(|v| {
+            let img = image::open(v).unwrap().to_luma8();
+            let bytes = img.as_bytes();
+            anyhow::ensure!(bytes.len() == IMAGE_SIZE);
+            TryInto::<Image>::try_into(bytes)
+                .map_err(|err| anyhow::Error::msg(format!("cannot convert {:?} 
into Image", err)))
+        })
+        .collect::<Result<Vec<_>, anyhow::Error>>()?;
+    binaries.extend(images);
+
+    let result = caller.infer_batch(&binaries)?;
+    anyhow::ensure!(binaries.len() == result.len());
+
+    for (i, binary) in args.binary.iter().enumerate() {
+        println!("{}. {}: {}", i + 1, binary, result[i]);
+    }
+
+    for (i, image) in args.image.iter().enumerate() {
+        println!(
+            "{}. {}: {}",
+            i + args.binary.len() + 1,
+            image,
+            result[args.binary.len()]
+        );
+    }
+    println!("Infer Success");
+
+    Ok(())
+}
diff --git a/examples/mnist-rs/host/src/commands/mod.rs 
b/examples/mnist-rs/host/src/commands/mod.rs
new file mode 100644
index 0000000..64e4e41
--- /dev/null
+++ b/examples/mnist-rs/host/src/commands/mod.rs
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub mod infer;
+pub mod serve;
+pub mod train;
diff --git a/examples/mnist-rs/host/src/commands/serve.rs 
b/examples/mnist-rs/host/src/commands/serve.rs
new file mode 100644
index 0000000..6f8b3e1
--- /dev/null
+++ b/examples/mnist-rs/host/src/commands/serve.rs
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use clap::Parser;
+use image::EncodableLayout;
+use proto::{IMAGE_HEIGHT, IMAGE_SIZE, IMAGE_WIDTH};
+use std::io::Cursor;
+use tiny_http::{Request, Response, Server};
+
+#[derive(Parser, Debug)]
+pub struct Args {
+    /// The path of the model.
+    #[arg(short, long)]
+    model: String,
+    #[arg(short, long, default_value_t = 3000)]
+    port: u16,
+}
+
+type HttpResponse = Response<Cursor<Vec<u8>>>;
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    let model_path = std::path::absolute(&args.model)?;
+    println!("Load model from \"{}\"", model_path.display());
+    let record = std::fs::read(&model_path)?;
+    let mut ctx = optee_teec::Context::new()?;
+    let mut caller = crate::tee::InferenceTaConnector::new(&mut ctx, &record)?;
+
+    let addr = format!("0.0.0.0:{}", args.port);
+    println!("Server runs on: {}", addr);
+
+    let server = Server::http(&addr)
+        .map_err(|err| anyhow::Error::msg(format!("cannot start server: {:?}", 
err)))?;
+
+    loop {
+        let mut request = server.recv()?;
+        let response = match handle(&mut caller, &mut request) {
+            Ok(v) => v,
+            Err(err) => {
+                eprintln!("unexpected error: {:#?}", err);
+                Response::from_string("Internal Error").with_status_code(500)
+            }
+        };
+        request.respond(response)?;
+    }
+}
+
+fn handle(
+    caller: &mut crate::tee::InferenceTaConnector<'_>,
+    request: &mut Request,
+) -> anyhow::Result<HttpResponse> {
+    if request.method().ne(&tiny_http::Method::Post) {
+        return Ok(Response::from_string("Invalid Request 
Method").with_status_code(400));
+    }
+
+    match request.url() {
+        "/inference/image" => handle_image(caller, request),
+        "/inference/binary" => handle_binary(caller, request),
+        _ => Ok(Response::from_string("Not Found").with_status_code(404)),
+    }
+}
+
+fn handle_image(
+    caller: &mut crate::tee::InferenceTaConnector<'_>,
+    request: &mut Request,
+) -> anyhow::Result<HttpResponse> {
+    let mut data = Vec::with_capacity(IMAGE_SIZE);
+    request.as_reader().read_to_end(&mut data)?;
+    let img = image::ImageReader::new(Cursor::new(data))
+        .with_guessed_format()?
+        .decode()?
+        .to_luma8();
+    if img.width() as usize != IMAGE_WIDTH || img.height() as usize != 
IMAGE_HEIGHT {
+        return Ok(Response::from_string("Invalid 
Image").with_status_code(400));
+    }
+    let result = handle_infer(caller, img.as_bytes())?;
+
+    println!("Performing Inference with Image, Result is {}", result);
+    Ok(Response::from_data(result.to_string()))
+}
+
+fn handle_binary(
+    caller: &mut crate::tee::InferenceTaConnector<'_>,
+    request: &mut Request,
+) -> anyhow::Result<HttpResponse> {
+    let mut data = Vec::with_capacity(IMAGE_SIZE);
+    request.as_reader().read_to_end(&mut data)?;
+    if data.len() != IMAGE_SIZE {
+        return Ok(Response::from_string("Invalid 
Tensor").with_status_code(400));
+    }
+
+    let result = handle_infer(caller, &data)?;
+    println!("Performing Inference with Binary, Result is {}", result);
+    Ok(Response::from_data(result.to_string()))
+}
+
+fn handle_infer(
+    caller: &mut crate::tee::InferenceTaConnector<'_>,
+    image: &[u8],
+) -> anyhow::Result<u8> {
+    let result = caller.infer_batch(bytemuck::cast_slice(image))?;
+    Ok(result[0])
+}
diff --git a/examples/mnist-rs/host/src/commands/train.rs 
b/examples/mnist-rs/host/src/commands/train.rs
new file mode 100644
index 0000000..4e8fac0
--- /dev/null
+++ b/examples/mnist-rs/host/src/commands/train.rs
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::io::{Cursor, Read};
+use std::path::PathBuf;
+
+use crate::tee::TrainerTaConnector;
+use optee_teec::Context;
+use proto::Image;
+use rand::seq::SliceRandom;
+
+#[derive(clap::Parser, Debug)]
+pub struct Args {
+    #[arg(short, long, default_value_t = 6)]
+    num_epochs: usize,
+    #[arg(short, long, default_value_t = 64)]
+    batch_size: usize,
+    #[arg(short, long, default_value_t = 0.0001)]
+    learning_rate: f64,
+    #[arg(short, long)]
+    output: Option<String>,
+}
+
+fn convert_datasets(images: &Vec<Image>, labels: &[u8]) -> Vec<(Image, u8)> {
+    let mut datasets: Vec<(Image, u8)> = images
+        .iter()
+        .map(|v| v.to_owned())
+        .zip(labels.iter().copied())
+        .collect();
+    datasets.shuffle(&mut rand::rng());
+    datasets
+}
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    // Initialize trainer
+    let mut ctx = Context::new()?;
+    let mut trainer = TrainerTaConnector::new(&mut ctx, args.learning_rate)?;
+    // Download mnist data
+    let data = check_download_mnist_data()?;
+    // Prepare datasets
+    let train_datasets = convert_datasets(&data.train_data, 
&data.train_labels);
+    let valid_datasets = convert_datasets(&data.test_data, &data.test_labels);
+    // Training loop, Originally inspired by burn/crates/custom-training-loop
+    for epoch in 1..args.num_epochs + 1 {
+        for (iteration, data) in 
train_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.train(&images, &labels)?;
+            println!(
+                "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+
+        for (iteration, data) in 
valid_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.valid(&images, &labels)?;
+            println!(
+                "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+    }
+    // Export the model to the given path
+    if let Some(output_path) = args.output.as_ref() {
+        let record = trainer.export()?;
+        println!("Export record to \"{}\"", output_path);
+        std::fs::write(output_path, &record)?;
+    }
+    println!("Train Success");
+    Ok(())
+}
+
+fn check_download_mnist_data() -> anyhow::Result<rust_mnist::Mnist> {
+    const DATA_PATH: &str = "./data/";
+
+    let folder = PathBuf::from(DATA_PATH);
+    if !folder.exists() {
+        std::fs::create_dir_all(&folder)?;
+    }
+
+    // Expected file properties (name, compressed_size, uncompressed_size) for
+    // verification after download
+    const EXPECTED_MNIST_FILE_SIZES: [(&str, u64, u64); 4] = [
+        ("train-images-idx3-ubyte", 9912422, 47040016),
+        ("train-labels-idx1-ubyte", 28881, 60008),
+        ("t10k-images-idx3-ubyte", 1648877, 7840016),
+        ("t10k-labels-idx1-ubyte", 4542, 10008),
+    ];
+
+    // Verify if all files are correctly downloaded
+    for (filename, compressed_size, uncompressed_size) in 
EXPECTED_MNIST_FILE_SIZES.iter() {
+        let file = folder.join(filename);
+        if file.exists() && file.is_file() && std::fs::metadata(&file)?.len() 
== *compressed_size {
+            println!("File {} exist, skip.", file.display());
+            continue;
+        }
+
+        let url = format!(
+            "https://storage.googleapis.com/cvdf-datasets/mnist/{}.gz";,
+            filename
+        );
+        println!("Download {} from {}", filename, url);
+        let body = ureq::get(&url).call()?.body_mut().read_to_vec()?;
+
+        anyhow::ensure!(body.len() == *compressed_size as usize);
+
+        let mut gz = flate2::bufread::GzDecoder::new(Cursor::new(body));
+        let mut buffer = Vec::with_capacity(*uncompressed_size as usize);
+        gz.read_to_end(&mut buffer)?;
+
+        std::fs::write(file, &buffer)?;
+    }
+
+    Ok(rust_mnist::Mnist::new(DATA_PATH))
+}
diff --git a/examples/mnist-rs/host/src/main.rs 
b/examples/mnist-rs/host/src/main.rs
new file mode 100644
index 0000000..da3a6a6
--- /dev/null
+++ b/examples/mnist-rs/host/src/main.rs
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod commands;
+mod tee;
+
+use clap::{Parser, Subcommand};
+
+#[derive(Parser)]
+#[command(version, about, long_about = None)]
+struct Cli {
+    #[command(subcommand)]
+    command: Commands,
+}
+
+#[derive(Subcommand)]
+enum Commands {
+    Train(commands::train::Args),
+    Infer(commands::infer::Args),
+    Serve(commands::serve::Args),
+}
+
+fn main() -> anyhow::Result<()> {
+    let cli = Cli::parse();
+
+    match cli.command {
+        Commands::Train(args) => commands::train::execute(&args),
+        Commands::Infer(args) => commands::infer::execute(&args),
+        Commands::Serve(args) => commands::serve::execute(&args),
+    }
+}
diff --git a/examples/mnist-rs/host/src/tee.rs 
b/examples/mnist-rs/host/src/tee.rs
new file mode 100644
index 0000000..d59dbc9
--- /dev/null
+++ b/examples/mnist-rs/host/src/tee.rs
@@ -0,0 +1,157 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_teec::{Context, ErrorKind, Operation, ParamNone, ParamTmpRef, 
Session, Uuid};
+use proto::{inference, train, Image};
+
+const MAX_OUTPUT_SERIALIZE_SIZE: usize = 1 * 1024;
+const MAX_MODEL_RECORD_SIZE: usize = 10 * 1024 * 1024;
+
+pub struct TrainerTaConnector<'a> {
+    sess: Session<'a>,
+}
+
+impl<'a> TrainerTaConnector<'a> {
+    pub fn new(ctx: &'a mut Context, learning_rate: f64) -> 
optee_teec::Result<Self> {
+        let bytes = learning_rate.to_le_bytes();
+        let uuid = Uuid::parse_str(train::UUID).map_err(|err| {
+            println!("parse uuid \"{}\" failed due to: {:?}", train::UUID, 
err);
+            ErrorKind::BadParameters
+        })?;
+        let mut op = Operation::new(
+            0,
+            ParamTmpRef::new_input(bytes.as_slice()),
+            ParamNone,
+            ParamNone,
+            ParamNone,
+        );
+
+        Ok(Self {
+            sess: ctx.open_session_with_operation(uuid, &mut op)?,
+        })
+    }
+    pub fn train(&mut self, images: &[Image], labels: &[u8]) -> 
optee_teec::Result<train::Output> {
+        let mut buffer = vec![0_u8; MAX_OUTPUT_SERIALIZE_SIZE];
+        let images = bytemuck::cast_slice(images);
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(images),
+                ParamTmpRef::new_input(labels),
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Train as u32, &mut op)?;
+            op.parameters().2.updated_size()
+        };
+        let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| {
+            println!("proto error: {:?}", err);
+            ErrorKind::BadFormat
+        })?;
+        Ok(result)
+    }
+    pub fn valid(&mut self, images: &[Image], labels: &[u8]) -> 
optee_teec::Result<train::Output> {
+        let mut buffer = vec![0_u8; MAX_OUTPUT_SERIALIZE_SIZE];
+        let images = bytemuck::cast_slice(images);
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(images),
+                ParamTmpRef::new_input(labels),
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Valid as u32, &mut op)?;
+            op.parameters().2.updated_size()
+        };
+        let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| {
+            println!("proto error: {:?}", err);
+            ErrorKind::BadFormat
+        })?;
+        Ok(result)
+    }
+
+    pub fn export(&mut self) -> optee_teec::Result<Vec<u8>> {
+        let mut buffer = vec![0_u8; MAX_MODEL_RECORD_SIZE];
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+                ParamNone,
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Export as u32, &mut op)?;
+            op.parameters().0.updated_size()
+        };
+        buffer.resize(size, 0);
+        Ok(buffer)
+    }
+}
+
+pub struct InferenceTaConnector<'a> {
+    sess: Session<'a>,
+}
+
+unsafe impl Send for InferenceTaConnector<'_> {}
+
+impl<'a> InferenceTaConnector<'a> {
+    pub fn new(ctx: &'a mut Context, record: &[u8]) -> 
optee_teec::Result<Self> {
+        let uuid = Uuid::parse_str(inference::UUID).map_err(|err| {
+            println!(
+                "parse uuid \"{}\" failed due to: {:?}",
+                inference::UUID,
+                err
+            );
+            ErrorKind::BadParameters
+        })?;
+        let mut op = Operation::new(
+            0,
+            ParamTmpRef::new_input(record),
+            ParamNone,
+            ParamNone,
+            ParamNone,
+        );
+
+        Ok(Self {
+            sess: ctx.open_session_with_operation(uuid, &mut op)?,
+        })
+    }
+    pub fn infer_batch(&mut self, images: &[Image]) -> 
optee_teec::Result<Vec<u8>> {
+        let mut output = vec![0_u8; images.len()];
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(bytemuck::cast_slice(images)),
+                ParamTmpRef::new_output(&mut output),
+                ParamNone,
+                ParamNone,
+            );
+            self.sess.invoke_command(0, &mut op)?;
+            op.parameters().1.updated_size()
+        };
+
+        if output.len() != size {
+            println!("mismatch response, want {}, got {}", size, output.len());
+            return Err(ErrorKind::Generic.into());
+        }
+        Ok(output)
+    }
+}
diff --git a/ci/ci.sh b/examples/mnist-rs/proto/Cargo.toml
old mode 100755
new mode 100644
similarity index 57%
copy from ci/ci.sh
copy to examples/mnist-rs/proto/Cargo.toml
index 7556757..1bba120
--- a/ci/ci.sh
+++ b/examples/mnist-rs/proto/Cargo.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
-
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+[package]
+name = "proto"
+version = "0.4.0"
+authors = ["Teaclave Contributors <[email protected]>"]
+license = "Apache-2.0"
+repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git";
+description = "Data structures and functions shared by host and TA."
+edition = "2021"
 
-popd
+[dependencies]
+num_enum = { version = "0.7.3", default-features = false }
+serde = { version = "1.0.218", default-features = false, features = ["derive"] 
}
diff --git a/examples/mnist-rs/proto/src/inference.rs 
b/examples/mnist-rs/proto/src/inference.rs
new file mode 100644
index 0000000..18a58fb
--- /dev/null
+++ b/examples/mnist-rs/proto/src/inference.rs
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub const UUID: &str = &include_str!("../../ta/inference/uuid.txt");
diff --git a/examples/mnist-rs/proto/src/lib.rs 
b/examples/mnist-rs/proto/src/lib.rs
new file mode 100644
index 0000000..69212ac
--- /dev/null
+++ b/examples/mnist-rs/proto/src/lib.rs
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+pub mod inference;
+pub mod train;
+
+pub const IMAGE_HEIGHT: usize = 28;
+pub const IMAGE_WIDTH: usize = 28;
+pub const IMAGE_SIZE: usize = IMAGE_HEIGHT * IMAGE_WIDTH;
+pub const NUM_CLASSES: usize = 10;
+pub type Image = [u8; IMAGE_SIZE];
diff --git a/examples/mnist-rs/proto/src/train.rs 
b/examples/mnist-rs/proto/src/train.rs
new file mode 100644
index 0000000..9bc804f
--- /dev/null
+++ b/examples/mnist-rs/proto/src/train.rs
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use num_enum::{IntoPrimitive, TryFromPrimitive};
+
+#[derive(Debug, TryFromPrimitive, IntoPrimitive)]
+#[repr(u32)]
+pub enum Command {
+    Train,
+    Valid,
+    Export,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct Output {
+    pub loss: f32,
+    pub accuracy: f32,
+}
+
+pub const UUID: &str = &include_str!("../../ta/train/uuid.txt");
diff --git a/ci/ci.sh b/examples/mnist-rs/rust-toolchain.toml
old mode 100755
new mode 100644
similarity index 57%
copy from ci/ci.sh
copy to examples/mnist-rs/rust-toolchain.toml
index 7556757..25dac2a
--- a/ci/ci.sh
+++ b/examples/mnist-rs/rust-toolchain.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
-
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+# Toolchain override, burn currently requires Rust 1.83 and will soon require
+# Rust 1.85.
 
-popd
+[toolchain]
+channel = "nightly-2025-01-16"
+components = [ "rust-src" ]
+targets = ["aarch64-unknown-linux-gnu", "arm-unknown-linux-gnueabihf"]
+# minimal profile: install rustc, cargo, and rust-std
+profile = "minimal"
diff --git a/examples/mnist-rs/ta/Cargo.toml b/examples/mnist-rs/ta/Cargo.toml
new file mode 100644
index 0000000..8303332
--- /dev/null
+++ b/examples/mnist-rs/ta/Cargo.toml
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[workspace]
+resolver = "2"
+
+members = [
+    "common",
+    "train",
+    "inference",
+]
+
+[workspace.package]
+edition = "2021"
+license = "Apache-2.0"
+version = "0.4.0"
+repository = "https://github.com/apache/incubator-teaclave-trustzone-sdk.git";
+authors = ["Teaclave Contributors <[email protected]>"]
+
+[workspace.dependencies]
+optee-utee-sys = { path = "../../../optee-utee/optee-utee-sys" }
+optee-utee = { path = "../../../optee-utee" }
+optee-utee-build = { path = "../../../optee-utee-build" }
+
+proto = { path = "../proto" }
+
+bytemuck = { version = "1.21.0", features = ["min_const_generics"] }
+burn = { git = "https://github.com/tracel-ai/burn.git";, rev = 
"a1ca4346424197f42ab1b5bff7f9d1210b029318", default-features = false, features 
= ["ndarray", "autodiff"] }
+spin = "0.9.8"
+serde = { version = "1.0.218", default-features = false, features = ["derive"] 
}
+serde_json = { version = "1.0.139", default-features = false, features = 
["alloc"] }
+
+
+[profile.release]
+panic = "abort"
+lto = true
+opt-level = 1
diff --git a/ci/ci.sh b/examples/mnist-rs/ta/common/Cargo.toml
old mode 100755
new mode 100644
similarity index 57%
copy from ci/ci.sh
copy to examples/mnist-rs/ta/common/Cargo.toml
index 7556757..66bdb84
--- a/ci/ci.sh
+++ b/examples/mnist-rs/ta/common/Cargo.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
-
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+[package]
+name = "common"
+description = "Some common structs and functions."
+publish = false
+version.workspace = true
+authors.workspace = true
+license.workspace = true
+repository.workspace = true
+edition.workspace = true
 
-popd
+[dependencies]
+proto = { workspace = true }
+optee-utee-sys = { workspace = true }
+optee-utee = { workspace = true }
+burn = { workspace = true }
diff --git a/examples/mnist-rs/ta/common/src/lib.rs 
b/examples/mnist-rs/ta/common/src/lib.rs
new file mode 100644
index 0000000..9332788
--- /dev/null
+++ b/examples/mnist-rs/ta/common/src/lib.rs
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+extern crate alloc;
+
+mod model;
+mod utils;
+
+pub use model::Model;
+pub use utils::*;
diff --git a/examples/mnist-rs/ta/common/src/model.rs 
b/examples/mnist-rs/ta/common/src/model.rs
new file mode 100644
index 0000000..39213f4
--- /dev/null
+++ b/examples/mnist-rs/ta/common/src/model.rs
@@ -0,0 +1,86 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use alloc::vec::Vec;
+use burn::{
+    prelude::*,
+    record::{FullPrecisionSettings, Recorder, RecorderError},
+};
+use proto::{Image, IMAGE_SIZE, NUM_CLASSES};
+
+/// This is a simple model designed solely to demonstrate how to train and
+/// perform inference, don't use it in production.
+#[derive(Module, Debug)]
+pub struct Model<B: Backend> {
+    linear: nn::Linear<B>,
+}
+
+impl<B: Backend> Model<B> {
+    pub fn new(device: &B::Device) -> Self {
+        Self {
+            linear: nn::LinearConfig::new(IMAGE_SIZE, 
NUM_CLASSES).init(device),
+        }
+    }
+
+    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
+        self.linear.forward(input)
+    }
+
+    pub fn export(&self) -> Result<Vec<u8>, RecorderError> {
+        let recorder = 
burn::record::BinBytesRecorder::<FullPrecisionSettings>::new();
+        recorder.record(self.clone().into_record(), ())
+    }
+
+    pub fn import(device: &B::Device, record: Vec<u8>) -> Result<Self, 
RecorderError> {
+        let recorder = 
burn::record::BinBytesRecorder::<FullPrecisionSettings>::new();
+        let record = recorder.load(record, device)?;
+
+        let m = Self::new(device);
+        Ok(m.load_record(record))
+    }
+}
+
+impl<B: Backend> Model<B> {
+    pub fn image_to_tensor(device: &B::Device, image: &Image) -> Tensor<B, 2> {
+        let tensor = 
TensorData::from(image.as_slice()).convert::<B::FloatElem>();
+        let tensor = Tensor::<B, 1>::from_data(tensor, device);
+        let tensor = tensor.reshape([1, IMAGE_SIZE]);
+
+        // Normalize input: make between [0,1] and make the mean=0 and std=1
+        // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
+        // 
https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
+        ((tensor / 255) - 0.1307) / 0.3081
+    }
+
+    pub fn images_to_tensors(device: &B::Device, images: &[Image]) -> 
Tensor<B, 2> {
+        let tensors = images
+            .iter()
+            .map(|v| Self::image_to_tensor(device, v))
+            .collect();
+        Tensor::cat(tensors, 0)
+    }
+
+    pub fn labels_to_tensors(device: &B::Device, labels: &[u8]) -> Tensor<B, 
1, Int> {
+        let targets = labels
+            .iter()
+            .map(|item| {
+                Tensor::<B, 1, Int>::from_data([(*item as 
i64).elem::<B::IntElem>()], device)
+            })
+            .collect();
+        Tensor::cat(targets, 0)
+    }
+}
diff --git a/examples/mnist-rs/ta/common/src/utils.rs 
b/examples/mnist-rs/ta/common/src/utils.rs
new file mode 100644
index 0000000..e80e030
--- /dev/null
+++ b/examples/mnist-rs/ta/common/src/utils.rs
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_utee::{trace_println, ErrorKind, Parameter, Result};
+
+pub fn copy_to_output(param: &mut Parameter, data: &[u8]) -> Result<()> {
+    let mut output = unsafe { param.as_memref()? };
+
+    let buffer = output.buffer();
+    if buffer.len() < data.len() {
+        trace_println!(
+            "expect output buffer size {}, got size {} instead",
+            data.len(),
+            buffer.len()
+        );
+        return Err(ErrorKind::ShortBuffer.into());
+    }
+    buffer[..data.len()].copy_from_slice(data);
+    output.set_updated_size(data.len());
+    Ok(())
+}
diff --git a/ci/ci.sh b/examples/mnist-rs/ta/inference/Cargo.toml
old mode 100755
new mode 100644
similarity index 57%
copy from ci/ci.sh
copy to examples/mnist-rs/ta/inference/Cargo.toml
index 7556757..b10e306
--- a/ci/ci.sh
+++ b/examples/mnist-rs/ta/inference/Cargo.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,26 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
+[package]
+name = "inference"
+description = "An inference TA of MNIST model."
+publish = false
+version.workspace = true
+authors.workspace = true
+license.workspace = true
+repository.workspace = true
+edition.workspace = true
 
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+[dependencies]
+common = { path = "../common" }
+proto = { workspace = true }
+optee-utee-sys = { workspace = true }
+optee-utee = { workspace = true }
+burn = { workspace = true }
+serde_json = { workspace = true }
+bytemuck = { workspace = true }
+spin = { workspace = true }
 
-popd
+[build-dependencies]
+proto = { workspace = true }
+optee-utee-build = { workspace = true }
diff --git a/examples/mnist-rs/ta/inference/Makefile 
b/examples/mnist-rs/ta/inference/Makefile
new file mode 100644
index 0000000..1e6f1bb
--- /dev/null
+++ b/examples/mnist-rs/ta/inference/Makefile
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+UUID ?= $(shell cat "./uuid.txt")
+NAME := inference
+
+TARGET ?= aarch64-unknown-linux-gnu
+CROSS_COMPILE ?= aarch64-linux-gnu-
+OBJCOPY := $(CROSS_COMPILE)objcopy
+# Configure the linker to use GCC, which works on both cross-compilation and 
ARM machines
+LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\"
+EXTRA_FLAGS = -Z build-std=core,alloc
+
+TA_SIGN_KEY ?= $(TA_DEV_KIT_DIR)/keys/default_ta.pem
+SIGN := $(TA_DEV_KIT_DIR)/scripts/sign_encrypt.py
+OUT_DIR := $(CURDIR)/../target/$(TARGET)/release
+
+
+ifeq ($(STD),)
+all: ta strip sign
+else
+all:
+       @echo "Please \`unset STD\` then rerun \`source environment\` to build 
the No-STD version"
+endif
+
+ta:
+       @cargo build --target $(TARGET) --release --config $(LINKER_CFG) 
$(EXTRA_FLAGS)
+
+strip: ta
+       @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) 
$(OUT_DIR)/stripped_$(NAME)
+
+sign: strip
+       @$(SIGN) --uuid $(UUID) --key $(TA_SIGN_KEY) --in 
$(OUT_DIR)/stripped_$(NAME) --out $(OUT_DIR)/$(UUID).ta
+       @echo "SIGN =>  ${UUID}"
+
+clean:
+       @cargo clean
diff --git a/examples/mnist-rs/ta/inference/build.rs 
b/examples/mnist-rs/ta/inference/build.rs
new file mode 100644
index 0000000..314c05c
--- /dev/null
+++ b/examples/mnist-rs/ta/inference/build.rs
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_utee_build::{Error, RustEdition, TaConfig};
+
+fn main() -> Result<(), Error> {
+    let config = TaConfig::new_default_with_cargo_env(proto::train::UUID)?
+        .ta_data_size(1 * 1024 * 1024)
+        .ta_stack_size(1 * 1024 * 1024);
+    optee_utee_build::build(RustEdition::Before2024, config)
+}
diff --git a/examples/mnist-rs/ta/inference/src/main.rs 
b/examples/mnist-rs/ta/inference/src/main.rs
new file mode 100644
index 0000000..9c23396
--- /dev/null
+++ b/examples/mnist-rs/ta/inference/src/main.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+#![no_main]
+extern crate alloc;
+
+use burn::{
+    backend::{ndarray::NdArrayDevice, NdArray},
+    tensor::cast::ToElement,
+};
+
+use common::{copy_to_output, Model};
+use optee_utee::{
+    ta_close_session, ta_create, ta_destroy, ta_invoke_command, 
ta_open_session, trace_println,
+};
+use optee_utee::{ErrorKind, Parameters, Result};
+use proto::Image;
+use spin::Mutex;
+
+type NoStdModel = Model<NdArray>;
+const DEVICE: NdArrayDevice = NdArrayDevice::Cpu;
+static MODEL: Mutex<Option<NoStdModel>> = Mutex::new(Option::None);
+
+#[ta_create]
+fn create() -> Result<()> {
+    trace_println!("[+] TA create");
+    Ok(())
+}
+
+#[ta_open_session]
+fn open_session(params: &mut Parameters) -> Result<()> {
+    let mut p0 = unsafe { params.0.as_memref()? };
+
+    let mut model = MODEL.lock();
+    model.replace(Model::import(&DEVICE, p0.buffer().to_vec()).map_err(|err| {
+        trace_println!("import failed: {:?}", err);
+        ErrorKind::BadParameters
+    })?);
+
+    Ok(())
+}
+
+#[ta_close_session]
+fn close_session() {
+    trace_println!("[+] TA close session");
+}
+
+#[ta_destroy]
+fn destroy() {
+    trace_println!("[+] TA destroy");
+}
+
+#[ta_invoke_command]
+fn invoke_command(_cmd_id: u32, params: &mut Parameters) -> Result<()> {
+    trace_println!("[+] TA invoke command");
+    let mut p0 = unsafe { params.0.as_memref()? };
+    let images: &[Image] = bytemuck::cast_slice(p0.buffer());
+    let input = NoStdModel::images_to_tensors(&DEVICE, images);
+
+    let output = MODEL
+        .lock()
+        .as_ref()
+        .ok_or(ErrorKind::CorruptObject)?
+        .forward(input);
+    let result: alloc::vec::Vec<u8> = output
+        .iter_dim(0)
+        .map(|v| {
+            let data = burn::tensor::activation::softmax(v, 1);
+            data.argmax(1).into_scalar().to_u8()
+        })
+        .collect();
+
+    copy_to_output(&mut params.1, &result)
+}
+
+include!(concat!(env!("OUT_DIR"), "/user_ta_header.rs"));
diff --git a/examples/mnist-rs/ta/inference/uuid.txt 
b/examples/mnist-rs/ta/inference/uuid.txt
new file mode 100644
index 0000000..d00c0a4
--- /dev/null
+++ b/examples/mnist-rs/ta/inference/uuid.txt
@@ -0,0 +1 @@
+ff09aa8a-fbb9-4734-ae8c-d7cd1a3f6744
\ No newline at end of file
diff --git a/ci/ci.sh b/examples/mnist-rs/ta/train/Cargo.toml
old mode 100755
new mode 100644
similarity index 56%
copy from ci/ci.sh
copy to examples/mnist-rs/ta/train/Cargo.toml
index 7556757..a0c24f9
--- a/ci/ci.sh
+++ b/examples/mnist-rs/ta/train/Cargo.toml
@@ -1,5 +1,3 @@
-#!/bin/bash
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,35 +15,26 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -xe
-
-pushd ../tests
-
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
+[package]
+name = "train"
+description = "An training TA of MNIST model."
+publish = false
+version.workspace = true
+authors.workspace = true
+license.workspace = true
+repository.workspace = true
+edition.workspace = true
 
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+[dependencies]
+common = { path = "../common" }
+proto = { workspace = true }
+optee-utee-sys = { workspace = true }
+optee-utee = { workspace = true }
+burn = { workspace = true }
+spin = { workspace = true }
+serde_json = { workspace = true }
+bytemuck = { workspace = true, features = ["min_const_generics"] }
 
-popd
+[build-dependencies]
+proto = { workspace = true }
+optee-utee-build = { workspace = true }
diff --git a/examples/mnist-rs/ta/train/Makefile 
b/examples/mnist-rs/ta/train/Makefile
new file mode 100644
index 0000000..4b8bbd4
--- /dev/null
+++ b/examples/mnist-rs/ta/train/Makefile
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+UUID ?= $(shell cat "./uuid.txt")
+NAME := train
+
+TARGET ?= aarch64-unknown-linux-gnu
+CROSS_COMPILE ?= aarch64-linux-gnu-
+OBJCOPY := $(CROSS_COMPILE)objcopy
+# Configure the linker to use GCC, which works on both cross-compilation and 
ARM machines
+LINKER_CFG := target.$(TARGET).linker=\"$(CROSS_COMPILE)gcc\"
+EXTRA_FLAGS = -Z build-std=core,alloc
+
+TA_SIGN_KEY ?= $(TA_DEV_KIT_DIR)/keys/default_ta.pem
+SIGN := $(TA_DEV_KIT_DIR)/scripts/sign_encrypt.py
+OUT_DIR := $(CURDIR)/../target/$(TARGET)/release
+
+ifeq ($(STD),)
+all: ta strip sign
+else
+all:
+       @echo "Please \`unset STD\` then rerun \`source environment\` to build 
the No-STD version"
+endif
+
+ta:
+       @cargo build --target $(TARGET) --release --config $(LINKER_CFG) 
$(EXTRA_FLAGS)
+
+strip: ta
+       @$(OBJCOPY) --strip-unneeded $(OUT_DIR)/$(NAME) 
$(OUT_DIR)/stripped_$(NAME)
+
+sign: strip
+       @$(SIGN) --uuid $(UUID) --key $(TA_SIGN_KEY) --in 
$(OUT_DIR)/stripped_$(NAME) --out $(OUT_DIR)/$(UUID).ta
+       @echo "SIGN =>  ${UUID}"
+
+clean:
+       @cargo clean
diff --git a/examples/mnist-rs/ta/train/build.rs 
b/examples/mnist-rs/ta/train/build.rs
new file mode 100644
index 0000000..e60b54b
--- /dev/null
+++ b/examples/mnist-rs/ta/train/build.rs
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_utee_build::{Error, RustEdition, TaConfig};
+
+fn main() -> Result<(), Error> {
+    let config = TaConfig::new_default_with_cargo_env(proto::train::UUID)?
+        .ta_data_size(5 * 1024 * 1024)
+        .ta_stack_size(1 * 1024 * 1024);
+    optee_utee_build::build(RustEdition::Before2024, config)
+}
diff --git a/examples/mnist-rs/ta/train/src/main.rs 
b/examples/mnist-rs/ta/train/src/main.rs
new file mode 100644
index 0000000..f3ad57a
--- /dev/null
+++ b/examples/mnist-rs/ta/train/src/main.rs
@@ -0,0 +1,127 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+#![no_main]
+extern crate alloc;
+
+use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
+use common::copy_to_output;
+use optee_utee::{
+    ta_close_session, ta_create, ta_destroy, ta_invoke_command, 
ta_open_session, trace_println,
+};
+use optee_utee::{ErrorKind, Parameters, Result};
+use proto::train::Command;
+use spin::Mutex;
+
+mod trainer;
+
+type NoStdTrainer = trainer::Trainer<Autodiff<NdArray>>;
+
+const DEVICE: NdArrayDevice = NdArrayDevice::Cpu;
+static TRAINER: Mutex<Option<NoStdTrainer>> = Mutex::new(Option::None);
+
+#[ta_create]
+fn create() -> Result<()> {
+    trace_println!("[+] TA create");
+    Ok(())
+}
+
+#[ta_open_session]
+fn open_session(params: &mut Parameters) -> Result<()> {
+    let mut p0 = unsafe { params.0.as_memref()? };
+
+    let learning_rate = 
f64::from_le_bytes(p0.buffer().try_into().map_err(|err| {
+        trace_println!("bad parameter {:?}", err);
+        ErrorKind::BadParameters
+    })?);
+    trace_println!("Initialize with learning_rate: {}", learning_rate);
+
+    let mut trainer = TRAINER.lock();
+    trainer.replace(NoStdTrainer::new(DEVICE, learning_rate));
+
+    Ok(())
+}
+
+#[ta_close_session]
+fn close_session() {
+    trace_println!("[+] TA close session");
+}
+
+#[ta_destroy]
+fn destroy() {
+    trace_println!("[+] TA destroy");
+}
+
+#[ta_invoke_command]
+fn invoke_command(cmd_id: u32, params: &mut Parameters) -> Result<()> {
+    match Command::try_from(cmd_id) {
+        Ok(Command::Train) => {
+            let mut p0 = unsafe { params.0.as_memref()? };
+            let mut p1 = unsafe { params.1.as_memref()? };
+
+            let images = p0.buffer();
+            let labels = p1.buffer();
+
+            let mut trainer = TRAINER.lock();
+            let result = trainer
+                .as_mut()
+                .ok_or(ErrorKind::CorruptObject)?
+                .train(bytemuck::cast_slice(images), labels);
+            let bytes = serde_json::to_vec(&result).map_err(|err| {
+                trace_println!("unexpected error: {:?}", err);
+                ErrorKind::BadState
+            })?;
+
+            copy_to_output(&mut params.2, &bytes)
+        }
+        Ok(Command::Valid) => {
+            let mut p0 = unsafe { params.0.as_memref()? };
+            let mut p1 = unsafe { params.1.as_memref()? };
+
+            let images = p0.buffer();
+            let labels = p1.buffer();
+
+            let trainer = TRAINER.lock();
+            let result = trainer
+                .as_ref()
+                .ok_or(ErrorKind::CorruptObject)?
+                .valid(bytemuck::cast_slice(images), labels);
+
+            let bytes = serde_json::to_vec(&result).map_err(|err| {
+                trace_println!("unexpected error: {:?}", err);
+                ErrorKind::BadState
+            })?;
+            copy_to_output(&mut params.2, &bytes)
+        }
+        Ok(Command::Export) => {
+            let trainer = TRAINER.lock();
+            let result = trainer
+                .as_ref()
+                .ok_or(ErrorKind::CorruptObject)?
+                .export()
+                .map_err(|err| {
+                    trace_println!("unexpected error: {:?}", err);
+                    ErrorKind::BadState
+                })?;
+            copy_to_output(&mut params.0, &result)
+        }
+        Err(_) => Err(ErrorKind::BadParameters.into()),
+    }
+}
+
+include!(concat!(env!("OUT_DIR"), "/user_ta_header.rs"));
diff --git a/examples/mnist-rs/ta/train/src/trainer.rs 
b/examples/mnist-rs/ta/train/src/trainer.rs
new file mode 100644
index 0000000..5d4b9a8
--- /dev/null
+++ b/examples/mnist-rs/ta/train/src/trainer.rs
@@ -0,0 +1,114 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use alloc::vec::Vec;
+use burn::{
+    module::AutodiffModule,
+    nn::loss::CrossEntropyLoss,
+    optim::{adaptor::OptimizerAdaptor, Adam, AdamConfig, GradientsParams, 
Optimizer},
+    prelude::*,
+    record::RecorderError,
+    tensor::{backend::AutodiffBackend, cast::ToElement},
+};
+use common::Model;
+use proto::{train::Output, Image};
+
+pub struct Trainer<B: AutodiffBackend> {
+    model: Model<B>,
+    device: B::Device,
+    optim: OptimizerAdaptor<Adam, Model<B>, B>,
+    lr: f64,
+}
+
+impl<B: AutodiffBackend> Trainer<B> {
+    pub fn new(device: B::Device, lr: f64) -> Self {
+        let mut seed = [0_u8; 8];
+        optee_utee::Random::generate(seed.as_mut_slice());
+        B::seed(u64::from_le_bytes(seed));
+
+        Self {
+            optim: AdamConfig::new().init(),
+            model: Model::new(&device),
+            device,
+            lr,
+        }
+    }
+
+    // Originally inspired by the burn/examples/custom-training-loop package.
+    // You may refer to
+    // 
https://github.com/tracel-ai/burn/blob/v0.16.0/examples/custom-training-loop
+    // for details.
+    pub fn train(&mut self, images: &[Image], labels: &[u8]) -> Output {
+        let images = Model::images_to_tensors(&self.device, images);
+        let targets = Model::labels_to_tensors(&self.device, labels);
+        let model = self.model.clone();
+
+        let output = model.forward(images);
+        let loss =
+            CrossEntropyLoss::new(None, 
&output.device()).forward(output.clone(), targets.clone());
+        let accuracy = accuracy(output, targets);
+
+        // Gradients for the current backward pass
+        let grads = loss.backward();
+        // Gradients linked to each parameter of the model.
+        let grads = GradientsParams::from_grads(grads, &model);
+        // Update the model using the optimizer.
+        self.model = self.optim.step(self.lr, model, grads);
+
+        Output {
+            loss: loss.into_scalar().to_f32(),
+            accuracy,
+        }
+    }
+
+    // Originally inspired by the burn/examples/custom-training-loop package.
+    // You may refer to
+    // 
https://github.com/tracel-ai/burn/blob/v0.16.0/examples/custom-training-loop
+    // for details.
+    pub fn valid(&self, images: &[Image], labels: &[u8]) -> Output {
+        // Get the model without autodiff.
+        let model_valid = self.model.valid();
+
+        let images = Model::images_to_tensors(&self.device, images);
+        let targets = Model::labels_to_tensors(&self.device, labels);
+
+        let output = model_valid.forward(images);
+        let loss =
+            CrossEntropyLoss::new(None, 
&output.device()).forward(output.clone(), targets.clone());
+        let accuracy = accuracy(output, targets);
+
+        Output {
+            loss: loss.into_scalar().to_f32(),
+            accuracy,
+        }
+    }
+
+    pub fn export(&self) -> Result<Vec<u8>, RecorderError> {
+        self.model.export()
+    }
+}
+
+// Originally copied from the burn/crates/no-std-tests package. You may refer
+// to https://github.com/tracel-ai/burn/blob/v0.16.0/crates/burn-no-std-test 
for
+// details.
+fn accuracy<B: Backend>(output: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> 
f32 {
+    let predictions = output.argmax(1).squeeze(1);
+    let num_predictions: usize = targets.dims().iter().product();
+    let num_corrects = predictions.equal(targets).int().sum().into_scalar();
+
+    num_corrects.elem::<f32>() / num_predictions as f32 * 100.0
+}
diff --git a/examples/mnist-rs/ta/train/uuid.txt 
b/examples/mnist-rs/ta/train/uuid.txt
new file mode 100644
index 0000000..1c014e8
--- /dev/null
+++ b/examples/mnist-rs/ta/train/uuid.txt
@@ -0,0 +1 @@
+1b5f5b74-e9cf-4e62-8c3e-7e41da6d76f6
\ No newline at end of file
diff --git a/tests/setup.sh b/tests/setup.sh
index e8a6a2a..6e7329f 100755
--- a/tests/setup.sh
+++ b/tests/setup.sh
@@ -43,6 +43,11 @@ run_in_qemu() {
     sleep 5
 }
 
+run_in_qemu_with_timeout_secs() {
+    screen -S qemu_screen -p 0 -X stuff "$1\n"
+    sleep $2
+}
+
 # Check if the image file exists locally
 if [ ! -d "${IMG}" ]; then
     echo "Image file '${IMG}' not found locally. Downloading from network."
diff --git a/ci/ci.sh b/tests/test_mnist_rs.sh
similarity index 51%
copy from ci/ci.sh
copy to tests/test_mnist_rs.sh
index 7556757..8b66e79 100755
--- a/ci/ci.sh
+++ b/tests/test_mnist_rs.sh
@@ -19,33 +19,30 @@
 
 set -xe
 
-pushd ../tests
+# Include base script
+source setup.sh
 
-./test_hello_world.sh
-./test_random.sh
-./test_secure_storage.sh
-./test_aes.sh
-./test_hotp.sh
-./test_acipher.sh
-./test_big_int.sh
-./test_diffie_hellman.sh
-./test_digest.sh
-./test_authentication.sh
-./test_time.sh
-./test_signature_verification.sh
-./test_supp_plugin.sh
-./test_error_handling.sh
-./test_tcp_client.sh
-./test_udp_socket.sh
+# Copy TA and host binary
+cp ../examples/mnist-rs/ta/target/$TARGET_TA/release/*.ta shared
+cp ../examples/mnist-rs/host/target/$TARGET_HOST/release/mnist-rs shared
+# Copy samples files
+cp -r ../examples/mnist-rs/host/samples shared
 
-# Run std only tests
-if [ "$STD" ]; then
-    ./test_serde.sh
-    ./test_message_passing_interface.sh
-    ./test_tls_client.sh
-    ./test_tls_server.sh
-    ./test_eth_wallet.sh
-    ./test_secure_db_abstraction.sh
-fi
+# Run script specific commands in QEMU
+run_in_qemu "cp *.ta /lib/optee_armtz/\n"
+# Do not export the model due to QEMU's memory limitations.
+run_in_qemu_with_timeout_secs "./mnist-rs train -n 1\n" 300
+run_in_qemu "./mnist-rs infer -m samples/model.bin -b samples/7.bin -i 
samples/7.png\n"
+run_in_qemu "^C"
 
-popd
+# Script specific checks
+{
+    grep -q "Train Success" screenlog.0 &&
+    grep -q "Infer Success" screenlog.0
+} || {
+    cat -v screenlog.0
+    cat -v /tmp/serial.log
+    false
+}
+
+rm screenlog.0


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to