On 6/22/23 18:16, Max Chou wrote:
--- a/target/riscv/vcrypto_helper.c +++ b/target/riscv/vcrypto_helper.c @@ -22,6 +22,7 @@ #include "qemu/bitops.h" #include "qemu/bswap.h" #include "cpu.h" +#include "crypto/aes.h" #include "exec/memop.h" #include "exec/exec-all.h" #include "exec/helper-proto.h" @@ -195,3 +196,310 @@ RVVCALL(OPIVX2, vwsll_vx_w, WOP_UUU_W, H8, H4, DO_SLL) GEN_VEXT_VX(vwsll_vx_b, 2) GEN_VEXT_VX(vwsll_vx_h, 4) GEN_VEXT_VX(vwsll_vx_w, 8) + +static inline void aes_sub_bytes(uint8_t round_state[4][4]) +{ + for (int j = 0; j < 16; j++) { + round_state[j / 4][j % 4] = AES_sbox[round_state[j / 4][j % 4]]; + } +} + +static inline void aes_shift_bytes(uint8_t round_state[4][4]) +{ + uint8_t temp; + temp = round_state[0][1]; + round_state[0][1] = round_state[1][1]; + round_state[1][1] = round_state[2][1]; + round_state[2][1] = round_state[3][1]; + round_state[3][1] = temp; + temp = round_state[0][2]; + round_state[0][2] = round_state[2][2]; + round_state[2][2] = temp; + temp = round_state[1][2]; + round_state[1][2] = round_state[3][2]; + round_state[3][2] = temp; + temp = round_state[0][3]; + round_state[0][3] = round_state[3][3]; + round_state[3][3] = round_state[2][3]; + round_state[2][3] = round_state[1][3]; + round_state[1][3] = temp; +} + +static inline void xor_round_key(uint8_t round_state[4][4], uint8_t *round_key) +{ + for (int j = 0; j < 16; j++) { + round_state[j / 4][j % 4] = round_state[j / 4][j % 4] ^ (round_key)[j]; + } +} + +static inline void aes_inv_sub_bytes(uint8_t round_state[4][4]) +{ + for (int j = 0; j < 16; j++) { + round_state[j / 4][j % 4] = AES_isbox[round_state[j / 4][j % 4]]; + } +} + +static inline void aes_inv_shift_bytes(uint8_t round_state[4][4]) +{ + uint8_t temp; + temp = round_state[3][1]; + round_state[3][1] = round_state[2][1]; + round_state[2][1] = round_state[1][1]; + round_state[1][1] = round_state[0][1]; + round_state[0][1] = temp; + temp = round_state[0][2]; + round_state[0][2] = round_state[2][2]; + round_state[2][2] = temp; + temp = round_state[1][2]; + round_state[1][2] = round_state[3][2]; + round_state[3][2] = temp; + temp = round_state[0][3]; + round_state[0][3] = round_state[1][3]; + round_state[1][3] = round_state[2][3]; + round_state[2][3] = round_state[3][3]; + round_state[3][3] = temp; +} + +static inline uint8_t xtime(uint8_t x) +{ + return (x << 1) ^ (((x >> 7) & 1) * 0x1b); +} + +static inline uint8_t multiply(uint8_t x, uint8_t y) +{ + return (((y & 1) * x) ^ ((y >> 1 & 1) * xtime(x)) ^ + ((y >> 2 & 1) * xtime(xtime(x))) ^ + ((y >> 3 & 1) * xtime(xtime(xtime(x)))) ^ + ((y >> 4 & 1) * xtime(xtime(xtime(xtime(x)))))); +} + +static inline void aes_inv_mix_cols(uint8_t round_state[4][4]) +{ + uint8_t a, b, c, d; + for (int j = 0; j < 4; ++j) { + a = round_state[j][0]; + b = round_state[j][1]; + c = round_state[j][2]; + d = round_state[j][3]; + round_state[j][0] = multiply(a, 0x0e) ^ multiply(b, 0x0b) ^ + multiply(c, 0x0d) ^ multiply(d, 0x09); + round_state[j][1] = multiply(a, 0x09) ^ multiply(b, 0x0e) ^ + multiply(c, 0x0b) ^ multiply(d, 0x0d); + round_state[j][2] = multiply(a, 0x0d) ^ multiply(b, 0x09) ^ + multiply(c, 0x0e) ^ multiply(d, 0x0b); + round_state[j][3] = multiply(a, 0x0b) ^ multiply(b, 0x0d) ^ + multiply(c, 0x09) ^ multiply(d, 0x0e); + } +} + +static inline void aes_mix_cols(uint8_t round_state[4][4]) +{ + uint8_t a, b; + for (int j = 0; j < 4; ++j) { + a = round_state[j][0]; + b = round_state[j][0] ^ round_state[j][1] ^ round_state[j][2] ^ + round_state[j][3]; + round_state[j][0] ^= xtime(round_state[j][0] ^ round_state[j][1]) ^ b; + round_state[j][1] ^= xtime(round_state[j][1] ^ round_state[j][2]) ^ b; + round_state[j][2] ^= xtime(round_state[j][2] ^ round_state[j][3]) ^ b; + round_state[j][3] ^= xtime(round_state[j][3] ^ a) ^ b; + } +} + +#define GEN_ZVKNED_HELPER_VV(NAME, ...) \ + void HELPER(NAME)(void *vd_vptr, void *vs2_vptr, CPURISCVState *env, \ + uint32_t desc) \ + { \ + uint64_t *vd = vd_vptr; \ + uint64_t *vs2 = vs2_vptr; \ + uint32_t vl = env->vl; \ + uint32_t total_elems = vext_get_total_elems(env, desc, 4); \ + uint32_t vta = vext_vta(desc); \ + \ + for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { \ + uint64_t round_key[2] = { \ + cpu_to_le64(vs2[i * 2 + 0]), \ + cpu_to_le64(vs2[i * 2 + 1]), \ + }; \ + uint8_t round_state[4][4]; \ + cpu_to_le64s(vd + i * 2 + 0); \ + cpu_to_le64s(vd + i * 2 + 1); \ + for (int j = 0; j < 16; j++) { \ + round_state[j / 4][j % 4] = ((uint8_t *)(vd + i * 2))[j]; \ + } \ + __VA_ARGS__; \ + for (int j = 0; j < 16; j++) { \ + ((uint8_t *)(vd + i * 2))[j] = round_state[j / 4][j % 4]; \ + } \ + le64_to_cpus(vd + i * 2 + 0); \ + le64_to_cpus(vd + i * 2 + 1); \ + } \ + env->vstart = 0; \ + /* set tail elements to 1s */ \ + vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); \ + } + +#define GEN_ZVKNED_HELPER_VS(NAME, ...) \ + void HELPER(NAME)(void *vd_vptr, void *vs2_vptr, CPURISCVState *env, \ + uint32_t desc) \ + { \ + uint64_t *vd = vd_vptr; \ + uint64_t *vs2 = vs2_vptr; \ + uint32_t vl = env->vl; \ + uint32_t total_elems = vext_get_total_elems(env, desc, 4); \ + uint32_t vta = vext_vta(desc); \ + \ + for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { \ + uint64_t round_key[2] = { \ + cpu_to_le64(vs2[0]), \ + cpu_to_le64(vs2[1]), \ + }; \ + uint8_t round_state[4][4]; \ + cpu_to_le64s(vd + i * 2 + 0); \ + cpu_to_le64s(vd + i * 2 + 1); \ + for (int j = 0; j < 16; j++) { \ + round_state[j / 4][j % 4] = ((uint8_t *)(vd + i * 2))[j]; \ + } \ + __VA_ARGS__; \ + for (int j = 0; j < 16; j++) { \ + ((uint8_t *)(vd + i * 2))[j] = round_state[j / 4][j % 4]; \ + } \ + le64_to_cpus(vd + i * 2 + 0); \ + le64_to_cpus(vd + i * 2 + 1); \ + } \ + env->vstart = 0; \ + /* set tail elements to 1s */ \ + vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); \ + }
See https://lore.kernel.org/qemu-devel/20230620110758.787479-1-richard.hender...@linaro.org/ which should greatly simplify all of this. r~