Claudio Bantaloukas <[email protected]> writes:
> [...]
> @@ -4004,6 +4008,44 @@ SHAPE (ternary_bfloat_lane)
> typedef ternary_bfloat_lane_base<2> ternary_bfloat_lanex2_def;
> SHAPE (ternary_bfloat_lanex2)
> +/* sv<t0>_t svfoo[_t0](sv<t0>_t, svmfloat8_t, svmfloat8_t, uint64_t)
> +
> + where the final argument is an integer constant expression in the range
> + [0, 15]. */
> +struct ternary_mfloat8_lane_def
> + : public ternary_resize2_lane_base<8, TYPE_mfloat, TYPE_mfloat>
> +{
> + void
> + build (function_builder &b, const function_group_info &group) const
> override
> + {
> + gcc_assert (group.fpm_mode == FPM_set);
> + b.add_overloaded_functions (group, MODE_none);
> + build_all (b, "v0,v0,vM,vM,su64", group, MODE_none);
> + }
> +
> + bool
> + check (function_checker &c) const override
> + {
> + return c.require_immediate_lane_index (3, 2, 1);
> + }
> +
> + tree
> + resolve (function_resolver &r) const override
> + {
> + type_suffix_index type;
> + if (!r.check_num_arguments (5)
> + || (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
> + || !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
> + || !r.require_vector_type (2, VECTOR_TYPE_svmfloat8_t)
> + || !r.require_integer_immediate (3)
> + || !r.require_scalar_type (4, "int64_t"))
uint64_t
> + return error_mark_node;
> +
> + return r.resolve_to (r.mode_suffix_id, type, TYPE_SUFFIX_mf8,
> GROUP_none);
> + }
> +};
> +SHAPE (ternary_mfloat8_lane)
> +
> /* sv<t0>_t svfoo[_t0](sv<t0>_t, svbfloatt16_t, svbfloat16_t)
> sv<t0>_t svfoo[_n_t0](sv<t0>_t, svbfloat16_t, bfloat16_t). */
> struct ternary_bfloat_opt_n_def
> @@ -4019,6 +4061,46 @@ struct ternary_bfloat_opt_n_def
> };
> SHAPE (ternary_bfloat_opt_n)
>
> +/* sv<t0>_t svfoo[_t0](sv<t0>_t, svmfloatt8_t, svmfloat8_t)
> + sv<t0>_t svfoo[_n_t0](sv<t0>_t, svmfloat8_t, bfloat8_t). */
> +struct ternary_mfloat8_opt_n_def
> + : public ternary_resize2_opt_n_base<8, TYPE_mfloat, TYPE_mfloat>
> +{
> + void
> + build (function_builder &b, const function_group_info &group) const
> override
> + {
> + gcc_assert (group.fpm_mode == FPM_set);
> + b.add_overloaded_functions (group, MODE_none);
> + build_all (b, "v0,v0,vM,vM", group, MODE_none);
> + build_all (b, "v0,v0,vM,sM", group, MODE_n);
> + }
> +
> + tree
> + resolve (function_resolver &r) const override
> + {
> + type_suffix_index type;
> + if (!r.check_num_arguments (4)
> + || (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
> + || !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
> + || !r.require_scalar_type (3, "int64_t"))
> + return error_mark_node;
> +
> + tree scalar_form
> + = r.lookup_form (MODE_n, type, TYPE_SUFFIX_mf8, GROUP_none);
> + if (r.scalar_argument_p (2))
> + {
> + if (scalar_form)
> + return scalar_form;
> + return error_mark_node;
It looks like this would return error_mark_node without reporting
an error first.
> + }
> + if (scalar_form && !r.require_vector_or_scalar_type (2))
> + return error_mark_node;
> +
> + return r.resolve_to (r.mode_suffix_id, type, TYPE_SUFFIX_mf8,
> GROUP_none);
> + }
In this context (unlike finish_opt_n_resolution) we know that there is
a bijection between the vector and scalar forms. So I think we can just
add require_vector_or_scalar_type to the initial checks:
if (!r.check_num_arguments (4)
|| (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
|| !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
|| !r.require_vector_or_scalar_type (2)
|| !r.require_scalar_type (3, "int64_t"))
return error_mark_node;
auto mode = r.mode_suffix_id;
if (r.scalar_argument_p (2))
mode = MODE_n;
else if (!r.require_vector_type (2, VECTOR_TYPE_svmfloat8_t))
return error_mark_node;
return r.resolve_to (mode, type, TYPE_SUFFIX_mf8, GROUP_none);
(untested).
> [...]
> +;; -------------------------------------------------------------------------
> +;; ---- [FP] Mfloat8 Multiply-and-accumulate operations
> +;; -------------------------------------------------------------------------
> +;; Includes:
> +;; - FMLALB (vectors, FP8 to FP16)
> +;; - FMLALT (vectors, FP8 to FP16)
> +;; - FMLALB (indexed, FP8 to FP16)
> +;; - FMLALT (indexed, FP8 to FP16)
> +;; - FMLALLBB (vectors)
> +;; - FMLALLBB (indexed)
> +;; - FMLALLBT (vectors)
> +;; - FMLALLBT (indexed)
> +;; - FMLALLTB (vectors)
> +;; - FMLALLTB (indexed)
> +;; - FMLALLTT (vectors)
> +;; - FMLALLTT (indexed)
> +;; -------------------------------------------------------------------------
> +
> +(define_insn "@aarch64_sve_add_<sve2_fp8_fma_op><mode>"
> + [(set (match_operand:SVE_FULL_HSF 0 "register_operand")
> + (unspec:SVE_FULL_HSF
> + [(match_operand:SVE_FULL_HSF 1 "register_operand")
> + (match_operand:VNx16QI 2 "register_operand")
> + (match_operand:VNx16QI 3 "register_operand")
> + (reg:DI FPM_REGNUM)]
> + SVE2_FP8_TERNARY))]
> + "TARGET_SSVE_FP8FMA"
> + {@ [ cons: =0 , 1 , 2 , 3 ; attrs: movprfx ]
> + [ w , 0 , w , w ; * ]
> <sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b
> + [ ?&w , w , w , w ; yes ] movprfx\t%0,
> %1\;<sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b
> + }
> +)
> +
> +(define_insn "@aarch64_sve_add_lane_<sve2_fp8_fma_op><mode>"
> + [(set (match_operand:SVE_FULL_HSF 0 "register_operand")
> + (unspec:SVE_FULL_HSF
> + [(match_operand:SVE_FULL_HSF 1 "register_operand")
> + (match_operand:VNx16QI 2 "register_operand")
> + (match_operand:VNx16QI 3 "register_operand")
> + (match_operand:SI 4 "const_int_operand")
> + (reg:DI FPM_REGNUM)]
> + SVE2_FP8_TERNARY_LANE))]
> + "TARGET_SSVE_FP8FMA"
> + {@ [ cons: =0 , 1 , 2 , 3 ; attrs: movprfx ]
> + [ w , 0 , w , y ; * ]
> <sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b[%4]
> + [ ?&w , w , w , y ; yes ] movprfx\t%0,
> %1\;<sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b[%4]
> + }
> +)
> +
It goes against my instincts to ask for more cut-&-paste, but:
I think we should split the operator list into HF-only and SF-only,
rather than define invalid combinations. [ Hope I didn't suggest the
opposite earlier -- always a risk, unfortunately. :( ]
> [...]
> +/* SVE2 versions of fp8 multiply-accumulate instructions are enabled through
> +ssve-fp8fma. */
> +#define TARGET_SSVE_FP8FMA ((\
> + (TARGET_SVE2 && TARGET_FP8FMA) || TARGET_STREAMING) \
> + && (AARCH64_HAVE_ISA(SSVE_FP8FMA) || TARGET_NON_STREAMING))
Formatting nits, sorry, but: long line for the comment, and missing space
in the final line. Also, the comment doesn't cover the non-streaming case.
Maybe:
/* SVE2 versions of fp8 multiply-accumulate instructions are enabled for
non-streaming mode by +fp8fma and for streaming mode by +ssve-fp8fma. */
#define TARGET_SSVE_FP8FMA \
((TARGET_SVE2 && TARGET_FP8FMA) || TARGET_STREAMING) \
&& (AARCH64_HAVE_ISA (SSVE_FP8FMA) || TARGET_NON_STREAMING))
> diff --git a/gcc/doc/invoke.texi b/gcc/doc/invoke.texi
> index 93e096bc9d5..119f636dc16 100644
> --- a/gcc/doc/invoke.texi
> +++ b/gcc/doc/invoke.texi
> @@ -21824,6 +21824,10 @@ Enable support for Armv8.9-a/9.4-a translation
> hardening extension.
> Enable the RCpc3 (Release Consistency) extension.
> @item fp8
> Enable the fp8 (8-bit floating point) extension.
> +@item fp8fma
> +Enable the fp8 (8-bit floating point) multiply accumulate extension.
> +@item ssve-fp8fma
> +Enable the fp8 (8-bit floating point) multiply accumulate extension
> streaming mode.
Maybe "in streaming mode"? Also: the usual 80-character line limit applies
here too, where possible.
> [...]
> diff --git a/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c
> b/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c
> new file mode 100644
> index 00000000000..5b43f4d6611
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c
> @@ -0,0 +1,88 @@
> +/* { dg-final { check-function-bodies "**" "" "-DCHECK_ASM" } } */
> +/* { dg-additional-options "-march=armv8.5-a+sve2+fp8fma" } */
> +/* { dg-require-effective-target aarch64_asm_fp8fma_ok } */
> +/* { dg-require-effective-target aarch64_asm_ssve-fp8fma_ok } */
> +/* { dg-skip-if "" { *-*-* } { "-DSTREAMING_COMPATIBLE" } { "" } } */
> +
> +#include "test_sve_acle.h"
Following on from the comment on patch 3, the corresponding change here
would probably be:
/* { dg-do assemble { target aarch64_asm_ssve-fp8fma_ok } } */
/* { dg-do compile { target { ! aarch64_asm_ssve-fp8fma_ok } } } */
/* { dg-final { check-function-bodies "**" "" "-DCHECK_ASM" } } */
#include "test_sve_acle.h"
#pragma GCC target "+fp8fma"
#ifdef STREAMING_COMPATIBLE
#pragma GCC target "+ssve-fp8fma"
#endif
(which assumes that +ssve-fp8fma is good for +fp8fma too).
> +/*
> +** mlalb_lane_0_f16_tied1:
> +** msr fpmr, x0
> +** fmlalb z0\.h, z4\.b, z5\.b\[0\]
> +** ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_0_f16_tied1, svfloat16_t, svmfloat8_t,
> + z0 = svmlalb_lane_f16_mf8_fpm (z0, z4, z5, 0, fpm0),
> + z0 = svmlalb_lane_fpm (z0, z4, z5, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_tied2:
> +** msr fpmr, x0
> +** mov (z[0-9]+)\.d, z0\.d
> +** movprfx z0, z4
> +** fmlalb z0\.h, \1\.b, z1\.b\[0\]
> +** ret
> +*/
> +TEST_DUAL_Z_REV (mlalb_lane_0_f16_tied2, svfloat16_t, svmfloat8_t,
> + z0_res = svmlalb_lane_f16_mf8_fpm (z4, z0, z1, 0, fpm0),
> + z0_res = svmlalb_lane_fpm (z4, z0, z1, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_tied3:
> +** msr fpmr, x0
> +** mov (z[0-9]+)\.d, z0\.d
> +** movprfx z0, z4
> +** fmlalb z0\.h, z1\.b, \1\.b\[0\]
> +** ret
> +*/
> +TEST_DUAL_Z_REV (mlalb_lane_0_f16_tied3, svfloat16_t, svmfloat8_t,
> + z0_res = svmlalb_lane_f16_mf8_fpm (z4, z1, z0, 0, fpm0),
> + z0_res = svmlalb_lane_fpm (z4, z1, z0, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_untied:
> +** msr fpmr, x0
> +** movprfx z0, z1
> +** fmlalb z0\.h, z4\.b, z5\.b\[0\]
> +** ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_0_f16_untied, svfloat16_t, svmfloat8_t,
> + z0 = svmlalb_lane_f16_mf8_fpm (z1, z4, z5, 0, fpm0),
> + z0 = svmlalb_lane_fpm (z1, z4, z5, 0, fpm0))
> +
> +/*
> +** mlalb_lane_1_f16:
> +** msr fpmr, x0
> +** fmlalb z0\.h, z4\.b, z5\.b\[1\]
> +** ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_1_f16, svfloat16_t, svmfloat8_t,
> + z0 = svmlalb_lane_f16_mf8_fpm (z0, z4, z5, 1, fpm0),
> + z0 = svmlalb_lane_fpm (z0, z4, z5, 1, fpm0))
> +
> +/*
> +** mlalb_lane_z8_f16:
> +** ...
> +** msr fpmr, x0
> +** mov (z[0-7])\.d, z8\.d
> +** fmlalb z0\.h, z1\.b, \1\.b\[1\]
> +** ldr d8, \[sp\], 32
> +** ret
> +*/
> +TEST_DUAL_LANE_REG (mlalb_lane_z8_f16, svfloat16_t, svmfloat8_t, z8,
> + z0 = svmlalb_lane_f16_mf8_fpm (z0, z1, z8, 1, fpm0),
> + z0 = svmlalb_lane_fpm (z0, z1, z8, 1, fpm0))
> +
> +/*
> +** mlalb_lane_z16_f16:
> +** ...
> +** msr fpmr, x0
> +** mov (z[0-7])\.d, z16\.d
> +** fmlalb z0\.h, z1\.b, \1\.b\[1\]
> +** ...
> +** ret
> +*/
> +TEST_DUAL_LANE_REG (mlalb_lane_z16_f16, svfloat16_t, svmfloat8_t, z16,
> + z0 = svmlalb_lane_f16_mf8_fpm (z0, z1, z16, 1, fpm0),
> + z0 = svmlalb_lane_fpm (z0, z1, z16, 1, fpm0))
It would be good to have a test for the upper limit of the index range,
like for the _f32 tests. Same for svmlalt_lane.
Looks good to me otherwise, thanks,
Richard