llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Deric C. (Icohedron)
<details>
<summary>Changes</summary>
Fixes #<!-- -->99138
- Defines a `__builtin_hlsl_mul` clang builtin in `Builtins.td`.
- Links the `__builtin_hlsl_mul` clang builtin with `hlsl_alias_intrinsics.h`
under the name `mul`
- Adds sema for `__builtin_hlsl_mul` to `CheckBuiltinFunctionCall` in
`SemaHLSL.cpp`
- Adds codegen for `__builtin_hlsl_mul` to `EmitHLSLBuiltinExpr` in
`CGHLSLBuiltins.cpp`
- Vector-vector multiplication uses `dot`, except double vectors in DirectX
which expand to scalar (fused) multiply-adds.
- Matrix-matrix, matrix-vector, and vector-matrix multiplication lower to
the `llvm.matrix.multiply` intrinsic
- Adds codegen tests to `clang/test/CodeGenHLSL/builtins/mul.hlsl`
- Adds sema tests to `clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl`
- Implements lowering of the `llvm.matrix.multiply` intrinsic to DXIL in
`DXILIntrinsicExpansion.cpp`
- Currently only supports column-major matrix memory layout
Note: Currently the SPIR-V backend does not support row-major matrix memory
layouts when lowering matrix multiply, and just assumes column-major layout.
Therefore this PR also makes the DXIL backend only assume column-major layout.
Implementing support for different matrix memory layouts in the DXIL and SPIR-V
backends shall be done in a separate PR.
Assisted-by: claude-opus-4.6
---
Patch is 49.95 KiB, truncated to 20.00 KiB below, full version:
https://github.com/llvm/llvm-project/pull/184882.diff
8 Files Affected:
- (modified) clang/include/clang/Basic/Builtins.td (+6)
- (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+91)
- (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (+69)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+57)
- (added) clang/test/CodeGenHLSL/builtins/mul.hlsl (+143)
- (added) clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl (+42)
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+96)
- (added) llvm/test/CodeGen/DirectX/matrix-multiply.ll (+342)
``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td
b/clang/include/clang/Basic/Builtins.td
index 531c3702161f2..c6a0c10449ba9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5312,6 +5312,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLMul : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_mul"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_normalize"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..34e92c42fdef9 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -13,6 +13,7 @@
#include "CGBuiltin.h"
#include "CGHLSLRuntime.h"
#include "CodeGenFunction.h"
+#include "llvm/IR/MatrixBuilder.h"
using namespace clang;
using namespace CodeGen;
@@ -1006,6 +1007,96 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned
BuiltinID,
Value *Mul = Builder.CreateNUWMul(M, A);
return Builder.CreateNUWAdd(Mul, B);
}
+ case Builtin::BI__builtin_hlsl_mul: {
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ Value *Op1 = EmitScalarExpr(E->getArg(1));
+ QualType QTy0 = E->getArg(0)->getType();
+ QualType QTy1 = E->getArg(1)->getType();
+ llvm::Type *T0 = Op0->getType();
+
+ bool IsScalar0 = QTy0->isScalarType();
+ bool IsVec0 = QTy0->isVectorType();
+ bool IsMat0 = QTy0->isConstantMatrixType();
+ bool IsScalar1 = QTy1->isScalarType();
+ bool IsVec1 = QTy1->isVectorType();
+ bool IsMat1 = QTy1->isConstantMatrixType();
+ bool IsFP =
+ QTy0->hasFloatingRepresentation() || QTy1->hasFloatingRepresentation();
+
+ // Cases 1-4, 7: scalar * scalar/vector/matrix or vector/matrix * scalar
+ if (IsScalar0 || IsScalar1) {
+ // Splat scalar to match the other operand's type
+ Value *Scalar = IsScalar0 ? Op0 : Op1;
+ Value *Other = IsScalar0 ? Op1 : Op0;
+ llvm::Type *OtherTy = Other->getType();
+
+ // Note: Matrices are flat vectors in the IR, so the following
+ // if-condition is also true when Other is a matrix, not just a vector.
+ if (OtherTy->isVectorTy()) {
+ unsigned NumElts = cast<FixedVectorType>(OtherTy)->getNumElements();
+ Scalar = Builder.CreateVectorSplat(NumElts, Scalar);
+ }
+
+ if (IsFP)
+ return Builder.CreateFMul(Scalar, Other, "hlsl.mul");
+ return Builder.CreateMul(Scalar, Other, "hlsl.mul");
+ }
+
+ // Case 5: vector * vector -> scalar (dot product)
+ if (IsVec0 && IsVec1) {
+ auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();
+ QualType EltQTy = VecTy0->getElementType();
+
+ // DXIL doesn't have a dot product intrinsic for double vectors,
+ // so expand to scalar multiply-add for DXIL.
+ if (CGM.getTarget().getTriple().isDXIL() &&
+ EltQTy->isSpecificBuiltinType(BuiltinType::Double)) {
+ unsigned NumElts = cast<FixedVectorType>(T0)->getNumElements();
+ Value *Sum = nullptr;
+ for (unsigned I = 0; I < NumElts; ++I) {
+ Value *L = Builder.CreateExtractElement(Op0, I);
+ Value *R = Builder.CreateExtractElement(Op1, I);
+ if (Sum)
+ Sum = Builder.CreateIntrinsic(Sum->getType(), Intrinsic::fmuladd,
+ {L, R, Sum});
+ else
+ Sum = Builder.CreateFMul(L, R);
+ }
+ return Sum;
+ }
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/T0->getScalarType(),
+ getDotProductIntrinsic(CGM.getHLSLRuntime(), EltQTy),
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.mul");
+ }
+
+ // Cases 6, 8, 9: matrix involved -> use llvm.matrix.multiply
+ llvm::MatrixBuilder MB(Builder);
+ if (IsVec0 && IsMat1) {
+ // vector<N> * matrix<N,M> -> vector<M>
+ // Treat vector as 1×N matrix
+ unsigned N = QTy0->castAs<VectorType>()->getNumElements();
+ auto *MatTy = QTy1->castAs<ConstantMatrixType>();
+ unsigned M = MatTy->getNumColumns();
+ return MB.CreateMatrixMultiply(Op0, Op1, 1, N, M, "hlsl.mul");
+ }
+ if (IsMat0 && IsVec1) {
+ // matrix<M,N> * vector<N> -> vector<M>
+ // Treat vector as N×1 matrix
+ auto *MatTy = QTy0->castAs<ConstantMatrixType>();
+ unsigned Rows = MatTy->getNumRows();
+ unsigned Cols = MatTy->getNumColumns();
+ return MB.CreateMatrixMultiply(Op0, Op1, Rows, Cols, 1, "hlsl.mul");
+ }
+ assert(IsMat0 && IsMat1);
+ // matrix<M,K> * matrix<K,N> -> matrix<M,N>
+ auto *MatTy0 = QTy0->castAs<ConstantMatrixType>();
+ auto *MatTy1 = QTy1->castAs<ConstantMatrixType>();
+ return MB.CreateMatrixMultiply(Op0, Op1, MatTy0->getNumRows(),
+ MatTy0->getNumColumns(),
+ MatTy1->getNumColumns(), "hlsl.mul");
+ }
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2543401bdfbf9..2e9847803a8a1 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -1775,6 +1775,75 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
+//===----------------------------------------------------------------------===//
+// mul builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn R mul(X x, Y y)
+/// \brief Multiplies x and y using matrix math.
+/// \param x [in] The first input value. If x is a vector, it is treated as a
+/// row vector.
+/// \param y [in] The second input value. If y is a vector, it is treated as a
+/// column vector.
+///
+/// The inner dimension x-columns and y-rows must be equal. The result has the
+/// dimension x-rows x y-columns. When both x and y are vectors, the result is
+/// a dot product (scalar). Scalar operands are multiplied element-wise.
+///
+/// This function supports 9 overloaded forms:
+/// 1. scalar * scalar -> scalar
+/// 2. scalar * vector -> vector
+/// 3. scalar * matrix -> matrix
+/// 4. vector * scalar -> vector
+/// 5. vector * vector -> scalar (dot product)
+/// 6. vector * matrix -> vector
+/// 7. matrix * scalar -> matrix
+/// 8. matrix * vector -> vector
+/// 9. matrix * matrix -> matrix
+
+// Case 1: scalar * scalar -> scalar
+template <typename T> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) T mul(T, T);
+
+// Case 2: scalar * vector -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(T, vector<T, N>);
+
+// Case 3: scalar * matrix -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(T, matrix<T, R, C>);
+
+// Case 4: vector * scalar -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(vector<T, N>, T);
+
+// Case 5: vector * vector -> scalar (dot product)
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+T mul(vector<T, N>, vector<T, N>);
+
+// Case 6: vector * matrix -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, C> mul(vector<T, R>, matrix<T, R, C>);
+
+// Case 7: matrix * scalar -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, C>, T);
+
+// Case 8: matrix * vector -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, R> mul(matrix<T, R, C>, vector<T, C>);
+
+// Case 9: matrix * matrix -> matrix
+template <typename T, int R, int K, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, K>, matrix<T, K, C>);
+
//===----------------------------------------------------------------------===//
// normalize builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 804ea70aaddce..46a30acd95b68 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3775,6 +3775,63 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned
BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_mul: {
+ if (SemaRef.checkArgCount(TheCall, 2))
+ return true;
+
+ Expr *Arg0 = TheCall->getArg(0);
+ Expr *Arg1 = TheCall->getArg(1);
+ QualType Ty0 = Arg0->getType();
+ QualType Ty1 = Arg1->getType();
+
+ auto getElemType = [](QualType T) -> QualType {
+ if (const auto *VTy = T->getAs<VectorType>())
+ return VTy->getElementType();
+ if (const auto *MTy = T->getAs<ConstantMatrixType>())
+ return MTy->getElementType();
+ return T;
+ };
+
+ QualType EltTy0 = getElemType(Ty0);
+
+ bool IsScalar0 = Ty0->isScalarType();
+ bool IsVec0 = Ty0->isVectorType();
+ bool IsMat0 = Ty0->isConstantMatrixType();
+ bool IsScalar1 = Ty1->isScalarType();
+ bool IsVec1 = Ty1->isVectorType();
+ bool IsMat1 = Ty1->isConstantMatrixType();
+
+ QualType RetTy;
+
+ if (IsScalar0 && IsScalar1) {
+ RetTy = EltTy0;
+ } else if (IsScalar0 && IsVec1) {
+ RetTy = Ty1;
+ } else if (IsScalar0 && IsMat1) {
+ RetTy = Ty1;
+ } else if (IsVec0 && IsScalar1) {
+ RetTy = Ty0;
+ } else if (IsVec0 && IsVec1) {
+ RetTy = EltTy0;
+ } else if (IsVec0 && IsMat1) {
+ auto *MatTy = Ty1->castAs<ConstantMatrixType>();
+ RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
+ } else if (IsMat0 && IsScalar1) {
+ RetTy = Ty0;
+ } else if (IsMat0 && IsVec1) {
+ auto *MatTy = Ty0->castAs<ConstantMatrixType>();
+ RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
+ } else {
+ assert(IsMat0 && IsMat1);
+ auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
+ auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
+ RetTy = getASTContext().getConstantMatrixType(
+ EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
+ }
+
+ TheCall->setType(RetTy);
+ break;
+ }
case Builtin::BI__builtin_hlsl_normalize: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl
b/clang/test/CodeGenHLSL/builtins/mul.hlsl
new file mode 100644
index 0000000000000..0a95d6004e567
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl
@@ -0,0 +1,143 @@
+// RUN: %clang_cc1 -finclude-default-header -triple
dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s
--check-prefixes=CHECK,DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple
spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s
--check-prefixes=CHECK,SPIRV
+
+// -- Case 1: scalar * scalar -> scalar --
+
+// CHECK-LABEL: test_scalar_mulf
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %hlsl.mul = fmul {{.*}} float [[A]], [[B]]
+// CHECK: ret float %hlsl.mul
+export float test_scalar_mulf(float a, float b) { return mul(a, b); }
+
+// CHECK-LABEL: test_scalar_muli
+// CHECK: [[A:%.*]] = load i32, ptr %a.addr
+// CHECK: [[B:%.*]] = load i32, ptr %b.addr
+// CHECK: %hlsl.mul = mul i32 [[A]], [[B]]
+// CHECK: ret i32 %hlsl.mul
+export int test_scalar_muli(int a, int b) { return mul(a, b); }
+
+// -- Case 2: scalar * vector -> vector --
+
+// CHECK-LABEL: test_scalar_vec_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float>
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[B]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_scalar_vec_mul(float a, float3 b) { return mul(a, b); }
+
+// -- Case 3: scalar * matrix -> matrix --
+
+// CHECK-LABEL: test_scalar_mat_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <6 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float>
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[B]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_scalar_mat_mul(float a, float2x3 b) { return mul(a, b); }
+
+// -- Case 4: vector * scalar -> vector --
+
+// CHECK-LABEL: test_vec_scalar_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float>
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[A]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_scalar_mul(float3 a, float b) { return mul(a, b); }
+
+// -- Case 5: vector * vector -> scalar (dot product) --
+
+// CHECK-LABEL: test_vec_vec_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// DXIL: %hlsl.mul = call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> [[A]],
<3 x float> [[B]])
+// SPIRV: %hlsl.mul = call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float>
[[A]], <3 x float> [[B]])
+// CHECK: ret float %hlsl.mul
+export float test_vec_vec_mul(float3 a, float3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_muli
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.sdot.v3i32(<3 x i32> [[A]], <3 x i32>
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.sdot.v3i32(<3 x i32> [[A]], <3 x i32>
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_mulu
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.udot.v3i32(<3 x i32> [[A]], <3 x i32>
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.udot.v3i32(<3 x i32> [[A]], <3 x i32>
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); }
+
+// Double vector dot product: DXIL uses scalar arithmetic, SPIR-V uses fdot
+// CHECK-LABEL: test_vec_vec_muld
+// CHECK: [[A:%.*]] = load <3 x double>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x double>, ptr %b.addr
+// DXIL-NOT: @llvm.dx.fdot
+// DXIL: [[A0:%.*]] = extractelement <3 x double> [[A]], i64 0
+// DXIL: [[B0:%.*]] = extractelement <3 x double> [[B]], i64 0
+// DXIL: [[MUL0:%.*]] = fmul {{.*}} double [[A0]], [[B0]]
+// DXIL: [[A1:%.*]] = extractelement <3 x double> [[A]], i64 1
+// DXIL: [[B1:%.*]] = extractelement <3 x double> [[B]], i64 1
+// DXIL: [[FMA0:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A1]],
double [[B1]], double [[MUL0]])
+// DXIL: [[A2:%.*]] = extractelement <3 x double> [[A]], i64 2
+// DXIL: [[B2:%.*]] = extractelement <3 x double> [[B]], i64 2
+// DXIL: [[FMA1:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A2]],
double [[B2]], double [[FMA0]])
+// DXIL: ret double [[FMA1]]
+// SPIRV: %hlsl.mul = call {{.*}} double @llvm.spv.fdot.v3f64(<3 x double>
[[A]], <3 x double> [[B]])
+// SPIRV: ret double %hlsl.mul
+export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); }
+
+// -- Case 6: vector * matrix -> vector --
+
+// CHECK-LABEL: test_vec_mat_mul
+// CHECK: [[V:%.*]] = load <2 x float>, ptr %v.addr
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: %hlsl.mul = call {{.*}} <3 x float>
@llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> [[V]], <6 x float> [[M]],
i32 1, i32 2, i32 3)
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_mat_mul(float2 v, float2x3 m) { return mul(v, m); }
+
+// -- Case 7: matrix * scalar -> matrix --
+
+// CHECK-LABEL: test_mat_scalar_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float>
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[A]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_mat_scalar_mul(float2x3 a, float b) { return mul(a, b); }
+
+// -- Case 8: matrix * vector -> vector --
+
+// CHECK-LABEL: test_mat_vec_mul
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: [[V:%.*]] = load <3 x float>, ptr %v.addr
+// CHECK: %hlsl.mul = call {{.*}} <2 x float>
@llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> [[M]], <3 x float> [[V]],
i32 2, i32 3, i32 1)
+// CHECK: ret <2 x float> %hlsl.mul
+export float2 test_mat_vec_mul(float2x3 m, float3 v) { return mul(m, v); }
+
+// -- Case 9: matrix * matrix -> matrix --
+
+// CHECK-LABEL: test_mat_mat_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x float>, ptr %b.addr
+// CHECK: %hlsl.mul = call {{.*}} <8 x float>
@llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> [[A]], <12 x float> [[B]],
i32 2, i32 3, i32 4)
+// CHECK: ret <8 x float> %hlsl.mul
+export float2x4 test_mat_mat_mul(float2x3 a, float3x4 b) { return mul(a, b); }
+
+// -- Integer matrix multiply --
+
+// CHECK-LABEL: test_mat_mat_muli
+// CHECK: [[A:%.*]] = load <6 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x i32>, ptr %b.addr
+// CHECK: %hlsl.mul = call <8 x i32>
@llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> [[A]], <12 x i32> [[B]], i32
2, i32 3, i32 4)
+// CHECK: ret <8 x i32> %hlsl.mul
+export int2x4 test_mat_mat_muli(int2x3 a, int3x4 b) { return mul(a, b); }
diff --git a/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
new file mode 100644
index 0000000000000..01227126bff10
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
@@ -0,0 +1,42 @@
+// RUN: %clang_cc1 -finclude-default-header -triple
dxil-pc-shadermodel6.3-library %s -verify
+
+// expected-note@*:* 54 {{candidate template ignored}}
+
+// Test mul error cases via template overload resolution
+
+// Inner dimension mismatch: vector * vector with different sizes
+export float test_vec_dim_mismatch(float2 a, float3 b) {
+ return mul(a, b);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * matrix
+export float2x4 test_mat_dim_mismatch(float2x3 a, float4x4 b) {
+ return mul(a, b);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: vector * matrix
+export float3 test_vec_mat_mismatch(float3 v, float2x3 m) {
+ return mul(v, m);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * vector
+export float2 test_mat_vec_mismatch(float2x3 m, float2 v) {
+ return mul(m, v);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different element types
+export float test_type_mismatch(float a, int b) {
+ return mul(a, b);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different vector element types
+export float test_vec_type_mismatch(float3 a, int3 b) {
+ return mul(a, b);
+ // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index c4bf097e5a0f8..78ac0edc7de47 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -22,6 +22,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
@@ -229,6 +230,7 @@ static bool isIntrinsicExpansion(Function &F) ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184882
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits