https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/184882
>From 2491daa60da34861eeff2a239b443b984844d6a9 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 5 Mar 2026 11:22:33 -0800 Subject: [PATCH 1/4] Implement HLSL mul function - Define a `__builtin_hlsl_mul` clang builtin in `Builtins.td`. - Links the `__builtin_hlsl_mul` clang builtin with `hlsl_alias_intrinsics.h` under the name `mul` - Adds sema checks for `__builtin_hlsl_mul` to `CheckBuiltinFunctionCall` in `SemaHLSL.cpp` - Adds codegen for `__builtin_hlsl_mul` to `EmitHLSLBuiltinExpr` in `CGHLSLBuiltins.cpp` - Vector-vector multiplication uses `dot`, except double vectors in DirectX which expand to scalar (fused) multiply-adds. - Matrix-matrix, matrix-vector, and vector-matrix multiplication lower to the `llvm.matrix.multiply` intrinsic - Adds codegen tests to `clang/test/CodeGenHLSL/builtins/mul.hlsl` - Adds sema tests to `clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl` - Implements lowering of the `llvm.matrix.multiply` intrinsic to DXIL in `DXILIntrinsicExpansion.cpp` - Currently only supports column-major matrix memory layout Note: Currently the SPIR-V and DXIL backends do not support row-major matrix memory layouts when lowering matrix multiply. Assisted-by: claude-opus-4.6 --- clang/include/clang/Basic/Builtins.td | 6 + clang/lib/CodeGen/CGHLSLBuiltins.cpp | 91 +++++ .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 69 ++++ clang/lib/Sema/SemaHLSL.cpp | 57 +++ clang/test/CodeGenHLSL/builtins/mul.hlsl | 143 ++++++++ clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl | 42 +++ .../Target/DirectX/DXILIntrinsicExpansion.cpp | 96 +++++ llvm/test/CodeGen/DirectX/matrix-multiply.ll | 342 ++++++++++++++++++ 8 files changed, 846 insertions(+) create mode 100644 clang/test/CodeGenHLSL/builtins/mul.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/matrix-multiply.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 531c3702161f2..c6a0c10449ba9 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5312,6 +5312,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLMul : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_mul"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + def HLSLNormalize : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_normalize"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 70891eac39425..34e92c42fdef9 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -13,6 +13,7 @@ #include "CGBuiltin.h" #include "CGHLSLRuntime.h" #include "CodeGenFunction.h" +#include "llvm/IR/MatrixBuilder.h" using namespace clang; using namespace CodeGen; @@ -1006,6 +1007,96 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, Value *Mul = Builder.CreateNUWMul(M, A); return Builder.CreateNUWAdd(Mul, B); } + case Builtin::BI__builtin_hlsl_mul: { + Value *Op0 = EmitScalarExpr(E->getArg(0)); + Value *Op1 = EmitScalarExpr(E->getArg(1)); + QualType QTy0 = E->getArg(0)->getType(); + QualType QTy1 = E->getArg(1)->getType(); + llvm::Type *T0 = Op0->getType(); + + bool IsScalar0 = QTy0->isScalarType(); + bool IsVec0 = QTy0->isVectorType(); + bool IsMat0 = QTy0->isConstantMatrixType(); + bool IsScalar1 = QTy1->isScalarType(); + bool IsVec1 = QTy1->isVectorType(); + bool IsMat1 = QTy1->isConstantMatrixType(); + bool IsFP = + QTy0->hasFloatingRepresentation() || QTy1->hasFloatingRepresentation(); + + // Cases 1-4, 7: scalar * scalar/vector/matrix or vector/matrix * scalar + if (IsScalar0 || IsScalar1) { + // Splat scalar to match the other operand's type + Value *Scalar = IsScalar0 ? Op0 : Op1; + Value *Other = IsScalar0 ? Op1 : Op0; + llvm::Type *OtherTy = Other->getType(); + + // Note: Matrices are flat vectors in the IR, so the following + // if-condition is also true when Other is a matrix, not just a vector. + if (OtherTy->isVectorTy()) { + unsigned NumElts = cast<FixedVectorType>(OtherTy)->getNumElements(); + Scalar = Builder.CreateVectorSplat(NumElts, Scalar); + } + + if (IsFP) + return Builder.CreateFMul(Scalar, Other, "hlsl.mul"); + return Builder.CreateMul(Scalar, Other, "hlsl.mul"); + } + + // Case 5: vector * vector -> scalar (dot product) + if (IsVec0 && IsVec1) { + auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>(); + QualType EltQTy = VecTy0->getElementType(); + + // DXIL doesn't have a dot product intrinsic for double vectors, + // so expand to scalar multiply-add for DXIL. + if (CGM.getTarget().getTriple().isDXIL() && + EltQTy->isSpecificBuiltinType(BuiltinType::Double)) { + unsigned NumElts = cast<FixedVectorType>(T0)->getNumElements(); + Value *Sum = nullptr; + for (unsigned I = 0; I < NumElts; ++I) { + Value *L = Builder.CreateExtractElement(Op0, I); + Value *R = Builder.CreateExtractElement(Op1, I); + if (Sum) + Sum = Builder.CreateIntrinsic(Sum->getType(), Intrinsic::fmuladd, + {L, R, Sum}); + else + Sum = Builder.CreateFMul(L, R); + } + return Sum; + } + + return Builder.CreateIntrinsic( + /*ReturnType=*/T0->getScalarType(), + getDotProductIntrinsic(CGM.getHLSLRuntime(), EltQTy), + ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.mul"); + } + + // Cases 6, 8, 9: matrix involved -> use llvm.matrix.multiply + llvm::MatrixBuilder MB(Builder); + if (IsVec0 && IsMat1) { + // vector<N> * matrix<N,M> -> vector<M> + // Treat vector as 1×N matrix + unsigned N = QTy0->castAs<VectorType>()->getNumElements(); + auto *MatTy = QTy1->castAs<ConstantMatrixType>(); + unsigned M = MatTy->getNumColumns(); + return MB.CreateMatrixMultiply(Op0, Op1, 1, N, M, "hlsl.mul"); + } + if (IsMat0 && IsVec1) { + // matrix<M,N> * vector<N> -> vector<M> + // Treat vector as N×1 matrix + auto *MatTy = QTy0->castAs<ConstantMatrixType>(); + unsigned Rows = MatTy->getNumRows(); + unsigned Cols = MatTy->getNumColumns(); + return MB.CreateMatrixMultiply(Op0, Op1, Rows, Cols, 1, "hlsl.mul"); + } + assert(IsMat0 && IsMat1); + // matrix<M,K> * matrix<K,N> -> matrix<M,N> + auto *MatTy0 = QTy0->castAs<ConstantMatrixType>(); + auto *MatTy1 = QTy1->castAs<ConstantMatrixType>(); + return MB.CreateMatrixMultiply(Op0, Op1, MatTy0->getNumRows(), + MatTy0->getNumColumns(), + MatTy1->getNumColumns(), "hlsl.mul"); + } case Builtin::BI__builtin_hlsl_elementwise_rcp: { Value *Op0 = EmitScalarExpr(E->getArg(0)); if (!E->getArg(0)->getType()->hasFloatingRepresentation()) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2543401bdfbf9..2e9847803a8a1 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1775,6 +1775,75 @@ double3 min(double3, double3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min) double4 min(double4, double4); +//===----------------------------------------------------------------------===// +// mul builtins +//===----------------------------------------------------------------------===// + +/// \fn R mul(X x, Y y) +/// \brief Multiplies x and y using matrix math. +/// \param x [in] The first input value. If x is a vector, it is treated as a +/// row vector. +/// \param y [in] The second input value. If y is a vector, it is treated as a +/// column vector. +/// +/// The inner dimension x-columns and y-rows must be equal. The result has the +/// dimension x-rows x y-columns. When both x and y are vectors, the result is +/// a dot product (scalar). Scalar operands are multiplied element-wise. +/// +/// This function supports 9 overloaded forms: +/// 1. scalar * scalar -> scalar +/// 2. scalar * vector -> vector +/// 3. scalar * matrix -> matrix +/// 4. vector * scalar -> vector +/// 5. vector * vector -> scalar (dot product) +/// 6. vector * matrix -> vector +/// 7. matrix * scalar -> matrix +/// 8. matrix * vector -> vector +/// 9. matrix * matrix -> matrix + +// Case 1: scalar * scalar -> scalar +template <typename T> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) T mul(T, T); + +// Case 2: scalar * vector -> vector +template <typename T, int N> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +vector<T, N> mul(T, vector<T, N>); + +// Case 3: scalar * matrix -> matrix +template <typename T, int R, int C> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +matrix<T, R, C> mul(T, matrix<T, R, C>); + +// Case 4: vector * scalar -> vector +template <typename T, int N> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +vector<T, N> mul(vector<T, N>, T); + +// Case 5: vector * vector -> scalar (dot product) +template <typename T, int N> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +T mul(vector<T, N>, vector<T, N>); + +// Case 6: vector * matrix -> vector +template <typename T, int R, int C> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +vector<T, C> mul(vector<T, R>, matrix<T, R, C>); + +// Case 7: matrix * scalar -> matrix +template <typename T, int R, int C> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +matrix<T, R, C> mul(matrix<T, R, C>, T); + +// Case 8: matrix * vector -> vector +template <typename T, int R, int C> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +vector<T, R> mul(matrix<T, R, C>, vector<T, C>); + +// Case 9: matrix * matrix -> matrix +template <typename T, int R, int K, int C> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) +matrix<T, R, C> mul(matrix<T, R, K>, matrix<T, K, C>); + //===----------------------------------------------------------------------===// // normalize builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 804ea70aaddce..46a30acd95b68 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3775,6 +3775,63 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_mul: { + if (SemaRef.checkArgCount(TheCall, 2)) + return true; + + Expr *Arg0 = TheCall->getArg(0); + Expr *Arg1 = TheCall->getArg(1); + QualType Ty0 = Arg0->getType(); + QualType Ty1 = Arg1->getType(); + + auto getElemType = [](QualType T) -> QualType { + if (const auto *VTy = T->getAs<VectorType>()) + return VTy->getElementType(); + if (const auto *MTy = T->getAs<ConstantMatrixType>()) + return MTy->getElementType(); + return T; + }; + + QualType EltTy0 = getElemType(Ty0); + + bool IsScalar0 = Ty0->isScalarType(); + bool IsVec0 = Ty0->isVectorType(); + bool IsMat0 = Ty0->isConstantMatrixType(); + bool IsScalar1 = Ty1->isScalarType(); + bool IsVec1 = Ty1->isVectorType(); + bool IsMat1 = Ty1->isConstantMatrixType(); + + QualType RetTy; + + if (IsScalar0 && IsScalar1) { + RetTy = EltTy0; + } else if (IsScalar0 && IsVec1) { + RetTy = Ty1; + } else if (IsScalar0 && IsMat1) { + RetTy = Ty1; + } else if (IsVec0 && IsScalar1) { + RetTy = Ty0; + } else if (IsVec0 && IsVec1) { + RetTy = EltTy0; + } else if (IsVec0 && IsMat1) { + auto *MatTy = Ty1->castAs<ConstantMatrixType>(); + RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns()); + } else if (IsMat0 && IsScalar1) { + RetTy = Ty0; + } else if (IsMat0 && IsVec1) { + auto *MatTy = Ty0->castAs<ConstantMatrixType>(); + RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows()); + } else { + assert(IsMat0 && IsMat1); + auto *MatTy0 = Ty0->castAs<ConstantMatrixType>(); + auto *MatTy1 = Ty1->castAs<ConstantMatrixType>(); + RetTy = getASTContext().getConstantMatrixType( + EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns()); + } + + TheCall->setType(RetTy); + break; + } case Builtin::BI__builtin_hlsl_normalize: { if (SemaRef.checkArgCount(TheCall, 1)) return true; diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl b/clang/test/CodeGenHLSL/builtins/mul.hlsl new file mode 100644 index 0000000000000..0a95d6004e567 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl @@ -0,0 +1,143 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,DXIL +// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,SPIRV + +// -- Case 1: scalar * scalar -> scalar -- + +// CHECK-LABEL: test_scalar_mulf +// CHECK: [[A:%.*]] = load float, ptr %a.addr +// CHECK: [[B:%.*]] = load float, ptr %b.addr +// CHECK: %hlsl.mul = fmul {{.*}} float [[A]], [[B]] +// CHECK: ret float %hlsl.mul +export float test_scalar_mulf(float a, float b) { return mul(a, b); } + +// CHECK-LABEL: test_scalar_muli +// CHECK: [[A:%.*]] = load i32, ptr %a.addr +// CHECK: [[B:%.*]] = load i32, ptr %b.addr +// CHECK: %hlsl.mul = mul i32 [[A]], [[B]] +// CHECK: ret i32 %hlsl.mul +export int test_scalar_muli(int a, int b) { return mul(a, b); } + +// -- Case 2: scalar * vector -> vector -- + +// CHECK-LABEL: test_scalar_vec_mul +// CHECK: [[A:%.*]] = load float, ptr %a.addr +// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr +// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[A]], i64 0 +// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[B]] +// CHECK: ret <3 x float> %hlsl.mul +export float3 test_scalar_vec_mul(float a, float3 b) { return mul(a, b); } + +// -- Case 3: scalar * matrix -> matrix -- + +// CHECK-LABEL: test_scalar_mat_mul +// CHECK: [[A:%.*]] = load float, ptr %a.addr +// CHECK: [[B:%.*]] = load <6 x float>, ptr %b.addr +// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[A]], i64 0 +// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> poison, <6 x i32> zeroinitializer +// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[B]] +// CHECK: ret <6 x float> %hlsl.mul +export float2x3 test_scalar_mat_mul(float a, float2x3 b) { return mul(a, b); } + +// -- Case 4: vector * scalar -> vector -- + +// CHECK-LABEL: test_vec_scalar_mul +// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr +// CHECK: [[B:%.*]] = load float, ptr %b.addr +// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[B]], i64 0 +// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[A]] +// CHECK: ret <3 x float> %hlsl.mul +export float3 test_vec_scalar_mul(float3 a, float b) { return mul(a, b); } + +// -- Case 5: vector * vector -> scalar (dot product) -- + +// CHECK-LABEL: test_vec_vec_mul +// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr +// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr +// DXIL: %hlsl.mul = call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> [[A]], <3 x float> [[B]]) +// SPIRV: %hlsl.mul = call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float> [[A]], <3 x float> [[B]]) +// CHECK: ret float %hlsl.mul +export float test_vec_vec_mul(float3 a, float3 b) { return mul(a, b); } + +// CHECK-LABEL: test_vec_vec_muli +// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr +// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr +// DXIL: %hlsl.mul = call i32 @llvm.dx.sdot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) +// SPIRV: %hlsl.mul = call i32 @llvm.spv.sdot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) +// CHECK: ret i32 %hlsl.mul +export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); } + +// CHECK-LABEL: test_vec_vec_mulu +// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr +// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr +// DXIL: %hlsl.mul = call i32 @llvm.dx.udot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) +// SPIRV: %hlsl.mul = call i32 @llvm.spv.udot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) +// CHECK: ret i32 %hlsl.mul +export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); } + +// Double vector dot product: DXIL uses scalar arithmetic, SPIR-V uses fdot +// CHECK-LABEL: test_vec_vec_muld +// CHECK: [[A:%.*]] = load <3 x double>, ptr %a.addr +// CHECK: [[B:%.*]] = load <3 x double>, ptr %b.addr +// DXIL-NOT: @llvm.dx.fdot +// DXIL: [[A0:%.*]] = extractelement <3 x double> [[A]], i64 0 +// DXIL: [[B0:%.*]] = extractelement <3 x double> [[B]], i64 0 +// DXIL: [[MUL0:%.*]] = fmul {{.*}} double [[A0]], [[B0]] +// DXIL: [[A1:%.*]] = extractelement <3 x double> [[A]], i64 1 +// DXIL: [[B1:%.*]] = extractelement <3 x double> [[B]], i64 1 +// DXIL: [[FMA0:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A1]], double [[B1]], double [[MUL0]]) +// DXIL: [[A2:%.*]] = extractelement <3 x double> [[A]], i64 2 +// DXIL: [[B2:%.*]] = extractelement <3 x double> [[B]], i64 2 +// DXIL: [[FMA1:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A2]], double [[B2]], double [[FMA0]]) +// DXIL: ret double [[FMA1]] +// SPIRV: %hlsl.mul = call {{.*}} double @llvm.spv.fdot.v3f64(<3 x double> [[A]], <3 x double> [[B]]) +// SPIRV: ret double %hlsl.mul +export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); } + +// -- Case 6: vector * matrix -> vector -- + +// CHECK-LABEL: test_vec_mat_mul +// CHECK: [[V:%.*]] = load <2 x float>, ptr %v.addr +// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr +// CHECK: %hlsl.mul = call {{.*}} <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> [[V]], <6 x float> [[M]], i32 1, i32 2, i32 3) +// CHECK: ret <3 x float> %hlsl.mul +export float3 test_vec_mat_mul(float2 v, float2x3 m) { return mul(v, m); } + +// -- Case 7: matrix * scalar -> matrix -- + +// CHECK-LABEL: test_mat_scalar_mul +// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr +// CHECK: [[B:%.*]] = load float, ptr %b.addr +// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[B]], i64 0 +// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> poison, <6 x i32> zeroinitializer +// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[A]] +// CHECK: ret <6 x float> %hlsl.mul +export float2x3 test_mat_scalar_mul(float2x3 a, float b) { return mul(a, b); } + +// -- Case 8: matrix * vector -> vector -- + +// CHECK-LABEL: test_mat_vec_mul +// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr +// CHECK: [[V:%.*]] = load <3 x float>, ptr %v.addr +// CHECK: %hlsl.mul = call {{.*}} <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> [[M]], <3 x float> [[V]], i32 2, i32 3, i32 1) +// CHECK: ret <2 x float> %hlsl.mul +export float2 test_mat_vec_mul(float2x3 m, float3 v) { return mul(m, v); } + +// -- Case 9: matrix * matrix -> matrix -- + +// CHECK-LABEL: test_mat_mat_mul +// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr +// CHECK: [[B:%.*]] = load <12 x float>, ptr %b.addr +// CHECK: %hlsl.mul = call {{.*}} <8 x float> @llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> [[A]], <12 x float> [[B]], i32 2, i32 3, i32 4) +// CHECK: ret <8 x float> %hlsl.mul +export float2x4 test_mat_mat_mul(float2x3 a, float3x4 b) { return mul(a, b); } + +// -- Integer matrix multiply -- + +// CHECK-LABEL: test_mat_mat_muli +// CHECK: [[A:%.*]] = load <6 x i32>, ptr %a.addr +// CHECK: [[B:%.*]] = load <12 x i32>, ptr %b.addr +// CHECK: %hlsl.mul = call <8 x i32> @llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> [[A]], <12 x i32> [[B]], i32 2, i32 3, i32 4) +// CHECK: ret <8 x i32> %hlsl.mul +export int2x4 test_mat_mat_muli(int2x3 a, int3x4 b) { return mul(a, b); } diff --git a/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl new file mode 100644 index 0000000000000..01227126bff10 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl @@ -0,0 +1,42 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library %s -verify + +// expected-note@*:* 54 {{candidate template ignored}} + +// Test mul error cases via template overload resolution + +// Inner dimension mismatch: vector * vector with different sizes +export float test_vec_dim_mismatch(float2 a, float3 b) { + return mul(a, b); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + +// Inner dimension mismatch: matrix * matrix +export float2x4 test_mat_dim_mismatch(float2x3 a, float4x4 b) { + return mul(a, b); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + +// Inner dimension mismatch: vector * matrix +export float3 test_vec_mat_mismatch(float3 v, float2x3 m) { + return mul(v, m); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + +// Inner dimension mismatch: matrix * vector +export float2 test_mat_vec_mismatch(float2x3 m, float2 v) { + return mul(m, v); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + +// Type mismatch: different element types +export float test_type_mismatch(float a, int b) { + return mul(a, b); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + +// Type mismatch: different vector element types +export float test_vec_type_mismatch(float3 a, int3 b) { + return mul(a, b); + // expected-error@-1 {{no matching function for call to 'mul'}} +} + diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index c4bf097e5a0f8..78ac0edc7de47 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -22,6 +22,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -229,6 +230,7 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::usub_sat: case Intrinsic::vector_reduce_add: case Intrinsic::vector_reduce_fadd: + case Intrinsic::matrix_multiply: return true; case Intrinsic::dx_resource_load_rawbuffer: return resourceAccessNeeds64BitExpansion( @@ -1043,6 +1045,97 @@ static Value *expandSignIntrinsic(CallInst *Orig) { return Builder.CreateSub(ZextGT, ZextLT); } +// Expand llvm.matrix.multiply by extracting row/column vectors and computing +// dot products. +// Result[r,c] = dot(row_r(LHS), col_c(RHS)) +// Element (r,c) is at index c*NumRows + r (column-major). +static Value *expandMatrixMultiply(CallInst *Orig) { + Value *LHS = Orig->getArgOperand(0); + Value *RHS = Orig->getArgOperand(1); + unsigned LHSRows = cast<ConstantInt>(Orig->getArgOperand(2))->getZExtValue(); + unsigned LHSCols = cast<ConstantInt>(Orig->getArgOperand(3))->getZExtValue(); + unsigned RHSCols = cast<ConstantInt>(Orig->getArgOperand(4))->getZExtValue(); + + auto *RetTy = cast<FixedVectorType>(Orig->getType()); + Type *EltTy = RetTy->getElementType(); + bool IsFP = EltTy->isFloatingPointTy(); + + IRBuilder<> Builder(Orig); + + // Column-major indexing: + // LHS row R, element K: index = K * LHSRows + R + // RHS col C, element K: index = C * LHSCols + K + // TODO: support row-major indexing + Value *Result = PoisonValue::get(RetTy); + + // Extract all scalar elements from LHS and RHS once, then reuse them. + unsigned LHSSize = LHSRows * LHSCols; + unsigned RHSSize = LHSCols * RHSCols; + SmallVector<Value *, 16> LHSElts(LHSSize); + SmallVector<Value *, 16> RHSElts(RHSSize); + for (unsigned I = 0; I < LHSSize; ++I) + LHSElts[I] = Builder.CreateExtractElement(LHS, I); + for (unsigned I = 0; I < RHSSize; ++I) + RHSElts[I] = Builder.CreateExtractElement(RHS, I); + + // Choose the appropriate scalar-arg dot intrinsic for floats. + // K=1 and double types use scalar expansion instead. + Intrinsic::ID FloatDotID = Intrinsic::not_intrinsic; + bool UseScalarFP = IsFP && (EltTy->isDoubleTy() || LHSCols == 1); + if (IsFP && !UseScalarFP) { + switch (LHSCols) { + case 2: + FloatDotID = Intrinsic::dx_dot2; + break; + case 3: + FloatDotID = Intrinsic::dx_dot3; + break; + case 4: + FloatDotID = Intrinsic::dx_dot4; + break; + default: + reportFatalUsageError( + "Invalid matrix inner dimension for dot product: must be 2-4"); + return nullptr; + } + } + + for (unsigned C = 0; C < RHSCols; ++C) { + for (unsigned R = 0; R < LHSRows; ++R) { + // Gather row R from LHS and column C from RHS. + SmallVector<Value *, 4> RowElts, ColElts; + for (unsigned K = 0; K < LHSCols; ++K) { + RowElts.push_back(LHSElts[K * LHSRows + R]); + ColElts.push_back(RHSElts[C * LHSCols + K]); + } + + Value *Dot; + if (UseScalarFP) { + // Scalar fmul+fadd expansion for double types and K=1. + Dot = Builder.CreateFMul(RowElts[0], ColElts[0]); + for (unsigned K = 1; K < LHSCols; ++K) + Dot = Builder.CreateFAdd(Dot, + Builder.CreateFMul(RowElts[K], ColElts[K])); + } else if (IsFP) { + // Emit scalar-arg DXIL dot directly (dx.dot2/dx.dot3/dx.dot4). + SmallVector<Value *, 8> Args; + Args.append(RowElts.begin(), RowElts.end()); + Args.append(ColElts.begin(), ColElts.end()); + Dot = Builder.CreateIntrinsic(EltTy, FloatDotID, Args); + } else { + // Integer: emit multiply + imad chain. + Dot = Builder.CreateMul(RowElts[0], ColElts[0]); + for (unsigned K = 1; K < LHSCols; ++K) + Dot = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_imad, + {RowElts[K], ColElts[K], Dot}); + } + unsigned ResIdx = C * LHSRows + R; + Result = Builder.CreateInsertElement(Result, Dot, ResIdx); + } + } + return Result; +} + static bool expandIntrinsic(Function &F, CallInst *Orig) { Value *Result = nullptr; Intrinsic::ID IntrinsicId = F.getIntrinsicID(); @@ -1144,6 +1237,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { case Intrinsic::vector_reduce_fadd: Result = expandVecReduceAdd(Orig, IntrinsicId); break; + case Intrinsic::matrix_multiply: + Result = expandMatrixMultiply(Orig); + break; } if (Result) { Orig->replaceAllUsesWith(Result); diff --git a/llvm/test/CodeGen/DirectX/matrix-multiply.ll b/llvm/test/CodeGen/DirectX/matrix-multiply.ll new file mode 100644 index 0000000000000..dca270f053a02 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/matrix-multiply.ll @@ -0,0 +1,342 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s + +; Verify that llvm.matrix.multiply is expanded to scalar dot products for DXIL. + +declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32) +declare <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float>, <6 x float>, i32, i32, i32) +declare <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float>, <3 x float>, i32, i32, i32) +declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32) +declare <16 x float> @llvm.matrix.multiply.v16f32.v16f32.v16f32(<16 x float>, <16 x float>, i32, i32, i32) +declare <4 x float> @llvm.matrix.multiply.v4f32.v16f32.v4f32(<16 x float>, <4 x float>, i32, i32, i32) +declare <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double>, <4 x double>, i32, i32, i32) +declare <2 x double> @llvm.matrix.multiply.v2f64.v4f64.v2f64(<4 x double>, <2 x double>, i32, i32, i32) +declare <6 x float> @llvm.matrix.multiply.v6f32.v2f32.v3f32(<2 x float>, <3 x float>, i32, i32, i32) +declare <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32>, <3 x i32>, i32, i32, i32) + +; 2x2 float: 4 dot2 calls. +define <4 x float> @test_float_2x2(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: define <4 x float> @test_float_2x2( +; CHECK-SAME: <4 x float> [[A:%.*]], <4 x float> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[A]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[A]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[B]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[B]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[B]], i64 2 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x float> [[B]], i64 3 +; CHECK-NEXT: [[TMP9:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], float [[TMP3]], float [[TMP5]], float [[TMP6]]) +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x float> poison, float [[TMP9]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP2]], float [[TMP4]], float [[TMP5]], float [[TMP6]]) +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <4 x float> [[TMP10]], float [[TMP11]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], float [[TMP3]], float [[TMP7]], float [[TMP8]]) +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x float> [[TMP12]], float [[TMP13]], i64 2 +; CHECK-NEXT: [[TMP15:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP2]], float [[TMP4]], float [[TMP7]], float [[TMP8]]) +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x float> [[TMP14]], float [[TMP15]], i64 3 +; CHECK-NEXT: ret <4 x float> [[TMP16]] +; + %r = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %a, <4 x float> %b, i32 2, i32 2, i32 2) + ret <4 x float> %r +} + +; 1x2 * 2x3: 3 dot2 calls. +define <3 x float> @test_vec_mat(<2 x float> %v, <6 x float> %m) { +; CHECK-LABEL: define <3 x float> @test_vec_mat( +; CHECK-SAME: <2 x float> [[V:%.*]], <6 x float> [[M:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[V]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[V]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <6 x float> [[M]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <6 x float> [[M]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <6 x float> [[M]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <6 x float> [[M]], i64 3 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <6 x float> [[M]], i64 4 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <6 x float> [[M]], i64 5 +; CHECK-NEXT: [[TMP9:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]], float [[TMP4]]) +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <3 x float> poison, float [[TMP9]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], float [[TMP2]], float [[TMP5]], float [[TMP6]]) +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <3 x float> [[TMP10]], float [[TMP11]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], float [[TMP2]], float [[TMP7]], float [[TMP8]]) +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <3 x float> [[TMP12]], float [[TMP13]], i64 2 +; CHECK-NEXT: ret <3 x float> [[TMP14]] +; + %r = call <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> %v, <6 x float> %m, i32 1, i32 2, i32 3) + ret <3 x float> %r +} + +; 2x3 * 3x1: 2 dot3 calls. +define <2 x float> @test_mat_vec(<6 x float> %m, <3 x float> %v) { +; CHECK-LABEL: define <2 x float> @test_mat_vec( +; CHECK-SAME: <6 x float> [[M:%.*]], <3 x float> [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <6 x float> [[M]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <6 x float> [[M]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <6 x float> [[M]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <6 x float> [[M]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <6 x float> [[M]], i64 4 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <6 x float> [[M]], i64 5 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <3 x float> [[V]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x float> [[V]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <3 x float> [[V]], i64 2 +; CHECK-NEXT: [[TMP10:%.*]] = call float @llvm.dx.dot3.f32(float [[TMP1]], float [[TMP3]], float [[TMP5]], float [[TMP7]], float [[TMP8]], float [[TMP9]]) +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x float> poison, float [[TMP10]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = call float @llvm.dx.dot3.f32(float [[TMP2]], float [[TMP4]], float [[TMP6]], float [[TMP7]], float [[TMP8]], float [[TMP9]]) +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <2 x float> [[TMP11]], float [[TMP12]], i64 1 +; CHECK-NEXT: ret <2 x float> [[TMP13]] +; + %r = call <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> %m, <3 x float> %v, i32 2, i32 3, i32 1) + ret <2 x float> %r +} + +; 2x2 integer: mul + imad chains. +define <4 x i32> @test_int_2x2(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: define <4 x i32> @test_int_2x2( +; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[A]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[A]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i32> [[B]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[B]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x i32> [[B]], i64 2 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x i32> [[B]], i64 3 +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP3]], i32 [[TMP6]], i32 [[TMP9]]) +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x i32> poison, i32 [[TMP10]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = mul i32 [[TMP2]], [[TMP5]] +; CHECK-NEXT: [[TMP13:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP4]], i32 [[TMP6]], i32 [[TMP12]]) +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x i32> [[TMP11]], i32 [[TMP13]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = mul i32 [[TMP1]], [[TMP7]] +; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP3]], i32 [[TMP8]], i32 [[TMP15]]) +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x i32> [[TMP14]], i32 [[TMP16]], i64 2 +; CHECK-NEXT: [[TMP18:%.*]] = mul i32 [[TMP2]], [[TMP7]] +; CHECK-NEXT: [[TMP19:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP4]], i32 [[TMP8]], i32 [[TMP18]]) +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i32> [[TMP17]], i32 [[TMP19]], i64 3 +; CHECK-NEXT: ret <4 x i32> [[TMP20]] +; + %r = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2) + ret <4 x i32> %r +} + +; 4x4 float: 16 dot4 calls. +define <16 x float> @test_float_4x4(<16 x float> %a, <16 x float> %b) { +; CHECK-LABEL: define <16 x float> @test_float_4x4( +; CHECK-SAME: <16 x float> [[A:%.*]], <16 x float> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <16 x float> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <16 x float> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <16 x float> [[A]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <16 x float> [[A]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <16 x float> [[A]], i64 4 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <16 x float> [[A]], i64 5 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <16 x float> [[A]], i64 6 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <16 x float> [[A]], i64 7 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <16 x float> [[A]], i64 8 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <16 x float> [[A]], i64 9 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <16 x float> [[A]], i64 10 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <16 x float> [[A]], i64 11 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <16 x float> [[A]], i64 12 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <16 x float> [[A]], i64 13 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <16 x float> [[A]], i64 14 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <16 x float> [[A]], i64 15 +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <16 x float> [[B]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <16 x float> [[B]], i64 1 +; CHECK-NEXT: [[TMP19:%.*]] = extractelement <16 x float> [[B]], i64 2 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <16 x float> [[B]], i64 3 +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <16 x float> [[B]], i64 4 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <16 x float> [[B]], i64 5 +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <16 x float> [[B]], i64 6 +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <16 x float> [[B]], i64 7 +; CHECK-NEXT: [[TMP25:%.*]] = extractelement <16 x float> [[B]], i64 8 +; CHECK-NEXT: [[TMP26:%.*]] = extractelement <16 x float> [[B]], i64 9 +; CHECK-NEXT: [[TMP27:%.*]] = extractelement <16 x float> [[B]], i64 10 +; CHECK-NEXT: [[TMP28:%.*]] = extractelement <16 x float> [[B]], i64 11 +; CHECK-NEXT: [[TMP29:%.*]] = extractelement <16 x float> [[B]], i64 12 +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <16 x float> [[B]], i64 13 +; CHECK-NEXT: [[TMP31:%.*]] = extractelement <16 x float> [[B]], i64 14 +; CHECK-NEXT: [[TMP32:%.*]] = extractelement <16 x float> [[B]], i64 15 +; CHECK-NEXT: [[TMP33:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP34:%.*]] = insertelement <16 x float> poison, float [[TMP33]], i64 0 +; CHECK-NEXT: [[TMP35:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP36:%.*]] = insertelement <16 x float> [[TMP34]], float [[TMP35]], i64 1 +; CHECK-NEXT: [[TMP37:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP38:%.*]] = insertelement <16 x float> [[TMP36]], float [[TMP37]], i64 2 +; CHECK-NEXT: [[TMP39:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP40:%.*]] = insertelement <16 x float> [[TMP38]], float [[TMP39]], i64 3 +; CHECK-NEXT: [[TMP41:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP21]], float [[TMP22]], float [[TMP23]], float [[TMP24]]) +; CHECK-NEXT: [[TMP42:%.*]] = insertelement <16 x float> [[TMP40]], float [[TMP41]], i64 4 +; CHECK-NEXT: [[TMP43:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP21]], float [[TMP22]], float [[TMP23]], float [[TMP24]]) +; CHECK-NEXT: [[TMP44:%.*]] = insertelement <16 x float> [[TMP42]], float [[TMP43]], i64 5 +; CHECK-NEXT: [[TMP45:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP21]], float [[TMP22]], float [[TMP23]], float [[TMP24]]) +; CHECK-NEXT: [[TMP46:%.*]] = insertelement <16 x float> [[TMP44]], float [[TMP45]], i64 6 +; CHECK-NEXT: [[TMP47:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP21]], float [[TMP22]], float [[TMP23]], float [[TMP24]]) +; CHECK-NEXT: [[TMP48:%.*]] = insertelement <16 x float> [[TMP46]], float [[TMP47]], i64 7 +; CHECK-NEXT: [[TMP49:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP25]], float [[TMP26]], float [[TMP27]], float [[TMP28]]) +; CHECK-NEXT: [[TMP50:%.*]] = insertelement <16 x float> [[TMP48]], float [[TMP49]], i64 8 +; CHECK-NEXT: [[TMP51:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP25]], float [[TMP26]], float [[TMP27]], float [[TMP28]]) +; CHECK-NEXT: [[TMP52:%.*]] = insertelement <16 x float> [[TMP50]], float [[TMP51]], i64 9 +; CHECK-NEXT: [[TMP53:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP25]], float [[TMP26]], float [[TMP27]], float [[TMP28]]) +; CHECK-NEXT: [[TMP54:%.*]] = insertelement <16 x float> [[TMP52]], float [[TMP53]], i64 10 +; CHECK-NEXT: [[TMP55:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP25]], float [[TMP26]], float [[TMP27]], float [[TMP28]]) +; CHECK-NEXT: [[TMP56:%.*]] = insertelement <16 x float> [[TMP54]], float [[TMP55]], i64 11 +; CHECK-NEXT: [[TMP57:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP29]], float [[TMP30]], float [[TMP31]], float [[TMP32]]) +; CHECK-NEXT: [[TMP58:%.*]] = insertelement <16 x float> [[TMP56]], float [[TMP57]], i64 12 +; CHECK-NEXT: [[TMP59:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP29]], float [[TMP30]], float [[TMP31]], float [[TMP32]]) +; CHECK-NEXT: [[TMP60:%.*]] = insertelement <16 x float> [[TMP58]], float [[TMP59]], i64 13 +; CHECK-NEXT: [[TMP61:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP29]], float [[TMP30]], float [[TMP31]], float [[TMP32]]) +; CHECK-NEXT: [[TMP62:%.*]] = insertelement <16 x float> [[TMP60]], float [[TMP61]], i64 14 +; CHECK-NEXT: [[TMP63:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP29]], float [[TMP30]], float [[TMP31]], float [[TMP32]]) +; CHECK-NEXT: [[TMP64:%.*]] = insertelement <16 x float> [[TMP62]], float [[TMP63]], i64 15 +; CHECK-NEXT: ret <16 x float> [[TMP64]] +; + %r = call <16 x float> @llvm.matrix.multiply.v16f32.v16f32.v16f32(<16 x float> %a, <16 x float> %b, i32 4, i32 4, i32 4) + ret <16 x float> %r +} + +; 4x4 * 4x1: 4 dot4 calls. +define <4 x float> @test_mat4x4_vec4(<16 x float> %m, <4 x float> %v) { +; CHECK-LABEL: define <4 x float> @test_mat4x4_vec4( +; CHECK-SAME: <16 x float> [[M:%.*]], <4 x float> [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <16 x float> [[M]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <16 x float> [[M]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <16 x float> [[M]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <16 x float> [[M]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <16 x float> [[M]], i64 4 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <16 x float> [[M]], i64 5 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <16 x float> [[M]], i64 6 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <16 x float> [[M]], i64 7 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <16 x float> [[M]], i64 8 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <16 x float> [[M]], i64 9 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <16 x float> [[M]], i64 10 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <16 x float> [[M]], i64 11 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <16 x float> [[M]], i64 12 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <16 x float> [[M]], i64 13 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <16 x float> [[M]], i64 14 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <16 x float> [[M]], i64 15 +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <4 x float> [[V]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x float> [[V]], i64 1 +; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x float> [[V]], i64 2 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x float> [[V]], i64 3 +; CHECK-NEXT: [[TMP21:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x float> poison, float [[TMP21]], i64 0 +; CHECK-NEXT: [[TMP23:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x float> [[TMP22]], float [[TMP23]], i64 1 +; CHECK-NEXT: [[TMP25:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP26:%.*]] = insertelement <4 x float> [[TMP24]], float [[TMP25]], i64 2 +; CHECK-NEXT: [[TMP27:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP17]], float [[TMP18]], float [[TMP19]], float [[TMP20]]) +; CHECK-NEXT: [[TMP28:%.*]] = insertelement <4 x float> [[TMP26]], float [[TMP27]], i64 3 +; CHECK-NEXT: ret <4 x float> [[TMP28]] +; + %r = call <4 x float> @llvm.matrix.multiply.v4f32.v16f32.v4f32(<16 x float> %m, <4 x float> %v, i32 4, i32 4, i32 1) + ret <4 x float> %r +} + +; 2x2 double: scalar fmul + fadd chains. +define <4 x double> @test_double_2x2(<4 x double> %a, <4 x double> %b) { +; CHECK-LABEL: define <4 x double> @test_double_2x2( +; CHECK-SAME: <4 x double> [[A:%.*]], <4 x double> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x double> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x double> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x double> [[A]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x double> [[A]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x double> [[B]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x double> [[B]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x double> [[B]], i64 2 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x double> [[B]], i64 3 +; CHECK-NEXT: [[TMP9:%.*]] = fmul double [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP10:%.*]] = fmul double [[TMP3]], [[TMP6]] +; CHECK-NEXT: [[TMP11:%.*]] = fadd double [[TMP9]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <4 x double> poison, double [[TMP11]], i64 0 +; CHECK-NEXT: [[TMP13:%.*]] = fmul double [[TMP2]], [[TMP5]] +; CHECK-NEXT: [[TMP14:%.*]] = fmul double [[TMP4]], [[TMP6]] +; CHECK-NEXT: [[TMP15:%.*]] = fadd double [[TMP13]], [[TMP14]] +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x double> [[TMP12]], double [[TMP15]], i64 1 +; CHECK-NEXT: [[TMP17:%.*]] = fmul double [[TMP1]], [[TMP7]] +; CHECK-NEXT: [[TMP18:%.*]] = fmul double [[TMP3]], [[TMP8]] +; CHECK-NEXT: [[TMP19:%.*]] = fadd double [[TMP17]], [[TMP18]] +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x double> [[TMP16]], double [[TMP19]], i64 2 +; CHECK-NEXT: [[TMP21:%.*]] = fmul double [[TMP2]], [[TMP7]] +; CHECK-NEXT: [[TMP22:%.*]] = fmul double [[TMP4]], [[TMP8]] +; CHECK-NEXT: [[TMP23:%.*]] = fadd double [[TMP21]], [[TMP22]] +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x double> [[TMP20]], double [[TMP23]], i64 3 +; CHECK-NEXT: ret <4 x double> [[TMP24]] +; + %r = call <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double> %a, <4 x double> %b, i32 2, i32 2, i32 2) + ret <4 x double> %r +} + +; 2x2 double * 2x1 double: 2 scalar fmul + fadd chains. +define <2 x double> @test_double_mat_vec(<4 x double> %m, <2 x double> %v) { +; CHECK-LABEL: define <2 x double> @test_double_mat_vec( +; CHECK-SAME: <4 x double> [[M:%.*]], <2 x double> [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x double> [[M]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x double> [[M]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x double> [[M]], i64 2 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x double> [[M]], i64 3 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[V]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[V]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = fmul double [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP8:%.*]] = fmul double [[TMP3]], [[TMP6]] +; CHECK-NEXT: [[TMP9:%.*]] = fadd double [[TMP7]], [[TMP8]] +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x double> poison, double [[TMP9]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = fmul double [[TMP2]], [[TMP5]] +; CHECK-NEXT: [[TMP12:%.*]] = fmul double [[TMP4]], [[TMP6]] +; CHECK-NEXT: [[TMP13:%.*]] = fadd double [[TMP11]], [[TMP12]] +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x double> [[TMP10]], double [[TMP13]], i64 1 +; CHECK-NEXT: ret <2 x double> [[TMP14]] +; + %r = call <2 x double> @llvm.matrix.multiply.v2f64.v4f64.v2f64(<4 x double> %m, <2 x double> %v, i32 2, i32 2, i32 1) + ret <2 x double> %r +} + +; K=1 float outer product (2x1 * 1x3 = 2x3): each element is a single fmul. +define <6 x float> @test_k1_outer_product(<2 x float> %a, <3 x float> %b) { +; CHECK-LABEL: define <6 x float> @test_k1_outer_product( +; CHECK-SAME: <2 x float> [[A:%.*]], <3 x float> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[B]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x float> [[B]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x float> [[B]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = fmul float [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <6 x float> poison, float [[TMP6]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = fmul float [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <6 x float> [[TMP7]], float [[TMP8]], i64 1 +; CHECK-NEXT: [[TMP10:%.*]] = fmul float [[TMP1]], [[TMP4]] +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <6 x float> [[TMP9]], float [[TMP10]], i64 2 +; CHECK-NEXT: [[TMP12:%.*]] = fmul float [[TMP2]], [[TMP4]] +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <6 x float> [[TMP11]], float [[TMP12]], i64 3 +; CHECK-NEXT: [[TMP14:%.*]] = fmul float [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <6 x float> [[TMP13]], float [[TMP14]], i64 4 +; CHECK-NEXT: [[TMP16:%.*]] = fmul float [[TMP2]], [[TMP5]] +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <6 x float> [[TMP15]], float [[TMP16]], i64 5 +; CHECK-NEXT: ret <6 x float> [[TMP17]] +; + %r = call <6 x float> @llvm.matrix.multiply.v6f32.v2f32.v3f32(<2 x float> %a, <3 x float> %b, i32 2, i32 1, i32 3) + ret <6 x float> %r +} + +; K=1 integer outer product (2x1 * 1x3 = 2x3): each element is a single mul. +define <6 x i32> @test_k1_int_outer_product(<2 x i32> %a, <3 x i32> %b) { +; CHECK-LABEL: define <6 x i32> @test_k1_int_outer_product( +; CHECK-SAME: <2 x i32> [[A:%.*]], <3 x i32> [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> [[A]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[A]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i32> [[B]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x i32> [[B]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x i32> [[B]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = mul i32 [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <6 x i32> poison, i32 [[TMP6]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = mul i32 [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <6 x i32> [[TMP7]], i32 [[TMP8]], i64 1 +; CHECK-NEXT: [[TMP10:%.*]] = mul i32 [[TMP1]], [[TMP4]] +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <6 x i32> [[TMP9]], i32 [[TMP10]], i64 2 +; CHECK-NEXT: [[TMP12:%.*]] = mul i32 [[TMP2]], [[TMP4]] +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <6 x i32> [[TMP11]], i32 [[TMP12]], i64 3 +; CHECK-NEXT: [[TMP14:%.*]] = mul i32 [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <6 x i32> [[TMP13]], i32 [[TMP14]], i64 4 +; CHECK-NEXT: [[TMP16:%.*]] = mul i32 [[TMP2]], [[TMP5]] +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <6 x i32> [[TMP15]], i32 [[TMP16]], i64 5 +; CHECK-NEXT: ret <6 x i32> [[TMP17]] +; + %r = call <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32> %a, <3 x i32> %b, i32 2, i32 1, i32 3) + ret <6 x i32> %r +} >From 4a105410fac0931c1bfec02e6356d5358faf37b9 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 5 Mar 2026 13:50:58 -0800 Subject: [PATCH 2/4] Split mul into header and builtin implementations - All non-matrix cases of mul are implemented in HLSL headers only - The case of mul between two double vectors is now always scalarized, even for SPIR-V which supports dot product between two double vectors - This is fine since the HLSL dot function will also not use the SPIR-V dot function for double vectors - Updated the codegen test for mul, which now also uses -O1 Assisted-by: claude-opus-4.6 --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 66 +-------- .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 5 - .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 15 +++ clang/lib/Headers/hlsl/hlsl_intrinsics.h | 49 +++++++ clang/lib/Sema/SemaHLSL.cpp | 20 +-- clang/test/CodeGenHLSL/builtins/mul.hlsl | 127 +++++++----------- 6 files changed, 124 insertions(+), 158 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 34e92c42fdef9..dd4ce85897313 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -1012,85 +1012,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, Value *Op1 = EmitScalarExpr(E->getArg(1)); QualType QTy0 = E->getArg(0)->getType(); QualType QTy1 = E->getArg(1)->getType(); - llvm::Type *T0 = Op0->getType(); - bool IsScalar0 = QTy0->isScalarType(); bool IsVec0 = QTy0->isVectorType(); bool IsMat0 = QTy0->isConstantMatrixType(); - bool IsScalar1 = QTy1->isScalarType(); - bool IsVec1 = QTy1->isVectorType(); bool IsMat1 = QTy1->isConstantMatrixType(); - bool IsFP = - QTy0->hasFloatingRepresentation() || QTy1->hasFloatingRepresentation(); - - // Cases 1-4, 7: scalar * scalar/vector/matrix or vector/matrix * scalar - if (IsScalar0 || IsScalar1) { - // Splat scalar to match the other operand's type - Value *Scalar = IsScalar0 ? Op0 : Op1; - Value *Other = IsScalar0 ? Op1 : Op0; - llvm::Type *OtherTy = Other->getType(); - - // Note: Matrices are flat vectors in the IR, so the following - // if-condition is also true when Other is a matrix, not just a vector. - if (OtherTy->isVectorTy()) { - unsigned NumElts = cast<FixedVectorType>(OtherTy)->getNumElements(); - Scalar = Builder.CreateVectorSplat(NumElts, Scalar); - } - - if (IsFP) - return Builder.CreateFMul(Scalar, Other, "hlsl.mul"); - return Builder.CreateMul(Scalar, Other, "hlsl.mul"); - } - - // Case 5: vector * vector -> scalar (dot product) - if (IsVec0 && IsVec1) { - auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>(); - QualType EltQTy = VecTy0->getElementType(); - - // DXIL doesn't have a dot product intrinsic for double vectors, - // so expand to scalar multiply-add for DXIL. - if (CGM.getTarget().getTriple().isDXIL() && - EltQTy->isSpecificBuiltinType(BuiltinType::Double)) { - unsigned NumElts = cast<FixedVectorType>(T0)->getNumElements(); - Value *Sum = nullptr; - for (unsigned I = 0; I < NumElts; ++I) { - Value *L = Builder.CreateExtractElement(Op0, I); - Value *R = Builder.CreateExtractElement(Op1, I); - if (Sum) - Sum = Builder.CreateIntrinsic(Sum->getType(), Intrinsic::fmuladd, - {L, R, Sum}); - else - Sum = Builder.CreateFMul(L, R); - } - return Sum; - } - - return Builder.CreateIntrinsic( - /*ReturnType=*/T0->getScalarType(), - getDotProductIntrinsic(CGM.getHLSLRuntime(), EltQTy), - ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.mul"); - } - // Cases 6, 8, 9: matrix involved -> use llvm.matrix.multiply + // Only matrix-involved cases reach the builtin (cases 6, 8, 9). llvm::MatrixBuilder MB(Builder); if (IsVec0 && IsMat1) { - // vector<N> * matrix<N,M> -> vector<M> - // Treat vector as 1×N matrix + // Case 6: vector<N> * matrix<N,M> -> vector<M> unsigned N = QTy0->castAs<VectorType>()->getNumElements(); auto *MatTy = QTy1->castAs<ConstantMatrixType>(); unsigned M = MatTy->getNumColumns(); return MB.CreateMatrixMultiply(Op0, Op1, 1, N, M, "hlsl.mul"); } - if (IsMat0 && IsVec1) { - // matrix<M,N> * vector<N> -> vector<M> - // Treat vector as N×1 matrix + if (IsMat0 && !IsMat1) { + // Case 8: matrix<M,N> * vector<N> -> vector<M> auto *MatTy = QTy0->castAs<ConstantMatrixType>(); unsigned Rows = MatTy->getNumRows(); unsigned Cols = MatTy->getNumColumns(); return MB.CreateMatrixMultiply(Op0, Op1, Rows, Cols, 1, "hlsl.mul"); } + // Case 9: matrix<M,K> * matrix<K,N> -> matrix<M,N> assert(IsMat0 && IsMat1); - // matrix<M,K> * matrix<K,N> -> matrix<M,N> auto *MatTy0 = QTy0->castAs<ConstantMatrixType>(); auto *MatTy1 = QTy1->castAs<ConstantMatrixType>(); return MB.CreateMatrixMultiply(Op0, Op1, MatTy0->getNumRows(), diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2e9847803a8a1..e676ca07e325a 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1829,11 +1829,6 @@ template <typename T, int R, int C> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) vector<T, C> mul(vector<T, R>, matrix<T, R, C>); -// Case 7: matrix * scalar -> matrix -template <typename T, int R, int C> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) -matrix<T, R, C> mul(matrix<T, R, C>, T); - // Case 8: matrix * vector -> vector template <typename T, int R, int C> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h index dfd4659637929..858670359f39c 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h @@ -58,6 +58,21 @@ constexpr float dot2add_impl(half2 a, half2 b, float c) { #endif } +template <typename T, int N> +constexpr enable_if_t<!is_same<double, T>::value, T> +mul_vec_impl(vector<T, N> x, vector<T, N> y) { + return dot(x, y); +} + +// Double vectors do not have a dot intrinsic, so expand manually. +template <typename T, int N> +enable_if_t<is_same<double, T>::value, T> mul_vec_impl(vector<T, N> x, + vector<T, N> y) { + T sum = x[0] * y[0]; + [unroll] for (int i = 1; i < N; ++i) sum += x[i] * y[i]; + return sum; +} + template <typename T> constexpr T reflect_impl(T I, T N) { return I - 2 * N * I * N; } diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index 330b3f12635e4..315f4ee2bf9d7 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -790,5 +790,54 @@ fwidth(__detail::HLSL_FIXED_VECTOR<float, N> input) { return __detail::fwidth_impl(input); } +//===----------------------------------------------------------------------===// +// mul builtins +//===----------------------------------------------------------------------===// + +/// \fn R mul(X x, Y y) +/// \brief Multiplies x and y using matrix math. +/// \param x [in] The first input value. If x is a vector, it is treated as a +/// row vector. +/// \param y [in] The second input value. If y is a vector, it is treated as a +/// column vector. +/// +/// The inner dimension x-columns and y-rows must be equal. The result has the +/// dimension x-rows x y-columns. When both x and y are vectors, the result is +/// a dot product (scalar). Scalar operands are multiplied element-wise. + +// Case 1: scalar * scalar -> scalar +template <typename T> +constexpr __detail::enable_if_t<__detail::is_arithmetic<T>::Value, T> mul(T x, + T y) { + return x * y; +} + +// Case 2: scalar * vector -> vector +template <typename T, int N> constexpr vector<T, N> mul(T x, vector<T, N> y) { + return x * y; +} + +// Case 3: scalar * matrix -> matrix +template <typename T, int R, int C> +constexpr matrix<T, R, C> mul(T x, matrix<T, R, C> y) { + return x * y; +} + +// Case 4: vector * scalar -> vector +template <typename T, int N> constexpr vector<T, N> mul(vector<T, N> x, T y) { + return x * y; +} + +// Case 5: vector * vector -> scalar (dot product) +template <typename T, int N> T mul(vector<T, N> x, vector<T, N> y) { + return __detail::mul_vec_impl(x, y); +} + +// Case 7: matrix * scalar -> matrix +template <typename T, int R, int C> +constexpr matrix<T, R, C> mul(matrix<T, R, C> x, T y) { + return x * y; +} + } // namespace hlsl #endif //_HLSL_HLSL_INTRINSICS_H_ diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 46a30acd95b68..44a676110ff55 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3794,34 +3794,24 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { QualType EltTy0 = getElemType(Ty0); - bool IsScalar0 = Ty0->isScalarType(); bool IsVec0 = Ty0->isVectorType(); bool IsMat0 = Ty0->isConstantMatrixType(); - bool IsScalar1 = Ty1->isScalarType(); bool IsVec1 = Ty1->isVectorType(); bool IsMat1 = Ty1->isConstantMatrixType(); QualType RetTy; - if (IsScalar0 && IsScalar1) { - RetTy = EltTy0; - } else if (IsScalar0 && IsVec1) { - RetTy = Ty1; - } else if (IsScalar0 && IsMat1) { - RetTy = Ty1; - } else if (IsVec0 && IsScalar1) { - RetTy = Ty0; - } else if (IsVec0 && IsVec1) { - RetTy = EltTy0; - } else if (IsVec0 && IsMat1) { + // Only matrix-involved cases reach the builtin (cases 6, 8, 9). + if (IsVec0 && IsMat1) { + // Case 6: vector * matrix -> vector auto *MatTy = Ty1->castAs<ConstantMatrixType>(); RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns()); - } else if (IsMat0 && IsScalar1) { - RetTy = Ty0; } else if (IsMat0 && IsVec1) { + // Case 8: matrix * vector -> vector auto *MatTy = Ty0->castAs<ConstantMatrixType>(); RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows()); } else { + // Case 9: matrix * matrix -> matrix assert(IsMat0 && IsMat1); auto *MatTy0 = Ty0->castAs<ConstantMatrixType>(); auto *MatTy1 = Ty1->castAs<ConstantMatrixType>(); diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl b/clang/test/CodeGenHLSL/builtins/mul.hlsl index 0a95d6004e567..f073191d45862 100644 --- a/clang/test/CodeGenHLSL/builtins/mul.hlsl +++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl @@ -1,143 +1,116 @@ -// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,DXIL -// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,SPIRV +// RUN: %clang_cc1 -finclude-default-header -O1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,DXIL +// RUN: %clang_cc1 -finclude-default-header -O1 -triple spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,SPIRV // -- Case 1: scalar * scalar -> scalar -- // CHECK-LABEL: test_scalar_mulf -// CHECK: [[A:%.*]] = load float, ptr %a.addr -// CHECK: [[B:%.*]] = load float, ptr %b.addr -// CHECK: %hlsl.mul = fmul {{.*}} float [[A]], [[B]] -// CHECK: ret float %hlsl.mul +// CHECK: %mul.i = fmul {{.*}} float %b, %a +// CHECK: ret float %mul.i export float test_scalar_mulf(float a, float b) { return mul(a, b); } // CHECK-LABEL: test_scalar_muli -// CHECK: [[A:%.*]] = load i32, ptr %a.addr -// CHECK: [[B:%.*]] = load i32, ptr %b.addr -// CHECK: %hlsl.mul = mul i32 [[A]], [[B]] -// CHECK: ret i32 %hlsl.mul +// CHECK: %mul.i = mul {{.*}} i32 %b, %a +// CHECK: ret i32 %mul.i export int test_scalar_muli(int a, int b) { return mul(a, b); } // -- Case 2: scalar * vector -> vector -- // CHECK-LABEL: test_scalar_vec_mul -// CHECK: [[A:%.*]] = load float, ptr %a.addr -// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr -// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[A]], i64 0 -// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer -// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[B]] -// CHECK: ret <3 x float> %hlsl.mul +// CHECK: %splat.splatinsert.i = insertelement <3 x float> poison, float %a, i64 0 +// CHECK: %splat.splat.i = shufflevector <3 x float> %splat.splatinsert.i, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %mul.i = fmul {{.*}} <3 x float> %splat.splat.i, %b +// CHECK: ret <3 x float> %mul.i export float3 test_scalar_vec_mul(float a, float3 b) { return mul(a, b); } // -- Case 3: scalar * matrix -> matrix -- // CHECK-LABEL: test_scalar_mat_mul -// CHECK: [[A:%.*]] = load float, ptr %a.addr -// CHECK: [[B:%.*]] = load <6 x float>, ptr %b.addr -// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[A]], i64 0 -// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> poison, <6 x i32> zeroinitializer -// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[B]] -// CHECK: ret <6 x float> %hlsl.mul +// CHECK: %scalar.splat.splatinsert.i = insertelement <6 x float> poison, float %a, i64 0 +// CHECK: %scalar.splat.splat.i = shufflevector <6 x float> %scalar.splat.splatinsert.i, <6 x float> poison, <6 x i32> zeroinitializer +// CHECK: [[MUL:%.*]] = fmul {{.*}} <6 x float> %scalar.splat.splat.i, %b +// CHECK: ret <6 x float> [[MUL]] export float2x3 test_scalar_mat_mul(float a, float2x3 b) { return mul(a, b); } // -- Case 4: vector * scalar -> vector -- // CHECK-LABEL: test_vec_scalar_mul -// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr -// CHECK: [[B:%.*]] = load float, ptr %b.addr -// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[B]], i64 0 -// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer -// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[A]] -// CHECK: ret <3 x float> %hlsl.mul +// CHECK: %splat.splatinsert.i = insertelement <3 x float> poison, float %b, i64 0 +// CHECK: %splat.splat.i = shufflevector <3 x float> %splat.splatinsert.i, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %mul.i = fmul {{.*}} <3 x float> %splat.splat.i, %a +// CHECK: ret <3 x float> %mul.i export float3 test_vec_scalar_mul(float3 a, float b) { return mul(a, b); } // -- Case 5: vector * vector -> scalar (dot product) -- // CHECK-LABEL: test_vec_vec_mul -// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr -// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr -// DXIL: %hlsl.mul = call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> [[A]], <3 x float> [[B]]) -// SPIRV: %hlsl.mul = call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float> [[A]], <3 x float> [[B]]) -// CHECK: ret float %hlsl.mul +// DXIL: %hlsl.dot.i = {{.*}} call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> {{.*}} %a, <3 x float> {{.*}} %b) +// SPIRV: %hlsl.dot.i = {{.*}} call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float> {{.*}} %a, <3 x float> {{.*}} %b) +// CHECK: ret float %hlsl.dot.i export float test_vec_vec_mul(float3 a, float3 b) { return mul(a, b); } // CHECK-LABEL: test_vec_vec_muli -// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr -// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr -// DXIL: %hlsl.mul = call i32 @llvm.dx.sdot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) -// SPIRV: %hlsl.mul = call i32 @llvm.spv.sdot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) -// CHECK: ret i32 %hlsl.mul +// DXIL: %hlsl.dot.i = {{.*}} call {{.*}} i32 @llvm.dx.sdot.v3i32(<3 x i32> %a, <3 x i32> %b) +// SPIRV: %hlsl.dot.i = {{.*}} call {{.*}} i32 @llvm.spv.sdot.v3i32(<3 x i32> %a, <3 x i32> %b) +// CHECK: ret i32 %hlsl.dot.i export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); } // CHECK-LABEL: test_vec_vec_mulu -// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr -// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr -// DXIL: %hlsl.mul = call i32 @llvm.dx.udot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) -// SPIRV: %hlsl.mul = call i32 @llvm.spv.udot.v3i32(<3 x i32> [[A]], <3 x i32> [[B]]) -// CHECK: ret i32 %hlsl.mul +// DXIL: %hlsl.dot.i = {{.*}} call {{.*}} i32 @llvm.dx.udot.v3i32(<3 x i32> %a, <3 x i32> %b) +// SPIRV: %hlsl.dot.i = {{.*}} call {{.*}} i32 @llvm.spv.udot.v3i32(<3 x i32> %a, <3 x i32> %b) +// CHECK: ret i32 %hlsl.dot.i export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); } -// Double vector dot product: DXIL uses scalar arithmetic, SPIR-V uses fdot +// Double vector dot product: no dot intrinsic for double vectors // CHECK-LABEL: test_vec_vec_muld -// CHECK: [[A:%.*]] = load <3 x double>, ptr %a.addr -// CHECK: [[B:%.*]] = load <3 x double>, ptr %b.addr -// DXIL-NOT: @llvm.dx.fdot -// DXIL: [[A0:%.*]] = extractelement <3 x double> [[A]], i64 0 -// DXIL: [[B0:%.*]] = extractelement <3 x double> [[B]], i64 0 -// DXIL: [[MUL0:%.*]] = fmul {{.*}} double [[A0]], [[B0]] -// DXIL: [[A1:%.*]] = extractelement <3 x double> [[A]], i64 1 -// DXIL: [[B1:%.*]] = extractelement <3 x double> [[B]], i64 1 -// DXIL: [[FMA0:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A1]], double [[B1]], double [[MUL0]]) -// DXIL: [[A2:%.*]] = extractelement <3 x double> [[A]], i64 2 -// DXIL: [[B2:%.*]] = extractelement <3 x double> [[B]], i64 2 -// DXIL: [[FMA1:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A2]], double [[B2]], double [[FMA0]]) -// DXIL: ret double [[FMA1]] -// SPIRV: %hlsl.mul = call {{.*}} double @llvm.spv.fdot.v3f64(<3 x double> [[A]], <3 x double> [[B]]) -// SPIRV: ret double %hlsl.mul +// CHECK-NOT: @llvm.dx.fdot +// CHECK-NOT: @llvm.spv.fdot +// CHECK: [[A0:%.*]] = extractelement <3 x double> %{{[ab]}}, i64 0 +// CHECK: [[B0:%.*]] = extractelement <3 x double> %{{[ab]}}, i64 0 +// CHECK: %mul.i = fmul {{.*}} double [[A0]], [[B0]] +// CHECK: [[A1:%.*]] = extractelement <3 x double> %a, i64 1 +// CHECK: [[B1:%.*]] = extractelement <3 x double> %b, i64 1 +// CHECK: [[MUL1:%.*]] = fmul {{.*}} double [[A1]], [[B1]] +// CHECK: [[ADD1:%.*]] = fadd {{.*}} double [[MUL1]], %mul.i +// CHECK: [[A2:%.*]] = extractelement <3 x double> %a, i64 2 +// CHECK: [[B2:%.*]] = extractelement <3 x double> %b, i64 2 +// CHECK: [[MUL2:%.*]] = fmul {{.*}} double [[A2]], [[B2]] +// CHECK: [[ADD2:%.*]] = fadd {{.*}} double [[MUL2]], [[ADD1]] +// CHECK: ret double [[ADD2]] export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); } // -- Case 6: vector * matrix -> vector -- // CHECK-LABEL: test_vec_mat_mul -// CHECK: [[V:%.*]] = load <2 x float>, ptr %v.addr -// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr -// CHECK: %hlsl.mul = call {{.*}} <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> [[V]], <6 x float> [[M]], i32 1, i32 2, i32 3) +// CHECK: %hlsl.mul = {{.*}} call {{.*}} <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> %v, <6 x float> %m, i32 1, i32 2, i32 3) // CHECK: ret <3 x float> %hlsl.mul export float3 test_vec_mat_mul(float2 v, float2x3 m) { return mul(v, m); } // -- Case 7: matrix * scalar -> matrix -- // CHECK-LABEL: test_mat_scalar_mul -// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr -// CHECK: [[B:%.*]] = load float, ptr %b.addr -// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[B]], i64 0 -// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> poison, <6 x i32> zeroinitializer -// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[A]] -// CHECK: ret <6 x float> %hlsl.mul +// CHECK: %scalar.splat.splatinsert.i = insertelement <6 x float> poison, float %b, i64 0 +// CHECK: %scalar.splat.splat.i = shufflevector <6 x float> %scalar.splat.splatinsert.i, <6 x float> poison, <6 x i32> zeroinitializer +// CHECK: [[MUL:%.*]] = fmul {{.*}} <6 x float> %scalar.splat.splat.i, %a +// CHECK: ret <6 x float> [[MUL]] export float2x3 test_mat_scalar_mul(float2x3 a, float b) { return mul(a, b); } // -- Case 8: matrix * vector -> vector -- // CHECK-LABEL: test_mat_vec_mul -// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr -// CHECK: [[V:%.*]] = load <3 x float>, ptr %v.addr -// CHECK: %hlsl.mul = call {{.*}} <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> [[M]], <3 x float> [[V]], i32 2, i32 3, i32 1) +// CHECK: %hlsl.mul = {{.*}} call {{.*}} <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> %m, <3 x float> %v, i32 2, i32 3, i32 1) // CHECK: ret <2 x float> %hlsl.mul export float2 test_mat_vec_mul(float2x3 m, float3 v) { return mul(m, v); } // -- Case 9: matrix * matrix -> matrix -- // CHECK-LABEL: test_mat_mat_mul -// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr -// CHECK: [[B:%.*]] = load <12 x float>, ptr %b.addr -// CHECK: %hlsl.mul = call {{.*}} <8 x float> @llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> [[A]], <12 x float> [[B]], i32 2, i32 3, i32 4) +// CHECK: %hlsl.mul = {{.*}} call {{.*}} <8 x float> @llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> %a, <12 x float> %b, i32 2, i32 3, i32 4) // CHECK: ret <8 x float> %hlsl.mul export float2x4 test_mat_mat_mul(float2x3 a, float3x4 b) { return mul(a, b); } // -- Integer matrix multiply -- // CHECK-LABEL: test_mat_mat_muli -// CHECK: [[A:%.*]] = load <6 x i32>, ptr %a.addr -// CHECK: [[B:%.*]] = load <12 x i32>, ptr %b.addr -// CHECK: %hlsl.mul = call <8 x i32> @llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> [[A]], <12 x i32> [[B]], i32 2, i32 3, i32 4) +// CHECK: %hlsl.mul = {{.*}} call <8 x i32> @llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> %a, <12 x i32> %b, i32 2, i32 3, i32 4) // CHECK: ret <8 x i32> %hlsl.mul export int2x4 test_mat_mat_muli(int2x3 a, int3x4 b) { return mul(a, b); } >From 41612f30a85255bbd81c30d7c1dc33fe652bfc2c Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 5 Mar 2026 13:57:05 -0800 Subject: [PATCH 3/4] Remove redundant mul definitions. Make mul documentation comment consistent between files --- .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 23 ------------------- clang/lib/Headers/hlsl/hlsl_intrinsics.h | 11 +++++++++ 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index e676ca07e325a..49a64b070207a 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1801,29 +1801,6 @@ double4 min(double4, double4); /// 8. matrix * vector -> vector /// 9. matrix * matrix -> matrix -// Case 1: scalar * scalar -> scalar -template <typename T> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) T mul(T, T); - -// Case 2: scalar * vector -> vector -template <typename T, int N> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) -vector<T, N> mul(T, vector<T, N>); - -// Case 3: scalar * matrix -> matrix -template <typename T, int R, int C> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) -matrix<T, R, C> mul(T, matrix<T, R, C>); - -// Case 4: vector * scalar -> vector -template <typename T, int N> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) -vector<T, N> mul(vector<T, N>, T); - -// Case 5: vector * vector -> scalar (dot product) -template <typename T, int N> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) -T mul(vector<T, N>, vector<T, N>); - // Case 6: vector * matrix -> vector template <typename T, int R, int C> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index 315f4ee2bf9d7..0d81085664427 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -804,6 +804,17 @@ fwidth(__detail::HLSL_FIXED_VECTOR<float, N> input) { /// The inner dimension x-columns and y-rows must be equal. The result has the /// dimension x-rows x y-columns. When both x and y are vectors, the result is /// a dot product (scalar). Scalar operands are multiplied element-wise. +/// +/// This function supports 9 overloaded forms: +/// 1. scalar * scalar -> scalar +/// 2. scalar * vector -> vector +/// 3. scalar * matrix -> matrix +/// 4. vector * scalar -> vector +/// 5. vector * vector -> scalar (dot product) +/// 6. vector * matrix -> vector +/// 7. matrix * scalar -> matrix +/// 8. matrix * vector -> vector +/// 9. matrix * matrix -> matrix // Case 1: scalar * scalar -> scalar template <typename T> >From 64f025eb4e13d24206bc8710a810b25201f31857 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 5 Mar 2026 14:51:33 -0800 Subject: [PATCH 4/4] Relax CHECKs for double vector x double vector mul test --- clang/test/CodeGenHLSL/builtins/mul.hlsl | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl b/clang/test/CodeGenHLSL/builtins/mul.hlsl index f073191d45862..e1af03791b367 100644 --- a/clang/test/CodeGenHLSL/builtins/mul.hlsl +++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl @@ -60,22 +60,14 @@ export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); } // CHECK: ret i32 %hlsl.dot.i export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); } -// Double vector dot product: no dot intrinsic for double vectors +// Double vector dot product: no dot intrinsic for double vectors. +// The checks for this test are less precise because the scalar loop may be vectorized depending on the build configuration. // CHECK-LABEL: test_vec_vec_muld // CHECK-NOT: @llvm.dx.fdot // CHECK-NOT: @llvm.spv.fdot -// CHECK: [[A0:%.*]] = extractelement <3 x double> %{{[ab]}}, i64 0 -// CHECK: [[B0:%.*]] = extractelement <3 x double> %{{[ab]}}, i64 0 -// CHECK: %mul.i = fmul {{.*}} double [[A0]], [[B0]] -// CHECK: [[A1:%.*]] = extractelement <3 x double> %a, i64 1 -// CHECK: [[B1:%.*]] = extractelement <3 x double> %b, i64 1 -// CHECK: [[MUL1:%.*]] = fmul {{.*}} double [[A1]], [[B1]] -// CHECK: [[ADD1:%.*]] = fadd {{.*}} double [[MUL1]], %mul.i -// CHECK: [[A2:%.*]] = extractelement <3 x double> %a, i64 2 -// CHECK: [[B2:%.*]] = extractelement <3 x double> %b, i64 2 -// CHECK: [[MUL2:%.*]] = fmul {{.*}} double [[A2]], [[B2]] -// CHECK: [[ADD2:%.*]] = fadd {{.*}} double [[MUL2]], [[ADD1]] -// CHECK: ret double [[ADD2]] +// CHECK: fmul {{.*}} double +// CHECK: fadd {{.*}} double +// CHECK: ret double %{{.*}} export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); } // -- Case 6: vector * matrix -> vector -- _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
