llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-codegen

Author: Deric C. (Icohedron)

<details>
<summary>Changes</summary>

Fixes #<!-- -->99138

- Defines a `__builtin_hlsl_mul` clang builtin in `Builtins.td`.
- Links the `__builtin_hlsl_mul` clang builtin with `hlsl_alias_intrinsics.h` 
under the name `mul`
- Adds sema for `__builtin_hlsl_mul` to `CheckBuiltinFunctionCall` in 
`SemaHLSL.cpp`
- Adds codegen for `__builtin_hlsl_mul` to `EmitHLSLBuiltinExpr` in 
`CGHLSLBuiltins.cpp`
    - Vector-vector multiplication uses `dot`, except double vectors in DirectX 
which expand to scalar (fused) multiply-adds.
    - Matrix-matrix, matrix-vector, and vector-matrix multiplication lower to 
the `llvm.matrix.multiply` intrinsic
- Adds codegen tests to `clang/test/CodeGenHLSL/builtins/mul.hlsl`
- Adds sema tests to `clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl`
- Implements lowering of the `llvm.matrix.multiply` intrinsic to DXIL in 
`DXILIntrinsicExpansion.cpp`
    - Currently only supports column-major matrix memory layout

Note: Currently the SPIR-V backend does not support row-major matrix memory 
layouts when lowering matrix multiply, and just assumes column-major layout. 
Therefore this PR also makes the DXIL backend only assume column-major layout. 
Implementing support for different matrix memory layouts in the DXIL and SPIR-V 
backends shall be done in a separate PR.

Assisted-by: claude-opus-4.6

---

Patch is 49.95 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/184882.diff


8 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+91) 
- (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (+69) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+57) 
- (added) clang/test/CodeGenHLSL/builtins/mul.hlsl (+143) 
- (added) clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl (+42) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+96) 
- (added) llvm/test/CodeGen/DirectX/matrix-multiply.ll (+342) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 531c3702161f2..c6a0c10449ba9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5312,6 +5312,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLMul : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_mul"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_normalize"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..34e92c42fdef9 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -13,6 +13,7 @@
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
 #include "CodeGenFunction.h"
+#include "llvm/IR/MatrixBuilder.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -1006,6 +1007,96 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     Value *Mul = Builder.CreateNUWMul(M, A);
     return Builder.CreateNUWAdd(Mul, B);
   }
+  case Builtin::BI__builtin_hlsl_mul: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    QualType QTy0 = E->getArg(0)->getType();
+    QualType QTy1 = E->getArg(1)->getType();
+    llvm::Type *T0 = Op0->getType();
+
+    bool IsScalar0 = QTy0->isScalarType();
+    bool IsVec0 = QTy0->isVectorType();
+    bool IsMat0 = QTy0->isConstantMatrixType();
+    bool IsScalar1 = QTy1->isScalarType();
+    bool IsVec1 = QTy1->isVectorType();
+    bool IsMat1 = QTy1->isConstantMatrixType();
+    bool IsFP =
+        QTy0->hasFloatingRepresentation() || QTy1->hasFloatingRepresentation();
+
+    // Cases 1-4, 7: scalar * scalar/vector/matrix or vector/matrix * scalar
+    if (IsScalar0 || IsScalar1) {
+      // Splat scalar to match the other operand's type
+      Value *Scalar = IsScalar0 ? Op0 : Op1;
+      Value *Other = IsScalar0 ? Op1 : Op0;
+      llvm::Type *OtherTy = Other->getType();
+
+      // Note: Matrices are flat vectors in the IR, so the following
+      // if-condition is also true when Other is a matrix, not just a vector.
+      if (OtherTy->isVectorTy()) {
+        unsigned NumElts = cast<FixedVectorType>(OtherTy)->getNumElements();
+        Scalar = Builder.CreateVectorSplat(NumElts, Scalar);
+      }
+
+      if (IsFP)
+        return Builder.CreateFMul(Scalar, Other, "hlsl.mul");
+      return Builder.CreateMul(Scalar, Other, "hlsl.mul");
+    }
+
+    // Case 5: vector * vector -> scalar (dot product)
+    if (IsVec0 && IsVec1) {
+      auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();
+      QualType EltQTy = VecTy0->getElementType();
+
+      // DXIL doesn't have a dot product intrinsic for double vectors,
+      // so expand to scalar multiply-add for DXIL.
+      if (CGM.getTarget().getTriple().isDXIL() &&
+          EltQTy->isSpecificBuiltinType(BuiltinType::Double)) {
+        unsigned NumElts = cast<FixedVectorType>(T0)->getNumElements();
+        Value *Sum = nullptr;
+        for (unsigned I = 0; I < NumElts; ++I) {
+          Value *L = Builder.CreateExtractElement(Op0, I);
+          Value *R = Builder.CreateExtractElement(Op1, I);
+          if (Sum)
+            Sum = Builder.CreateIntrinsic(Sum->getType(), Intrinsic::fmuladd,
+                                          {L, R, Sum});
+          else
+            Sum = Builder.CreateFMul(L, R);
+        }
+        return Sum;
+      }
+
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/T0->getScalarType(),
+          getDotProductIntrinsic(CGM.getHLSLRuntime(), EltQTy),
+          ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.mul");
+    }
+
+    // Cases 6, 8, 9: matrix involved -> use llvm.matrix.multiply
+    llvm::MatrixBuilder MB(Builder);
+    if (IsVec0 && IsMat1) {
+      // vector<N> * matrix<N,M> -> vector<M>
+      // Treat vector as 1×N matrix
+      unsigned N = QTy0->castAs<VectorType>()->getNumElements();
+      auto *MatTy = QTy1->castAs<ConstantMatrixType>();
+      unsigned M = MatTy->getNumColumns();
+      return MB.CreateMatrixMultiply(Op0, Op1, 1, N, M, "hlsl.mul");
+    }
+    if (IsMat0 && IsVec1) {
+      // matrix<M,N> * vector<N> -> vector<M>
+      // Treat vector as N×1 matrix
+      auto *MatTy = QTy0->castAs<ConstantMatrixType>();
+      unsigned Rows = MatTy->getNumRows();
+      unsigned Cols = MatTy->getNumColumns();
+      return MB.CreateMatrixMultiply(Op0, Op1, Rows, Cols, 1, "hlsl.mul");
+    }
+    assert(IsMat0 && IsMat1);
+    // matrix<M,K> * matrix<K,N> -> matrix<M,N>
+    auto *MatTy0 = QTy0->castAs<ConstantMatrixType>();
+    auto *MatTy1 = QTy1->castAs<ConstantMatrixType>();
+    return MB.CreateMatrixMultiply(Op0, Op1, MatTy0->getNumRows(),
+                                   MatTy0->getNumColumns(),
+                                   MatTy1->getNumColumns(), "hlsl.mul");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h 
b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2543401bdfbf9..2e9847803a8a1 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -1775,6 +1775,75 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// mul builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn R mul(X x, Y y)
+/// \brief Multiplies x and y using matrix math.
+/// \param x [in] The first input value. If x is a vector, it is treated as a
+///   row vector.
+/// \param y [in] The second input value. If y is a vector, it is treated as a
+///   column vector.
+///
+/// The inner dimension x-columns and y-rows must be equal. The result has the
+/// dimension x-rows x y-columns. When both x and y are vectors, the result is
+/// a dot product (scalar). Scalar operands are multiplied element-wise.
+///
+/// This function supports 9 overloaded forms:
+///   1. scalar * scalar -> scalar
+///   2. scalar * vector -> vector
+///   3. scalar * matrix -> matrix
+///   4. vector * scalar -> vector
+///   5. vector * vector -> scalar (dot product)
+///   6. vector * matrix -> vector
+///   7. matrix * scalar -> matrix
+///   8. matrix * vector -> vector
+///   9. matrix * matrix -> matrix
+
+// Case 1: scalar * scalar -> scalar
+template <typename T> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) T mul(T, T);
+
+// Case 2: scalar * vector -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(T, vector<T, N>);
+
+// Case 3: scalar * matrix -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(T, matrix<T, R, C>);
+
+// Case 4: vector * scalar -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(vector<T, N>, T);
+
+// Case 5: vector * vector -> scalar (dot product)
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+T mul(vector<T, N>, vector<T, N>);
+
+// Case 6: vector * matrix -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, C> mul(vector<T, R>, matrix<T, R, C>);
+
+// Case 7: matrix * scalar -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, C>, T);
+
+// Case 8: matrix * vector -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, R> mul(matrix<T, R, C>, vector<T, C>);
+
+// Case 9: matrix * matrix -> matrix
+template <typename T, int R, int K, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, K>, matrix<T, K, C>);
+
 
//===----------------------------------------------------------------------===//
 // normalize builtins
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 804ea70aaddce..46a30acd95b68 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3775,6 +3775,63 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_mul: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    Expr *Arg0 = TheCall->getArg(0);
+    Expr *Arg1 = TheCall->getArg(1);
+    QualType Ty0 = Arg0->getType();
+    QualType Ty1 = Arg1->getType();
+
+    auto getElemType = [](QualType T) -> QualType {
+      if (const auto *VTy = T->getAs<VectorType>())
+        return VTy->getElementType();
+      if (const auto *MTy = T->getAs<ConstantMatrixType>())
+        return MTy->getElementType();
+      return T;
+    };
+
+    QualType EltTy0 = getElemType(Ty0);
+
+    bool IsScalar0 = Ty0->isScalarType();
+    bool IsVec0 = Ty0->isVectorType();
+    bool IsMat0 = Ty0->isConstantMatrixType();
+    bool IsScalar1 = Ty1->isScalarType();
+    bool IsVec1 = Ty1->isVectorType();
+    bool IsMat1 = Ty1->isConstantMatrixType();
+
+    QualType RetTy;
+
+    if (IsScalar0 && IsScalar1) {
+      RetTy = EltTy0;
+    } else if (IsScalar0 && IsVec1) {
+      RetTy = Ty1;
+    } else if (IsScalar0 && IsMat1) {
+      RetTy = Ty1;
+    } else if (IsVec0 && IsScalar1) {
+      RetTy = Ty0;
+    } else if (IsVec0 && IsVec1) {
+      RetTy = EltTy0;
+    } else if (IsVec0 && IsMat1) {
+      auto *MatTy = Ty1->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
+    } else if (IsMat0 && IsScalar1) {
+      RetTy = Ty0;
+    } else if (IsMat0 && IsVec1) {
+      auto *MatTy = Ty0->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
+    } else {
+      assert(IsMat0 && IsMat1);
+      auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
+      auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getConstantMatrixType(
+          EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
+    }
+
+    TheCall->setType(RetTy);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_normalize: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl 
b/clang/test/CodeGenHLSL/builtins/mul.hlsl
new file mode 100644
index 0000000000000..0a95d6004e567
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl
@@ -0,0 +1,143 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s 
--check-prefixes=CHECK,DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple 
spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s 
--check-prefixes=CHECK,SPIRV
+
+// -- Case 1: scalar * scalar -> scalar --
+
+// CHECK-LABEL: test_scalar_mulf
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %hlsl.mul = fmul {{.*}} float [[A]], [[B]]
+// CHECK: ret float %hlsl.mul
+export float test_scalar_mulf(float a, float b) { return mul(a, b); }
+
+// CHECK-LABEL: test_scalar_muli
+// CHECK: [[A:%.*]] = load i32, ptr %a.addr
+// CHECK: [[B:%.*]] = load i32, ptr %b.addr
+// CHECK: %hlsl.mul = mul i32 [[A]], [[B]]
+// CHECK: ret i32 %hlsl.mul
+export int test_scalar_muli(int a, int b) { return mul(a, b); }
+
+// -- Case 2: scalar * vector -> vector --
+
+// CHECK-LABEL: test_scalar_vec_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> 
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[B]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_scalar_vec_mul(float a, float3 b) { return mul(a, b); }
+
+// -- Case 3: scalar * matrix -> matrix --
+
+// CHECK-LABEL: test_scalar_mat_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <6 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> 
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[B]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_scalar_mat_mul(float a, float2x3 b) { return mul(a, b); }
+
+// -- Case 4: vector * scalar -> vector --
+
+// CHECK-LABEL: test_vec_scalar_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> 
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[A]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_scalar_mul(float3 a, float b) { return mul(a, b); }
+
+// -- Case 5: vector * vector -> scalar (dot product) --
+
+// CHECK-LABEL: test_vec_vec_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// DXIL: %hlsl.mul = call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> [[A]], 
<3 x float> [[B]])
+// SPIRV: %hlsl.mul = call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float> 
[[A]], <3 x float> [[B]])
+// CHECK: ret float %hlsl.mul
+export float test_vec_vec_mul(float3 a, float3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_muli
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.sdot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.sdot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_mulu
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.udot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.udot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); }
+
+// Double vector dot product: DXIL uses scalar arithmetic, SPIR-V uses fdot
+// CHECK-LABEL: test_vec_vec_muld
+// CHECK: [[A:%.*]] = load <3 x double>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x double>, ptr %b.addr
+// DXIL-NOT: @llvm.dx.fdot
+// DXIL: [[A0:%.*]] = extractelement <3 x double> [[A]], i64 0
+// DXIL: [[B0:%.*]] = extractelement <3 x double> [[B]], i64 0
+// DXIL: [[MUL0:%.*]] = fmul {{.*}} double [[A0]], [[B0]]
+// DXIL: [[A1:%.*]] = extractelement <3 x double> [[A]], i64 1
+// DXIL: [[B1:%.*]] = extractelement <3 x double> [[B]], i64 1
+// DXIL: [[FMA0:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A1]], 
double [[B1]], double [[MUL0]])
+// DXIL: [[A2:%.*]] = extractelement <3 x double> [[A]], i64 2
+// DXIL: [[B2:%.*]] = extractelement <3 x double> [[B]], i64 2
+// DXIL: [[FMA1:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A2]], 
double [[B2]], double [[FMA0]])
+// DXIL: ret double [[FMA1]]
+// SPIRV: %hlsl.mul = call {{.*}} double @llvm.spv.fdot.v3f64(<3 x double> 
[[A]], <3 x double> [[B]])
+// SPIRV: ret double %hlsl.mul
+export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); }
+
+// -- Case 6: vector * matrix -> vector --
+
+// CHECK-LABEL: test_vec_mat_mul
+// CHECK: [[V:%.*]] = load <2 x float>, ptr %v.addr
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: %hlsl.mul = call {{.*}} <3 x float> 
@llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> [[V]], <6 x float> [[M]], 
i32 1, i32 2, i32 3)
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_mat_mul(float2 v, float2x3 m) { return mul(v, m); }
+
+// -- Case 7: matrix * scalar -> matrix --
+
+// CHECK-LABEL: test_mat_scalar_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> 
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[A]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_mat_scalar_mul(float2x3 a, float b) { return mul(a, b); }
+
+// -- Case 8: matrix * vector -> vector --
+
+// CHECK-LABEL: test_mat_vec_mul
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: [[V:%.*]] = load <3 x float>, ptr %v.addr
+// CHECK: %hlsl.mul = call {{.*}} <2 x float> 
@llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> [[M]], <3 x float> [[V]], 
i32 2, i32 3, i32 1)
+// CHECK: ret <2 x float> %hlsl.mul
+export float2 test_mat_vec_mul(float2x3 m, float3 v) { return mul(m, v); }
+
+// -- Case 9: matrix * matrix -> matrix --
+
+// CHECK-LABEL: test_mat_mat_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x float>, ptr %b.addr
+// CHECK: %hlsl.mul = call {{.*}} <8 x float> 
@llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> [[A]], <12 x float> [[B]], 
i32 2, i32 3, i32 4)
+// CHECK: ret <8 x float> %hlsl.mul
+export float2x4 test_mat_mat_mul(float2x3 a, float3x4 b) { return mul(a, b); }
+
+// -- Integer matrix multiply --
+
+// CHECK-LABEL: test_mat_mat_muli
+// CHECK: [[A:%.*]] = load <6 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x i32>, ptr %b.addr
+// CHECK: %hlsl.mul = call <8 x i32> 
@llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> [[A]], <12 x i32> [[B]], i32 
2, i32 3, i32 4)
+// CHECK: ret <8 x i32> %hlsl.mul
+export int2x4 test_mat_mat_muli(int2x3 a, int3x4 b) { return mul(a, b); }
diff --git a/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
new file mode 100644
index 0000000000000..01227126bff10
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
@@ -0,0 +1,42 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.3-library %s -verify
+
+// expected-note@*:* 54 {{candidate template ignored}}
+
+// Test mul error cases via template overload resolution
+
+// Inner dimension mismatch: vector * vector with different sizes
+export float test_vec_dim_mismatch(float2 a, float3 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * matrix
+export float2x4 test_mat_dim_mismatch(float2x3 a, float4x4 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: vector * matrix
+export float3 test_vec_mat_mismatch(float3 v, float2x3 m) {
+  return mul(v, m);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * vector
+export float2 test_mat_vec_mismatch(float2x3 m, float2 v) {
+  return mul(m, v);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different element types
+export float test_type_mismatch(float a, int b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different vector element types
+export float test_vec_type_mismatch(float3 a, int3 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp 
b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index c4bf097e5a0f8..78ac0edc7de47 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
@@ -229,6 +230,7 @@ static bool isIntrinsicExpansion(Function &F) ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184882
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to