On 6/22/23 18:16, Max Chou wrote:
--- a/target/riscv/vcrypto_helper.c
+++ b/target/riscv/vcrypto_helper.c
@@ -22,6 +22,7 @@
  #include "qemu/bitops.h"
  #include "qemu/bswap.h"
  #include "cpu.h"
+#include "crypto/aes.h"
  #include "exec/memop.h"
  #include "exec/exec-all.h"
  #include "exec/helper-proto.h"
@@ -195,3 +196,310 @@ RVVCALL(OPIVX2, vwsll_vx_w, WOP_UUU_W, H8, H4, DO_SLL)
  GEN_VEXT_VX(vwsll_vx_b, 2)
  GEN_VEXT_VX(vwsll_vx_h, 4)
  GEN_VEXT_VX(vwsll_vx_w, 8)
+
+static inline void aes_sub_bytes(uint8_t round_state[4][4])
+{
+    for (int j = 0; j < 16; j++) {
+        round_state[j / 4][j % 4] = AES_sbox[round_state[j / 4][j % 4]];
+    }
+}
+
+static inline void aes_shift_bytes(uint8_t round_state[4][4])
+{
+    uint8_t temp;
+    temp = round_state[0][1];
+    round_state[0][1] = round_state[1][1];
+    round_state[1][1] = round_state[2][1];
+    round_state[2][1] = round_state[3][1];
+    round_state[3][1] = temp;
+    temp = round_state[0][2];
+    round_state[0][2] = round_state[2][2];
+    round_state[2][2] = temp;
+    temp = round_state[1][2];
+    round_state[1][2] = round_state[3][2];
+    round_state[3][2] = temp;
+    temp = round_state[0][3];
+    round_state[0][3] = round_state[3][3];
+    round_state[3][3] = round_state[2][3];
+    round_state[2][3] = round_state[1][3];
+    round_state[1][3] = temp;
+}
+
+static inline void xor_round_key(uint8_t round_state[4][4], uint8_t *round_key)
+{
+    for (int j = 0; j < 16; j++) {
+        round_state[j / 4][j % 4] = round_state[j / 4][j % 4] ^ (round_key)[j];
+    }
+}
+
+static inline void aes_inv_sub_bytes(uint8_t round_state[4][4])
+{
+    for (int j = 0; j < 16; j++) {
+        round_state[j / 4][j % 4] = AES_isbox[round_state[j / 4][j % 4]];
+    }
+}
+
+static inline void aes_inv_shift_bytes(uint8_t round_state[4][4])
+{
+    uint8_t temp;
+    temp = round_state[3][1];
+    round_state[3][1] = round_state[2][1];
+    round_state[2][1] = round_state[1][1];
+    round_state[1][1] = round_state[0][1];
+    round_state[0][1] = temp;
+    temp = round_state[0][2];
+    round_state[0][2] = round_state[2][2];
+    round_state[2][2] = temp;
+    temp = round_state[1][2];
+    round_state[1][2] = round_state[3][2];
+    round_state[3][2] = temp;
+    temp = round_state[0][3];
+    round_state[0][3] = round_state[1][3];
+    round_state[1][3] = round_state[2][3];
+    round_state[2][3] = round_state[3][3];
+    round_state[3][3] = temp;
+}
+
+static inline uint8_t xtime(uint8_t x)
+{
+    return (x << 1) ^ (((x >> 7) & 1) * 0x1b);
+}
+
+static inline uint8_t multiply(uint8_t x, uint8_t y)
+{
+    return (((y & 1) * x) ^ ((y >> 1 & 1) * xtime(x)) ^
+            ((y >> 2 & 1) * xtime(xtime(x))) ^
+            ((y >> 3 & 1) * xtime(xtime(xtime(x)))) ^
+            ((y >> 4 & 1) * xtime(xtime(xtime(xtime(x))))));
+}
+
+static inline void aes_inv_mix_cols(uint8_t round_state[4][4])
+{
+    uint8_t a, b, c, d;
+    for (int j = 0; j < 4; ++j) {
+        a = round_state[j][0];
+        b = round_state[j][1];
+        c = round_state[j][2];
+        d = round_state[j][3];
+        round_state[j][0] = multiply(a, 0x0e) ^ multiply(b, 0x0b) ^
+                            multiply(c, 0x0d) ^ multiply(d, 0x09);
+        round_state[j][1] = multiply(a, 0x09) ^ multiply(b, 0x0e) ^
+                            multiply(c, 0x0b) ^ multiply(d, 0x0d);
+        round_state[j][2] = multiply(a, 0x0d) ^ multiply(b, 0x09) ^
+                            multiply(c, 0x0e) ^ multiply(d, 0x0b);
+        round_state[j][3] = multiply(a, 0x0b) ^ multiply(b, 0x0d) ^
+                            multiply(c, 0x09) ^ multiply(d, 0x0e);
+    }
+}
+
+static inline void aes_mix_cols(uint8_t round_state[4][4])
+{
+    uint8_t a, b;
+    for (int j = 0; j < 4; ++j) {
+        a = round_state[j][0];
+        b = round_state[j][0] ^ round_state[j][1] ^ round_state[j][2] ^
+            round_state[j][3];
+        round_state[j][0] ^= xtime(round_state[j][0] ^ round_state[j][1]) ^ b;
+        round_state[j][1] ^= xtime(round_state[j][1] ^ round_state[j][2]) ^ b;
+        round_state[j][2] ^= xtime(round_state[j][2] ^ round_state[j][3]) ^ b;
+        round_state[j][3] ^= xtime(round_state[j][3] ^ a) ^ b;
+    }
+}
+
+#define GEN_ZVKNED_HELPER_VV(NAME, ...)                                   \
+    void HELPER(NAME)(void *vd_vptr, void *vs2_vptr, CPURISCVState *env,  \
+                      uint32_t desc)                                      \
+    {                                                                     \
+        uint64_t *vd = vd_vptr;                                           \
+        uint64_t *vs2 = vs2_vptr;                                         \
+        uint32_t vl = env->vl;                                            \
+        uint32_t total_elems = vext_get_total_elems(env, desc, 4);        \
+        uint32_t vta = vext_vta(desc);                                    \
+                                                                          \
+        for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) {        \
+            uint64_t round_key[2] = {                                     \
+                cpu_to_le64(vs2[i * 2 + 0]),                              \
+                cpu_to_le64(vs2[i * 2 + 1]),                              \
+            };                                                            \
+            uint8_t round_state[4][4];                                    \
+            cpu_to_le64s(vd + i * 2 + 0);                                 \
+            cpu_to_le64s(vd + i * 2 + 1);                                 \
+            for (int j = 0; j < 16; j++) {                                \
+                round_state[j / 4][j % 4] = ((uint8_t *)(vd + i * 2))[j]; \
+            }                                                             \
+            __VA_ARGS__;                                                  \
+            for (int j = 0; j < 16; j++) {                                \
+                ((uint8_t *)(vd + i * 2))[j] = round_state[j / 4][j % 4]; \
+            }                                                             \
+            le64_to_cpus(vd + i * 2 + 0);                                 \
+            le64_to_cpus(vd + i * 2 + 1);                                 \
+        }                                                                 \
+        env->vstart = 0;                                                  \
+        /* set tail elements to 1s */                                     \
+        vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4);              \
+    }
+
+#define GEN_ZVKNED_HELPER_VS(NAME, ...)                                   \
+    void HELPER(NAME)(void *vd_vptr, void *vs2_vptr, CPURISCVState *env,  \
+                      uint32_t desc)                                      \
+    {                                                                     \
+        uint64_t *vd = vd_vptr;                                           \
+        uint64_t *vs2 = vs2_vptr;                                         \
+        uint32_t vl = env->vl;                                            \
+        uint32_t total_elems = vext_get_total_elems(env, desc, 4);        \
+        uint32_t vta = vext_vta(desc);                                    \
+                                                                          \
+        for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) {        \
+            uint64_t round_key[2] = {                                     \
+                cpu_to_le64(vs2[0]),                                      \
+                cpu_to_le64(vs2[1]),                                      \
+            };                                                            \
+            uint8_t round_state[4][4];                                    \
+            cpu_to_le64s(vd + i * 2 + 0);                                 \
+            cpu_to_le64s(vd + i * 2 + 1);                                 \
+            for (int j = 0; j < 16; j++) {                                \
+                round_state[j / 4][j % 4] = ((uint8_t *)(vd + i * 2))[j]; \
+            }                                                             \
+            __VA_ARGS__;                                                  \
+            for (int j = 0; j < 16; j++) {                                \
+                ((uint8_t *)(vd + i * 2))[j] = round_state[j / 4][j % 4]; \
+            }                                                             \
+            le64_to_cpus(vd + i * 2 + 0);                                 \
+            le64_to_cpus(vd + i * 2 + 1);                                 \
+        }                                                                 \
+        env->vstart = 0;                                                  \
+        /* set tail elements to 1s */                                     \
+        vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4);              \
+    }

See

https://lore.kernel.org/qemu-devel/20230620110758.787479-1-richard.hender...@linaro.org/

which should greatly simplify all of this.


r~

Reply via email to