Module: Mesa Branch: main Commit: 839495efc67cce5b44e76abc35b3b355c75a88c7 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=839495efc67cce5b44e76abc35b3b355c75a88c7
Author: Ian Romanick <[email protected]> Date: Wed Jun 9 14:53:49 2021 -0700 nir/algebraic: Add lowering for dot_4x8 instructions v2: Fix copy-and-paste bugs in lowering patterns. v3: Add has_sudot_4x8 flag. Requested by Rhys. v4: Since the names of the opcodes changed from dp4 to dot_4x8, also change the names of the lowering helpers. Suggested by Jason. Reviewed-by: Jason Ekstrand <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142> --- src/compiler/nir/nir_opt_algebraic.py | 36 +++++++++++++++++++++++++++++++++++ src/compiler/nir/nir_search_helpers.h | 21 ++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 2f0a044e2a8..ee876b8b7ab 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -223,6 +223,42 @@ optimizations = [ (('sudot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sudot_4x8_iadd', a, b, 0), c), '!options->lower_add_sat'), ] +# Shorthand for the expansion of just the dot product part of the [iu]dp4a +# instructions. +sdot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_i8', b, 0)), + ('imul', ('extract_i8', a, 1), ('extract_i8', b, 1))), + ('iadd', ('imul', ('extract_i8', a, 2), ('extract_i8', b, 2)), + ('imul', ('extract_i8', a, 3), ('extract_i8', b, 3)))) +udot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_u8', a, 0), ('extract_u8', b, 0)), + ('imul', ('extract_u8', a, 1), ('extract_u8', b, 1))), + ('iadd', ('imul', ('extract_u8', a, 2), ('extract_u8', b, 2)), + ('imul', ('extract_u8', a, 3), ('extract_u8', b, 3)))) +sudot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_u8', b, 0)), + ('imul', ('extract_i8', a, 1), ('extract_u8', b, 1))), + ('iadd', ('imul', ('extract_i8', a, 2), ('extract_u8', b, 2)), + ('imul', ('extract_i8', a, 3), ('extract_u8', b, 3)))) + +optimizations.extend([ + (('sdot_4x8_iadd', a, b, c), ('iadd', sdot_4x8_a_b, c), '!options->has_dot_4x8'), + (('udot_4x8_uadd', a, b, c), ('iadd', udot_4x8_a_b, c), '!options->has_dot_4x8'), + (('sudot_4x8_iadd', a, b, c), ('iadd', sudot_4x8_a_b, c), '!options->has_sudot_4x8'), + + # For the unsigned dot-product, the largest possible value 4*(255*255) = + # 0x3f804, so we don't have to worry about that intermediate result + # overflowing. 0x100000000 - 0x3f804 = 0xfffc07fc. If c is a constant + # that is less than 0xfffc07fc, then the result cannot overflow ever. + (('udot_4x8_uadd_sat', a, b, '#c(is_ult_0xfffc07fc)'), ('udot_4x8_uadd', a, b, c)), + (('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', udot_4x8_a_b, c), '!options->has_dot_4x8'), + + # For the signed dot-product, the largest positive value is 4*(-128*-128) = + # 0x10000, and the largest negative value is 4*(-128*127) = -0xfe00. We + # don't have to worry about that intermediate result overflowing or + # underflowing. + (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', sdot_4x8_a_b, c), '!options->has_dot_4x8'), + + (('sudot_4x8_iadd_sat', a, b, c), ('iadd_sat', sudot_4x8_a_b, c), '!options->has_sudot_4x8'), +]) + # Float sizes for s in [16, 32, 64]: optimizations.extend([ diff --git a/src/compiler/nir/nir_search_helpers.h b/src/compiler/nir/nir_search_helpers.h index 24938484377..1188b50ed2d 100644 --- a/src/compiler/nir/nir_search_helpers.h +++ b/src/compiler/nir/nir_search_helpers.h @@ -205,6 +205,27 @@ is_not_const_zero(UNUSED struct hash_table *ht, const nir_alu_instr *instr, return true; } +/** Is value unsigned less than 0xfffc07fc? */ +static inline bool +is_ult_0xfffc07fc(UNUSED struct hash_table *ht, const nir_alu_instr *instr, + unsigned src, unsigned num_components, + const uint8_t *swizzle) +{ + /* only constant srcs: */ + if (!nir_src_is_const(instr->src[src].src)) + return false; + + for (unsigned i = 0; i < num_components; i++) { + const unsigned val = + nir_src_comp_as_uint(instr->src[src].src, swizzle[i]); + + if (val >= 0xfffc07fcU) + return false; + } + + return true; +} + static inline bool is_not_const(UNUSED struct hash_table *ht, const nir_alu_instr *instr, unsigned src, UNUSED unsigned num_components,
