On Sat, Apr 28, 2018 at 4:14 AM, Karol Herbst <kher...@redhat.com> wrote:
> OpenCL has explicit casts where one can specify the rounding mode and put a > sat modifier: > > https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/ > man/xhtml/convert_T.html > > _sat is valid for all conversions to an integer type and rounding modes are > valid for all conversions involving floats. > > Allthough the FPRoundingMode modifier is allowed without any restrictions > in > capabilities, it can only be used together with fp16 in GLSL. Additionally > it > can be used for conversions to/from floating points in OpenCL. > > The SaturatedConversion modifier, OpSatConvertUToS and OpSatConvertSToU are > only supported for Kernels, so current drivers are safe. > > Signed-off-by: Karol Herbst <kher...@redhat.com> > --- > src/compiler/glsl/glsl_to_nir.cpp | 2 +- > src/compiler/nir/nir.h | 2 +- > src/compiler/nir/nir_opcodes.py | 28 +++++----- > src/compiler/nir/nir_opcodes_c.py | 26 +++++---- > src/compiler/spirv/spirv_to_nir.c | 4 +- > src/compiler/spirv/vtn_alu.c | 108 ++++++++++++++++++++++++------ > -------- > src/compiler/spirv/vtn_glsl450.c | 2 +- > src/compiler/spirv/vtn_private.h | 2 +- > 8 files changed, 107 insertions(+), 67 deletions(-) > > diff --git a/src/compiler/glsl/glsl_to_nir.cpp > b/src/compiler/glsl/glsl_to_nir.cpp > index 8e5e9c34912..fcb6ef27e47 100644 > --- a/src/compiler/glsl/glsl_to_nir.cpp > +++ b/src/compiler/glsl/glsl_to_nir.cpp > @@ -1589,7 +1589,7 @@ nir_visitor::visit(ir_expression *ir) > nir_alu_type src_type = nir_get_nir_type_for_glsl_ > base_type(types[0]); > nir_alu_type dst_type = nir_get_nir_type_for_glsl_ > base_type(out_type); > result = nir_build_alu(&b, nir_type_conversion_op(src_type, > dst_type, > - nir_rounding_mode_undef), > + nir_rounding_mode_undef, false), > srcs[0], NULL, NULL, NULL); > /* b2i and b2f don't have fixed bit-size versions so the builder > will > * just assume 32 and we have to fix it up here. > diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h > index f3326e6df94..f32e5bd8bb2 100644 > --- a/src/compiler/nir/nir.h > +++ b/src/compiler/nir/nir.h > @@ -784,7 +784,7 @@ nir_get_nir_type_for_glsl_type(const struct glsl_type > *type) > } > > nir_op nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, > - nir_rounding_mode rnd); > + nir_rounding_mode rnd, bool saturation); > > typedef enum { > NIR_OP_IS_COMMUTATIVE = (1 << 0), > diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_ > opcodes.py > index f4cd175bc6a..9c51f77bf1b 100644 > --- a/src/compiler/nir/nir_opcodes.py > +++ b/src/compiler/nir/nir_opcodes.py > @@ -168,26 +168,28 @@ unop("flog2", tfloat, "log2f(src0)") > > # Generate all of the numeric conversion opcodes > for src_t in [tint, tuint, tfloat]: > - if src_t in (tint, tuint): > - dst_types = [tfloat, src_t] > - elif src_t == tfloat: > - dst_types = [tint, tuint, tfloat] > - > - for dst_t in dst_types: > + for dst_t in [tint, tuint, tfloat]: > if dst_t == tfloat: > bit_sizes = [16, 32, 64] > + sat_modes = [''] > else: > bit_sizes = [8, 16, 32, 64] > + if src_t != tfloat and dst_t != src_t: > + sat_modes = ['_sat'] > + else: > + sat_modes = ['_sat', ''] > for bit_size in bit_sizes: > - if dst_t == tfloat and src_t == tfloat: > - rnd_modes = ['_rtne', '_rtz', ''] > - for rnd_mode in rnd_modes: > + for sat_mode in sat_modes: > + if src_t == tfloat or dst_t == tfloat: > + for rnd_mode in ['_rtne', '_rtz', '_ru', '_rd', '']: > + unop_convert("{0}2{1}{2}{3}{4}".format(src_t[0], > dst_t[0], > + bit_size, > rnd_mode, > + sat_mode), > + dst_t + str(bit_size), src_t, "src0") > + else: > unop_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0], > - bit_size, > rnd_mode), > + bit_size, sat_mode), > dst_t + str(bit_size), src_t, "src0") > - else: > - unop_convert("{0}2{1}{2}".format(src_t[0], dst_t[0], > bit_size), > - dst_t + str(bit_size), src_t, "src0") > As I mentioned on IRC, we need proper constant folding. Getting rounding modes on f32->f16 wrong isn't good and I probably shouldn't have let it through. Let's not make the problem worse. Not correctly handling _sat is especially bad. > > # We'll hand-code the to/from bool conversion opcodes. Because bool > doesn't > # have multiple bit-sizes, we can always infer the size from the other > type. > diff --git a/src/compiler/nir/nir_opcodes_c.py b/src/compiler/nir/nir_ > opcodes_c.py > index 19079f86e7b..9b8642f0cc1 100644 > --- a/src/compiler/nir/nir_opcodes_c.py > +++ b/src/compiler/nir/nir_opcodes_c.py > @@ -30,7 +30,8 @@ template = Template(""" > #include "nir.h" > > nir_op > -nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, > nir_rounding_mode rnd) > +nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, > nir_rounding_mode rnd, > + bool saturate) > { > nir_alu_type src_base = (nir_alu_type) nir_alu_type_get_base_type( > src); > nir_alu_type dst_base = (nir_alu_type) nir_alu_type_get_base_type( > dst); > @@ -41,7 +42,8 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type > dst, nir_rounding_mode rnd > return nir_op_fmov; > } else if ((src_base == nir_type_int || src_base == nir_type_uint) && > (dst_base == nir_type_int || dst_base == nir_type_uint) && > - src_bit_size == dst_bit_size) { > + src_bit_size == dst_bit_size && > + (src_base == dst_base || !saturate)) { > /* Integer <-> integer conversions with the same bit-size on both > * ends are just no-op moves. > */ > @@ -54,12 +56,9 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type > dst, nir_rounding_mode rnd > switch (dst_base) { > % for dst_t in ['int', 'uint', 'float']: > case nir_type_${dst_t}: > +<% orig_dst_t = dst_t %> > % if src_t in ['int', 'uint'] and dst_t in ['int', 'uint']: > -% if dst_t == 'int': > -<% continue %> > -% else: > -<% dst_t = src_t %> > -% endif > +<% dst_t = src_t %> > % endif > switch (dst_bit_size) { > % if dst_t == 'float': > @@ -69,18 +68,25 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type > dst, nir_rounding_mode rnd > % endif > % for dst_bits in bit_sizes: > case ${dst_bits}: > -% if src_t == 'float' and dst_t == 'float': > +% if src_t == 'float' or dst_t == 'float': > switch(rnd) { > -% for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'), > ('undef', '')]: > +% for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'), > ('ru', '_ru'), ('rd', '_rd'), ('undef', '')]: > case nir_rounding_mode_${rnd_t[0]}: > +% if dst_t != 'float': > + if (saturate) > + return > ${'nir_op_{0}2{1}{2}{3}_sat'.format(src_t[0], > dst_t[0], > + > dst_bits, rnd_t[1])}; > +% endif > return ${'nir_op_{0}2{1}{2}{3}'.format(src_t[0], > dst_t[0], > > dst_bits, rnd_t[1])}; > % endfor > default: > - unreachable("Invalid 16-bit nir rounding > mode"); > + unreachable("Invalid float nir rounding mode"); > } > % else: > assert(rnd == nir_rounding_mode_undef); > + if (saturate) > + return ${'nir_op_{0}2{1}{2}_sat'.format(src_t[0], > orig_dst_t[0], dst_bits)}; > return ${'nir_op_{0}2{1}{2}'.format(src_t[0], > dst_t[0], dst_bits)}; > % endif > % endfor > diff --git a/src/compiler/spirv/spirv_to_nir.c > b/src/compiler/spirv/spirv_to_nir.c > index 2a835f047e4..6f1a1871b38 100644 > --- a/src/compiler/spirv/spirv_to_nir.c > +++ b/src/compiler/spirv/spirv_to_nir.c > @@ -1726,7 +1726,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp > opcode, > bit_size = glsl_get_bit_size(val->type->type); > }; > > - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, > + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, > &swap, > > nir_alu_type_get_type_size(src_alu_type), > > nir_alu_type_get_type_size(dst_alu_type)); > nir_const_value src[4]; > @@ -3839,6 +3839,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, > SpvOp opcode, > case SpvOpUConvert: > case SpvOpSConvert: > case SpvOpFConvert: > + case SpvOpSatConvertUToS: > + case SpvOpSatConvertSToU: > case SpvOpQuantizeToF16: > case SpvOpConvertPtrToU: > case SpvOpConvertUToPtr: > diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c > index 3134849ba90..b96f7d688fb 100644 > --- a/src/compiler/spirv/vtn_alu.c > +++ b/src/compiler/spirv/vtn_alu.c > @@ -273,8 +273,46 @@ vtn_handle_bitcast(struct vtn_builder *b, struct > vtn_ssa_value *dest, > dest->def = nir_vec(&b->nb, dest_chan, dest_components); > } > > +static void > +handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int > member, > + const struct vtn_decoration *dec, void > *_out_rounding_mode) > +{ > + nir_rounding_mode *out_rounding_mode = _out_rounding_mode; > + assert(dec->scope == VTN_DEC_DECORATION); > + if (dec->decoration != SpvDecorationFPRoundingMode) > + return; > + switch (dec->literals[0]) { > + case SpvFPRoundingModeRTE: > + *out_rounding_mode = nir_rounding_mode_rtne; > + break; > + case SpvFPRoundingModeRTZ: > + *out_rounding_mode = nir_rounding_mode_rtz; > + break; > + case SpvFPRoundingModeRTP: > + *out_rounding_mode = nir_rounding_mode_ru; > + break; > + case SpvFPRoundingModeRTN: > + *out_rounding_mode = nir_rounding_mode_rd; > + break; > + default: > + unreachable("Not supported rounding mode"); > + break; > + } > +} > + > +static void > +handle_saturation(struct vtn_builder *b, struct vtn_value *val, int > member, > + const struct vtn_decoration *dec, void *_out_saturation) > +{ > + bool *out_saturation = _out_saturation; > + assert(dec->scope == VTN_DEC_DECORATION); > + if (dec->decoration != SpvDecorationSaturatedConversion) > + return; > + *out_saturation = true; > +} > + > nir_op > -vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, > +vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct vtn_value > *val, > SpvOp opcode, bool *swap, > unsigned src_bit_size, unsigned > dst_bit_size) > { > @@ -356,42 +394,67 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder > *b, > case SpvOpConvertSToF: > case SpvOpConvertUToF: > case SpvOpSConvert: > - case SpvOpFConvert: { > + case SpvOpFConvert: > + case SpvOpSatConvertUToS: > + case SpvOpSatConvertSToU: { > nir_alu_type src_type; > nir_alu_type dst_type; > > + nir_rounding_mode rounding_mode = nir_rounding_mode_undef; > + bool saturation = false; > + > switch (opcode) { > case SpvOpConvertFToS: > src_type = nir_type_float; > dst_type = nir_type_int; > + vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > + vtn_foreach_decoration(b, val, handle_saturation, &saturation); > break; > case SpvOpConvertFToU: > src_type = nir_type_float; > dst_type = nir_type_uint; > + vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > + vtn_foreach_decoration(b, val, handle_saturation, &saturation); > break; > case SpvOpFConvert: > src_type = dst_type = nir_type_float; > + vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > break; > case SpvOpConvertSToF: > src_type = nir_type_int; > dst_type = nir_type_float; > + vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > break; > case SpvOpSConvert: > src_type = dst_type = nir_type_int; > + vtn_foreach_decoration(b, val, handle_saturation, &saturation); > break; > case SpvOpConvertUToF: > src_type = nir_type_uint; > dst_type = nir_type_float; > + vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > break; > case SpvOpUConvert: > src_type = dst_type = nir_type_uint; > + vtn_foreach_decoration(b, val, handle_saturation, &saturation); > + break; > + case SpvOpSatConvertUToS: > + src_type = nir_type_uint; > + dst_type = nir_type_int; > + saturation = true; > + break; > + case SpvOpSatConvertSToU: > + src_type = nir_type_int; > + dst_type = nir_type_uint; > + saturation = true; > break; > default: > unreachable("Invalid opcode"); > } > src_type |= src_bit_size; > dst_type |= dst_bit_size; > - return nir_type_conversion_op(src_type, dst_type, > nir_rounding_mode_undef); > + > + return nir_type_conversion_op(src_type, dst_type, rounding_mode, > saturation); > } > /* Derivatives: */ > case SpvOpDPdx: return nir_op_fddx; > @@ -417,27 +480,6 @@ handle_no_contraction(struct vtn_builder *b, struct > vtn_value *val, int member, > b->nb.exact = true; > } > > -static void > -handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int > member, > - const struct vtn_decoration *dec, void > *_out_rounding_mode) > -{ > - nir_rounding_mode *out_rounding_mode = _out_rounding_mode; > - assert(dec->scope == VTN_DEC_DECORATION); > - if (dec->decoration != SpvDecorationFPRoundingMode) > - return; > - switch (dec->literals[0]) { > - case SpvFPRoundingModeRTE: > - *out_rounding_mode = nir_rounding_mode_rtne; > - break; > - case SpvFPRoundingModeRTZ: > - *out_rounding_mode = nir_rounding_mode_rtz; > - break; > - default: > - unreachable("Not supported rounding mode"); > - break; > - } > -} > - > void > vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, > const uint32_t *w, unsigned count) > @@ -579,7 +621,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, > bool swap; > unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); > unsigned dst_bit_size = glsl_get_bit_size(type); > - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, > + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap, > src_bit_size, > dst_bit_size); > > if (swap) { > @@ -605,7 +647,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, > bool swap; > unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); > unsigned dst_bit_size = glsl_get_bit_size(type); > - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, > + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap, > src_bit_size, > dst_bit_size); > > assert(!swap); > @@ -623,23 +665,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, > vtn_handle_bitcast(b, val->ssa, src[0]); > break; > > - case SpvOpFConvert: { > - nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_ > type(vtn_src[0]->type); > - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); > - nir_rounding_mode rounding_mode = nir_rounding_mode_undef; > - > - vtn_foreach_decoration(b, val, handle_rounding_mode, > &rounding_mode); > - nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, > rounding_mode); > - > - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, > NULL); > - break; > - } > - > default: { > bool swap; > unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); > unsigned dst_bit_size = glsl_get_bit_size(type); > - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, > + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap, > src_bit_size, > dst_bit_size); > > if (swap) { > diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_ > glsl450.c > index 6fa759b1bba..284371446b5 100644 > --- a/src/compiler/spirv/vtn_glsl450.c > +++ b/src/compiler/spirv/vtn_glsl450.c > @@ -659,7 +659,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum > GLSLstd450 entrypoint, > nir_op conversion_op = > nir_type_conversion_op(nir_type_float | eta->bit_size, > nir_type_float | I->bit_size, > - nir_rounding_mode_undef); > + nir_rounding_mode_undef, false); > eta = nir_build_alu(nb, conversion_op, eta, NULL, NULL, NULL); > } > /* k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I)) */ > diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_ > private.h > index b501bbf9b4a..0895c865fbb 100644 > --- a/src/compiler/spirv/vtn_private.h > +++ b/src/compiler/spirv/vtn_private.h > @@ -708,7 +708,7 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct > vtn_builder *, > void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value > *value, > vtn_execution_mode_foreach_cb cb, void > *data); > > -nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, > +nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct > vtn_value *val, > SpvOp opcode, bool *swap, > unsigned src_bit_size, unsigned > dst_bit_size); > > -- > 2.14.3 > > _______________________________________________ > mesa-dev mailing list > mesa-dev@lists.freedesktop.org > https://lists.freedesktop.org/mailman/listinfo/mesa-dev >
_______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev