Module: Mesa Branch: main Commit: fe956d0182aeb54d03fdc69711ceae15cc29168a URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=fe956d0182aeb54d03fdc69711ceae15cc29168a
Author: Ian Romanick <[email protected]> Date: Mon Jun 14 14:12:36 2021 -0700 spirv: Add support for SPV_KHR_integer_dot_product v2 (Ivan): Add missing capability enum handling. v3 (idr): Properly handle cases where dest_size != 32. v4 (idr): Rewrite most of the error checking to use vtn_fail_if. Use nir_ssa_def with vtn_push_nir_ssa instead of vtn_ssa_value with vtn_push_ssa_value. All suggested by Jason. Massive rewrite of the handling of packed 4x8 saturating opcodes. Based on some observations made by Jason. v5 (idr): Remove some debugging cruft accidentally added in v4. Noticed by Jason. v6: Emit packed versions of vectored instructions when possible. Suggested by Jason. Reviewed-by: Jason Ekstrand <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142> --- src/compiler/spirv/spirv_to_nir.c | 13 ++ src/compiler/spirv/vtn_alu.c | 272 ++++++++++++++++++++++++++++++++++++++ src/compiler/spirv/vtn_private.h | 3 + 3 files changed, 288 insertions(+) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 983d8f9f06f..a64039f9469 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4364,6 +4364,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, case SpvCapabilityImageGatherExtended: case SpvCapabilityStorageImageExtendedFormats: case SpvCapabilityVector16: + case SpvCapabilityDotProductKHR: + case SpvCapabilityDotProductInputAllKHR: + case SpvCapabilityDotProductInput4x8BitKHR: + case SpvCapabilityDotProductInput4x8BitPackedKHR: break; case SpvCapabilityLinkage: @@ -5650,6 +5654,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, vtn_handle_alu(b, opcode, w, count); break; + case SpvOpSDotKHR: + case SpvOpUDotKHR: + case SpvOpSUDotKHR: + case SpvOpSDotAccSatKHR: + case SpvOpUDotAccSatKHR: + case SpvOpSUDotAccSatKHR: + vtn_handle_integer_dot(b, opcode, w, count); + break; + case SpvOpBitcast: vtn_handle_bitcast(b, w, count); break; diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 48f41ac249a..ed731184d2d 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -765,6 +765,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpSDotKHR: + case SpvOpUDotKHR: + case SpvOpSUDotKHR: + case SpvOpSDotAccSatKHR: + case SpvOpUDotAccSatKHR: + case SpvOpSUDotAccSatKHR: + unreachable("Should have called vtn_handle_integer_dot instead."); + default: { bool swap; bool exact; @@ -823,6 +831,270 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, b->nb.exact = b->exact; } +void +vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); + const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; + const unsigned dest_size = glsl_get_bit_size(dest_type); + + vtn_handle_no_contraction(b, dest_val); + + /* Collect the various SSA sources. + * + * Due to the optional "Packed Vector Format" field, determine number of + * inputs from the opcode. This differs from vtn_handle_alu. + */ + const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR || + opcode == SpvOpUDotAccSatKHR || + opcode == SpvOpSUDotAccSatKHR) ? 3 : 2; + + vtn_assert(count >= num_inputs + 3); + + struct vtn_ssa_value *vtn_src[3] = { NULL, }; + nir_ssa_def *src[3] = { NULL, }; + + for (unsigned i = 0; i < num_inputs; i++) { + vtn_src[i] = vtn_ssa_value(b, w[i + 3]); + src[i] = vtn_src[i]->def; + + vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); + } + + /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR, + * the SPV_KHR_integer_dot_product spec says: + * + * _Vector 1_ and _Vector 2_ must have the same type. + * + * The practical requirement is the same bit-size and the same number of + * components. + */ + vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) != + glsl_get_bit_size(vtn_src[1]->type) || + glsl_get_vector_elements(vtn_src[0]->type) != + glsl_get_vector_elements(vtn_src[1]->type), + "Vector 1 and vector 2 source of opcode %s must have the same " + "type", + spirv_op_to_string(opcode)); + + if (num_inputs == 3) { + /* The SPV_KHR_integer_dot_product spec says: + * + * The type of Accumulator must be the same as Result Type. + * + * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8 + * types (far below) assumes these types have the same size. + */ + vtn_fail_if(dest_type != vtn_src[2]->type, + "Accumulator type must be the same as Result Type for " + "opcode %s", + spirv_op_to_string(opcode)); + } + + if (glsl_type_is_vector(vtn_src[0]->type)) { + /* FINISHME: Is this actually as good or better for platforms that don't + * have the special instructions (i.e., one or both of has_dot_4x8 or + * has_sudot_4x8 is false)? + */ + if (glsl_get_vector_elements(vtn_src[0]->type) == 4 && + glsl_get_bit_size(vtn_src[0]->type) == 8 && + glsl_get_bit_size(dest_type) <= 32) { + src[0] = nir_pack_32_4x8(&b->nb, src[0]); + src[1] = nir_pack_32_4x8(&b->nb, src[1]); + } + } else if (glsl_type_is_scalar(vtn_src[0]->type) && + glsl_type_is_32bit(vtn_src[0]->type)) { + /* The SPV_KHR_integer_dot_product spec says: + * + * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed + * Vector Format_ must be specified to select how the integers are to + * be interpreted as vectors. + * + * The "Packed Vector Format" value follows the last input. + */ + vtn_assert(count == (num_inputs + 4)); + const SpvPackedVectorFormat pack_format = w[num_inputs + 3]; + vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR, + "Unsupported vector packing format %d for opcode %s", + pack_format, spirv_op_to_string(opcode)); + } else { + vtn_fail_with_opcode("Invalid source types.", opcode); + } + + nir_ssa_def *dest = NULL; + + if (src[0]->num_components > 1) { + const nir_op s_conversion_op = + nir_type_conversion_op(nir_type_int, nir_type_int | dest_size, + nir_rounding_mode_undef); + + const nir_op u_conversion_op = + nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size, + nir_rounding_mode_undef); + + nir_op src0_conversion_op; + nir_op src1_conversion_op; + + switch (opcode) { + case SpvOpSDotKHR: + case SpvOpSDotAccSatKHR: + src0_conversion_op = s_conversion_op; + src1_conversion_op = s_conversion_op; + break; + + case SpvOpUDotKHR: + case SpvOpUDotAccSatKHR: + src0_conversion_op = u_conversion_op; + src1_conversion_op = u_conversion_op; + break; + + case SpvOpSUDotKHR: + case SpvOpSUDotAccSatKHR: + src0_conversion_op = s_conversion_op; + src1_conversion_op = u_conversion_op; + break; + + default: + unreachable("Invalid opcode."); + } + + /* The SPV_KHR_integer_dot_product spec says: + * + * All components of the input vectors are sign-extended to the bit + * width of the result's type. The sign-extended input vectors are + * then multiplied component-wise and all components of the vector + * resulting from the component-wise multiplication are added + * together. The resulting value will equal the low-order N bits of + * the correct result R, where N is the result width and R is + * computed with enough precision to avoid overflow and underflow. + */ + const unsigned vector_components = + glsl_get_vector_elements(vtn_src[0]->type); + + for (unsigned i = 0; i < vector_components; i++) { + nir_ssa_def *const src0 = + nir_build_alu(&b->nb, src0_conversion_op, + nir_channel(&b->nb, src[0], i), NULL, NULL, NULL); + + nir_ssa_def *const src1 = + nir_build_alu(&b->nb, src1_conversion_op, + nir_channel(&b->nb, src[1], i), NULL, NULL, NULL); + + nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1); + + dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result); + } + + if (num_inputs == 3) { + /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: + * + * Signed integer dot product of _Vector 1_ and _Vector 2_ and + * signed saturating addition of the result with _Accumulator_. + * + * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: + * + * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and + * unsigned saturating addition of the result with _Accumulator_. + * + * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: + * + * Mixed-signedness integer dot product of _Vector 1_ and _Vector + * 2_ and signed saturating addition of the result with + * _Accumulator_. + */ + dest = (opcode == SpvOpUDotAccSatKHR) + ? nir_uadd_sat(&b->nb, dest, src[2]) + : nir_iadd_sat(&b->nb, dest, src[2]); + } + } else { + assert(src[0]->num_components == 1 && src[1]->num_components == 1); + assert(src[0]->bit_size == 32 && src[1]->bit_size == 32); + + nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32); + bool is_signed; + + switch (opcode) { + case SpvOpSDotKHR: + dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); + is_signed = true; + break; + + case SpvOpUDotKHR: + dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); + is_signed = false; + break; + + case SpvOpSUDotKHR: + dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); + is_signed = true; + break; + + case SpvOpSDotAccSatKHR: + if (dest_size == 32) + dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); + + is_signed = true; + break; + + case SpvOpUDotAccSatKHR: + if (dest_size == 32) + dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); + + is_signed = false; + break; + + case SpvOpSUDotAccSatKHR: + if (dest_size == 32) + dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); + + is_signed = true; + break; + + default: + unreachable("Invalid opcode."); + } + + if (dest_size != 32) { + /* When the accumulator is 32-bits, a NIR dot-product with saturate + * is generated above. In all other cases a regular dot-product is + * generated above, and separate addition with saturate is generated + * here. + * + * The SPV_KHR_integer_dot_product spec says: + * + * If any of the multiplications or additions, with the exception + * of the final accumulation, overflow or underflow, the result of + * the instruction is undefined. + * + * Therefore it is safe to cast the dot-product result down to the + * size of the accumulator before doing the addition. Since the + * result of the dot-product cannot overflow 32-bits, this is also + * safe to cast up. + */ + if (num_inputs == 3) { + dest = is_signed + ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2]) + : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]); + } else { + dest = is_signed + ? nir_i2i(&b->nb, dest, dest_size) + : nir_u2u(&b->nb, dest, dest_size); + } + } + } + + vtn_push_nir_ssa(b, w[2], dest); + + b->nb.exact = b->exact; +} + void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) { diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index f2cfe144405..d95e3c72e81 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -919,6 +919,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count); +void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count); + void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count);
