https://github.com/Icohedron created 
https://github.com/llvm/llvm-project/pull/184882

Fixes #99138

- Defines a `__builtin_hlsl_mul` clang builtin in `Builtins.td`.
- Links the `__builtin_hlsl_mul` clang builtin with `hlsl_alias_intrinsics.h` 
under the name `mul`
- Adds sema for `__builtin_hlsl_mul` to `CheckBuiltinFunctionCall` in 
`SemaHLSL.cpp`
- Adds codegen for `__builtin_hlsl_mul` to `EmitHLSLBuiltinExpr` in 
`CGHLSLBuiltins.cpp`
    - Vector-vector multiplication uses `dot`, except double vectors in DirectX 
which expand to scalar (fused) multiply-adds.
    - Matrix-matrix, matrix-vector, and vector-matrix multiplication lower to 
the `llvm.matrix.multiply` intrinsic
- Adds codegen tests to `clang/test/CodeGenHLSL/builtins/mul.hlsl`
- Adds sema tests to `clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl`
- Implements lowering of the `llvm.matrix.multiply` intrinsic to DXIL in 
`DXILIntrinsicExpansion.cpp`
    - Currently only supports column-major matrix memory layout

Note: Currently the SPIR-V backend does not support row-major matrix memory 
layouts when lowering matrix multiply, and just assumes column-major layout. 
Therefore this PR also makes the DXIL backend only assume column-major layout. 
Implementing support for different matrix memory layouts in the DXIL and SPIR-V 
backends shall be done in a separate PR.

Assisted-by: claude-opus-4.6

>From 2491daa60da34861eeff2a239b443b984844d6a9 Mon Sep 17 00:00:00 2001
From: Deric Cheung <[email protected]>
Date: Thu, 5 Mar 2026 11:22:33 -0800
Subject: [PATCH] Implement HLSL mul function

- Define a `__builtin_hlsl_mul` clang builtin in `Builtins.td`.
- Links the `__builtin_hlsl_mul` clang builtin with `hlsl_alias_intrinsics.h` 
under the name `mul`
- Adds sema checks for `__builtin_hlsl_mul` to `CheckBuiltinFunctionCall` in 
`SemaHLSL.cpp`
- Adds codegen for `__builtin_hlsl_mul` to `EmitHLSLBuiltinExpr` in 
`CGHLSLBuiltins.cpp`
    - Vector-vector multiplication uses `dot`, except double vectors in DirectX 
which expand to scalar (fused) multiply-adds.
    - Matrix-matrix, matrix-vector, and vector-matrix multiplication lower to 
the `llvm.matrix.multiply` intrinsic
- Adds codegen tests to `clang/test/CodeGenHLSL/builtins/mul.hlsl`
- Adds sema tests to `clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl`
- Implements lowering of the `llvm.matrix.multiply` intrinsic to DXIL in 
`DXILIntrinsicExpansion.cpp`
    - Currently only supports column-major matrix memory layout

Note: Currently the SPIR-V and DXIL backends do not support row-major matrix 
memory layouts when lowering matrix multiply.

Assisted-by: claude-opus-4.6
---
 clang/include/clang/Basic/Builtins.td         |   6 +
 clang/lib/CodeGen/CGHLSLBuiltins.cpp          |  91 +++++
 .../lib/Headers/hlsl/hlsl_alias_intrinsics.h  |  69 ++++
 clang/lib/Sema/SemaHLSL.cpp                   |  57 +++
 clang/test/CodeGenHLSL/builtins/mul.hlsl      | 143 ++++++++
 clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl  |  42 +++
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  96 +++++
 llvm/test/CodeGen/DirectX/matrix-multiply.ll  | 342 ++++++++++++++++++
 8 files changed, 846 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/mul.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/matrix-multiply.ll

diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 531c3702161f2..c6a0c10449ba9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5312,6 +5312,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLMul : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_mul"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_normalize"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..34e92c42fdef9 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -13,6 +13,7 @@
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
 #include "CodeGenFunction.h"
+#include "llvm/IR/MatrixBuilder.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -1006,6 +1007,96 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     Value *Mul = Builder.CreateNUWMul(M, A);
     return Builder.CreateNUWAdd(Mul, B);
   }
+  case Builtin::BI__builtin_hlsl_mul: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    QualType QTy0 = E->getArg(0)->getType();
+    QualType QTy1 = E->getArg(1)->getType();
+    llvm::Type *T0 = Op0->getType();
+
+    bool IsScalar0 = QTy0->isScalarType();
+    bool IsVec0 = QTy0->isVectorType();
+    bool IsMat0 = QTy0->isConstantMatrixType();
+    bool IsScalar1 = QTy1->isScalarType();
+    bool IsVec1 = QTy1->isVectorType();
+    bool IsMat1 = QTy1->isConstantMatrixType();
+    bool IsFP =
+        QTy0->hasFloatingRepresentation() || QTy1->hasFloatingRepresentation();
+
+    // Cases 1-4, 7: scalar * scalar/vector/matrix or vector/matrix * scalar
+    if (IsScalar0 || IsScalar1) {
+      // Splat scalar to match the other operand's type
+      Value *Scalar = IsScalar0 ? Op0 : Op1;
+      Value *Other = IsScalar0 ? Op1 : Op0;
+      llvm::Type *OtherTy = Other->getType();
+
+      // Note: Matrices are flat vectors in the IR, so the following
+      // if-condition is also true when Other is a matrix, not just a vector.
+      if (OtherTy->isVectorTy()) {
+        unsigned NumElts = cast<FixedVectorType>(OtherTy)->getNumElements();
+        Scalar = Builder.CreateVectorSplat(NumElts, Scalar);
+      }
+
+      if (IsFP)
+        return Builder.CreateFMul(Scalar, Other, "hlsl.mul");
+      return Builder.CreateMul(Scalar, Other, "hlsl.mul");
+    }
+
+    // Case 5: vector * vector -> scalar (dot product)
+    if (IsVec0 && IsVec1) {
+      auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();
+      QualType EltQTy = VecTy0->getElementType();
+
+      // DXIL doesn't have a dot product intrinsic for double vectors,
+      // so expand to scalar multiply-add for DXIL.
+      if (CGM.getTarget().getTriple().isDXIL() &&
+          EltQTy->isSpecificBuiltinType(BuiltinType::Double)) {
+        unsigned NumElts = cast<FixedVectorType>(T0)->getNumElements();
+        Value *Sum = nullptr;
+        for (unsigned I = 0; I < NumElts; ++I) {
+          Value *L = Builder.CreateExtractElement(Op0, I);
+          Value *R = Builder.CreateExtractElement(Op1, I);
+          if (Sum)
+            Sum = Builder.CreateIntrinsic(Sum->getType(), Intrinsic::fmuladd,
+                                          {L, R, Sum});
+          else
+            Sum = Builder.CreateFMul(L, R);
+        }
+        return Sum;
+      }
+
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/T0->getScalarType(),
+          getDotProductIntrinsic(CGM.getHLSLRuntime(), EltQTy),
+          ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.mul");
+    }
+
+    // Cases 6, 8, 9: matrix involved -> use llvm.matrix.multiply
+    llvm::MatrixBuilder MB(Builder);
+    if (IsVec0 && IsMat1) {
+      // vector<N> * matrix<N,M> -> vector<M>
+      // Treat vector as 1×N matrix
+      unsigned N = QTy0->castAs<VectorType>()->getNumElements();
+      auto *MatTy = QTy1->castAs<ConstantMatrixType>();
+      unsigned M = MatTy->getNumColumns();
+      return MB.CreateMatrixMultiply(Op0, Op1, 1, N, M, "hlsl.mul");
+    }
+    if (IsMat0 && IsVec1) {
+      // matrix<M,N> * vector<N> -> vector<M>
+      // Treat vector as N×1 matrix
+      auto *MatTy = QTy0->castAs<ConstantMatrixType>();
+      unsigned Rows = MatTy->getNumRows();
+      unsigned Cols = MatTy->getNumColumns();
+      return MB.CreateMatrixMultiply(Op0, Op1, Rows, Cols, 1, "hlsl.mul");
+    }
+    assert(IsMat0 && IsMat1);
+    // matrix<M,K> * matrix<K,N> -> matrix<M,N>
+    auto *MatTy0 = QTy0->castAs<ConstantMatrixType>();
+    auto *MatTy1 = QTy1->castAs<ConstantMatrixType>();
+    return MB.CreateMatrixMultiply(Op0, Op1, MatTy0->getNumRows(),
+                                   MatTy0->getNumColumns(),
+                                   MatTy1->getNumColumns(), "hlsl.mul");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h 
b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2543401bdfbf9..2e9847803a8a1 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -1775,6 +1775,75 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// mul builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn R mul(X x, Y y)
+/// \brief Multiplies x and y using matrix math.
+/// \param x [in] The first input value. If x is a vector, it is treated as a
+///   row vector.
+/// \param y [in] The second input value. If y is a vector, it is treated as a
+///   column vector.
+///
+/// The inner dimension x-columns and y-rows must be equal. The result has the
+/// dimension x-rows x y-columns. When both x and y are vectors, the result is
+/// a dot product (scalar). Scalar operands are multiplied element-wise.
+///
+/// This function supports 9 overloaded forms:
+///   1. scalar * scalar -> scalar
+///   2. scalar * vector -> vector
+///   3. scalar * matrix -> matrix
+///   4. vector * scalar -> vector
+///   5. vector * vector -> scalar (dot product)
+///   6. vector * matrix -> vector
+///   7. matrix * scalar -> matrix
+///   8. matrix * vector -> vector
+///   9. matrix * matrix -> matrix
+
+// Case 1: scalar * scalar -> scalar
+template <typename T> _HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul) T mul(T, T);
+
+// Case 2: scalar * vector -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(T, vector<T, N>);
+
+// Case 3: scalar * matrix -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(T, matrix<T, R, C>);
+
+// Case 4: vector * scalar -> vector
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, N> mul(vector<T, N>, T);
+
+// Case 5: vector * vector -> scalar (dot product)
+template <typename T, int N>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+T mul(vector<T, N>, vector<T, N>);
+
+// Case 6: vector * matrix -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, C> mul(vector<T, R>, matrix<T, R, C>);
+
+// Case 7: matrix * scalar -> matrix
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, C>, T);
+
+// Case 8: matrix * vector -> vector
+template <typename T, int R, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+vector<T, R> mul(matrix<T, R, C>, vector<T, C>);
+
+// Case 9: matrix * matrix -> matrix
+template <typename T, int R, int K, int C>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_mul)
+matrix<T, R, C> mul(matrix<T, R, K>, matrix<T, K, C>);
+
 
//===----------------------------------------------------------------------===//
 // normalize builtins
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 804ea70aaddce..46a30acd95b68 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3775,6 +3775,63 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_mul: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    Expr *Arg0 = TheCall->getArg(0);
+    Expr *Arg1 = TheCall->getArg(1);
+    QualType Ty0 = Arg0->getType();
+    QualType Ty1 = Arg1->getType();
+
+    auto getElemType = [](QualType T) -> QualType {
+      if (const auto *VTy = T->getAs<VectorType>())
+        return VTy->getElementType();
+      if (const auto *MTy = T->getAs<ConstantMatrixType>())
+        return MTy->getElementType();
+      return T;
+    };
+
+    QualType EltTy0 = getElemType(Ty0);
+
+    bool IsScalar0 = Ty0->isScalarType();
+    bool IsVec0 = Ty0->isVectorType();
+    bool IsMat0 = Ty0->isConstantMatrixType();
+    bool IsScalar1 = Ty1->isScalarType();
+    bool IsVec1 = Ty1->isVectorType();
+    bool IsMat1 = Ty1->isConstantMatrixType();
+
+    QualType RetTy;
+
+    if (IsScalar0 && IsScalar1) {
+      RetTy = EltTy0;
+    } else if (IsScalar0 && IsVec1) {
+      RetTy = Ty1;
+    } else if (IsScalar0 && IsMat1) {
+      RetTy = Ty1;
+    } else if (IsVec0 && IsScalar1) {
+      RetTy = Ty0;
+    } else if (IsVec0 && IsVec1) {
+      RetTy = EltTy0;
+    } else if (IsVec0 && IsMat1) {
+      auto *MatTy = Ty1->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumColumns());
+    } else if (IsMat0 && IsScalar1) {
+      RetTy = Ty0;
+    } else if (IsMat0 && IsVec1) {
+      auto *MatTy = Ty0->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getExtVectorType(EltTy0, MatTy->getNumRows());
+    } else {
+      assert(IsMat0 && IsMat1);
+      auto *MatTy0 = Ty0->castAs<ConstantMatrixType>();
+      auto *MatTy1 = Ty1->castAs<ConstantMatrixType>();
+      RetTy = getASTContext().getConstantMatrixType(
+          EltTy0, MatTy0->getNumRows(), MatTy1->getNumColumns());
+    }
+
+    TheCall->setType(RetTy);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_normalize: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/mul.hlsl 
b/clang/test/CodeGenHLSL/builtins/mul.hlsl
new file mode 100644
index 0000000000000..0a95d6004e567
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/mul.hlsl
@@ -0,0 +1,143 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s 
--check-prefixes=CHECK,DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple 
spirv-unknown-vulkan1.3-library -emit-llvm -o - %s | FileCheck %s 
--check-prefixes=CHECK,SPIRV
+
+// -- Case 1: scalar * scalar -> scalar --
+
+// CHECK-LABEL: test_scalar_mulf
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %hlsl.mul = fmul {{.*}} float [[A]], [[B]]
+// CHECK: ret float %hlsl.mul
+export float test_scalar_mulf(float a, float b) { return mul(a, b); }
+
+// CHECK-LABEL: test_scalar_muli
+// CHECK: [[A:%.*]] = load i32, ptr %a.addr
+// CHECK: [[B:%.*]] = load i32, ptr %b.addr
+// CHECK: %hlsl.mul = mul i32 [[A]], [[B]]
+// CHECK: ret i32 %hlsl.mul
+export int test_scalar_muli(int a, int b) { return mul(a, b); }
+
+// -- Case 2: scalar * vector -> vector --
+
+// CHECK-LABEL: test_scalar_vec_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> 
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[B]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_scalar_vec_mul(float a, float3 b) { return mul(a, b); }
+
+// -- Case 3: scalar * matrix -> matrix --
+
+// CHECK-LABEL: test_scalar_mat_mul
+// CHECK: [[A:%.*]] = load float, ptr %a.addr
+// CHECK: [[B:%.*]] = load <6 x float>, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[A]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> 
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[B]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_scalar_mat_mul(float a, float2x3 b) { return mul(a, b); }
+
+// -- Case 4: vector * scalar -> vector --
+
+// CHECK-LABEL: test_vec_scalar_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <3 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <3 x float> %.splatinsert, <3 x float> 
poison, <3 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <3 x float> %.splat, [[A]]
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_scalar_mul(float3 a, float b) { return mul(a, b); }
+
+// -- Case 5: vector * vector -> scalar (dot product) --
+
+// CHECK-LABEL: test_vec_vec_mul
+// CHECK: [[A:%.*]] = load <3 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x float>, ptr %b.addr
+// DXIL: %hlsl.mul = call {{.*}} float @llvm.dx.fdot.v3f32(<3 x float> [[A]], 
<3 x float> [[B]])
+// SPIRV: %hlsl.mul = call {{.*}} float @llvm.spv.fdot.v3f32(<3 x float> 
[[A]], <3 x float> [[B]])
+// CHECK: ret float %hlsl.mul
+export float test_vec_vec_mul(float3 a, float3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_muli
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.sdot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.sdot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export int test_vec_vec_muli(int3 a, int3 b) { return mul(a, b); }
+
+// CHECK-LABEL: test_vec_vec_mulu
+// CHECK: [[A:%.*]] = load <3 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x i32>, ptr %b.addr
+// DXIL: %hlsl.mul = call i32 @llvm.dx.udot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// SPIRV: %hlsl.mul = call i32 @llvm.spv.udot.v3i32(<3 x i32> [[A]], <3 x i32> 
[[B]])
+// CHECK: ret i32 %hlsl.mul
+export uint test_vec_vec_mulu(uint3 a, uint3 b) { return mul(a, b); }
+
+// Double vector dot product: DXIL uses scalar arithmetic, SPIR-V uses fdot
+// CHECK-LABEL: test_vec_vec_muld
+// CHECK: [[A:%.*]] = load <3 x double>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <3 x double>, ptr %b.addr
+// DXIL-NOT: @llvm.dx.fdot
+// DXIL: [[A0:%.*]] = extractelement <3 x double> [[A]], i64 0
+// DXIL: [[B0:%.*]] = extractelement <3 x double> [[B]], i64 0
+// DXIL: [[MUL0:%.*]] = fmul {{.*}} double [[A0]], [[B0]]
+// DXIL: [[A1:%.*]] = extractelement <3 x double> [[A]], i64 1
+// DXIL: [[B1:%.*]] = extractelement <3 x double> [[B]], i64 1
+// DXIL: [[FMA0:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A1]], 
double [[B1]], double [[MUL0]])
+// DXIL: [[A2:%.*]] = extractelement <3 x double> [[A]], i64 2
+// DXIL: [[B2:%.*]] = extractelement <3 x double> [[B]], i64 2
+// DXIL: [[FMA1:%.*]] = call {{.*}} double @llvm.fmuladd.f64(double [[A2]], 
double [[B2]], double [[FMA0]])
+// DXIL: ret double [[FMA1]]
+// SPIRV: %hlsl.mul = call {{.*}} double @llvm.spv.fdot.v3f64(<3 x double> 
[[A]], <3 x double> [[B]])
+// SPIRV: ret double %hlsl.mul
+export double test_vec_vec_muld(double3 a, double3 b) { return mul(a, b); }
+
+// -- Case 6: vector * matrix -> vector --
+
+// CHECK-LABEL: test_vec_mat_mul
+// CHECK: [[V:%.*]] = load <2 x float>, ptr %v.addr
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: %hlsl.mul = call {{.*}} <3 x float> 
@llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> [[V]], <6 x float> [[M]], 
i32 1, i32 2, i32 3)
+// CHECK: ret <3 x float> %hlsl.mul
+export float3 test_vec_mat_mul(float2 v, float2x3 m) { return mul(v, m); }
+
+// -- Case 7: matrix * scalar -> matrix --
+
+// CHECK-LABEL: test_mat_scalar_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load float, ptr %b.addr
+// CHECK: %.splatinsert = insertelement <6 x float> poison, float [[B]], i64 0
+// CHECK: %.splat = shufflevector <6 x float> %.splatinsert, <6 x float> 
poison, <6 x i32> zeroinitializer
+// CHECK: %hlsl.mul = fmul {{.*}} <6 x float> %.splat, [[A]]
+// CHECK: ret <6 x float> %hlsl.mul
+export float2x3 test_mat_scalar_mul(float2x3 a, float b) { return mul(a, b); }
+
+// -- Case 8: matrix * vector -> vector --
+
+// CHECK-LABEL: test_mat_vec_mul
+// CHECK: [[M:%.*]] = load <6 x float>, ptr %m.addr
+// CHECK: [[V:%.*]] = load <3 x float>, ptr %v.addr
+// CHECK: %hlsl.mul = call {{.*}} <2 x float> 
@llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> [[M]], <3 x float> [[V]], 
i32 2, i32 3, i32 1)
+// CHECK: ret <2 x float> %hlsl.mul
+export float2 test_mat_vec_mul(float2x3 m, float3 v) { return mul(m, v); }
+
+// -- Case 9: matrix * matrix -> matrix --
+
+// CHECK-LABEL: test_mat_mat_mul
+// CHECK: [[A:%.*]] = load <6 x float>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x float>, ptr %b.addr
+// CHECK: %hlsl.mul = call {{.*}} <8 x float> 
@llvm.matrix.multiply.v8f32.v6f32.v12f32(<6 x float> [[A]], <12 x float> [[B]], 
i32 2, i32 3, i32 4)
+// CHECK: ret <8 x float> %hlsl.mul
+export float2x4 test_mat_mat_mul(float2x3 a, float3x4 b) { return mul(a, b); }
+
+// -- Integer matrix multiply --
+
+// CHECK-LABEL: test_mat_mat_muli
+// CHECK: [[A:%.*]] = load <6 x i32>, ptr %a.addr
+// CHECK: [[B:%.*]] = load <12 x i32>, ptr %b.addr
+// CHECK: %hlsl.mul = call <8 x i32> 
@llvm.matrix.multiply.v8i32.v6i32.v12i32(<6 x i32> [[A]], <12 x i32> [[B]], i32 
2, i32 3, i32 4)
+// CHECK: ret <8 x i32> %hlsl.mul
+export int2x4 test_mat_mat_muli(int2x3 a, int3x4 b) { return mul(a, b); }
diff --git a/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
new file mode 100644
index 0000000000000..01227126bff10
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/mul-errors.hlsl
@@ -0,0 +1,42 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.3-library %s -verify
+
+// expected-note@*:* 54 {{candidate template ignored}}
+
+// Test mul error cases via template overload resolution
+
+// Inner dimension mismatch: vector * vector with different sizes
+export float test_vec_dim_mismatch(float2 a, float3 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * matrix
+export float2x4 test_mat_dim_mismatch(float2x3 a, float4x4 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: vector * matrix
+export float3 test_vec_mat_mismatch(float3 v, float2x3 m) {
+  return mul(v, m);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Inner dimension mismatch: matrix * vector
+export float2 test_mat_vec_mismatch(float2x3 m, float2 v) {
+  return mul(m, v);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different element types
+export float test_type_mismatch(float a, int b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
+// Type mismatch: different vector element types
+export float test_vec_type_mismatch(float3 a, int3 b) {
+  return mul(a, b);
+  // expected-error@-1 {{no matching function for call to 'mul'}}
+}
+
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp 
b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index c4bf097e5a0f8..78ac0edc7de47 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
@@ -229,6 +230,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::usub_sat:
   case Intrinsic::vector_reduce_add:
   case Intrinsic::vector_reduce_fadd:
+  case Intrinsic::matrix_multiply:
     return true;
   case Intrinsic::dx_resource_load_rawbuffer:
     return resourceAccessNeeds64BitExpansion(
@@ -1043,6 +1045,97 @@ static Value *expandSignIntrinsic(CallInst *Orig) {
   return Builder.CreateSub(ZextGT, ZextLT);
 }
 
+// Expand llvm.matrix.multiply by extracting row/column vectors and computing
+// dot products.
+// Result[r,c] = dot(row_r(LHS), col_c(RHS))
+// Element (r,c) is at index c*NumRows + r (column-major).
+static Value *expandMatrixMultiply(CallInst *Orig) {
+  Value *LHS = Orig->getArgOperand(0);
+  Value *RHS = Orig->getArgOperand(1);
+  unsigned LHSRows = cast<ConstantInt>(Orig->getArgOperand(2))->getZExtValue();
+  unsigned LHSCols = cast<ConstantInt>(Orig->getArgOperand(3))->getZExtValue();
+  unsigned RHSCols = cast<ConstantInt>(Orig->getArgOperand(4))->getZExtValue();
+
+  auto *RetTy = cast<FixedVectorType>(Orig->getType());
+  Type *EltTy = RetTy->getElementType();
+  bool IsFP = EltTy->isFloatingPointTy();
+
+  IRBuilder<> Builder(Orig);
+
+  // Column-major indexing:
+  // LHS row R, element K: index = K * LHSRows + R
+  // RHS col C, element K: index = C * LHSCols + K
+  // TODO: support row-major indexing
+  Value *Result = PoisonValue::get(RetTy);
+
+  // Extract all scalar elements from LHS and RHS once, then reuse them.
+  unsigned LHSSize = LHSRows * LHSCols;
+  unsigned RHSSize = LHSCols * RHSCols;
+  SmallVector<Value *, 16> LHSElts(LHSSize);
+  SmallVector<Value *, 16> RHSElts(RHSSize);
+  for (unsigned I = 0; I < LHSSize; ++I)
+    LHSElts[I] = Builder.CreateExtractElement(LHS, I);
+  for (unsigned I = 0; I < RHSSize; ++I)
+    RHSElts[I] = Builder.CreateExtractElement(RHS, I);
+
+  // Choose the appropriate scalar-arg dot intrinsic for floats.
+  // K=1 and double types use scalar expansion instead.
+  Intrinsic::ID FloatDotID = Intrinsic::not_intrinsic;
+  bool UseScalarFP = IsFP && (EltTy->isDoubleTy() || LHSCols == 1);
+  if (IsFP && !UseScalarFP) {
+    switch (LHSCols) {
+    case 2:
+      FloatDotID = Intrinsic::dx_dot2;
+      break;
+    case 3:
+      FloatDotID = Intrinsic::dx_dot3;
+      break;
+    case 4:
+      FloatDotID = Intrinsic::dx_dot4;
+      break;
+    default:
+      reportFatalUsageError(
+          "Invalid matrix inner dimension for dot product: must be 2-4");
+      return nullptr;
+    }
+  }
+
+  for (unsigned C = 0; C < RHSCols; ++C) {
+    for (unsigned R = 0; R < LHSRows; ++R) {
+      // Gather row R from LHS and column C from RHS.
+      SmallVector<Value *, 4> RowElts, ColElts;
+      for (unsigned K = 0; K < LHSCols; ++K) {
+        RowElts.push_back(LHSElts[K * LHSRows + R]);
+        ColElts.push_back(RHSElts[C * LHSCols + K]);
+      }
+
+      Value *Dot;
+      if (UseScalarFP) {
+        // Scalar fmul+fadd expansion for double types and K=1.
+        Dot = Builder.CreateFMul(RowElts[0], ColElts[0]);
+        for (unsigned K = 1; K < LHSCols; ++K)
+          Dot = Builder.CreateFAdd(Dot,
+                                   Builder.CreateFMul(RowElts[K], ColElts[K]));
+      } else if (IsFP) {
+        // Emit scalar-arg DXIL dot directly (dx.dot2/dx.dot3/dx.dot4).
+        SmallVector<Value *, 8> Args;
+        Args.append(RowElts.begin(), RowElts.end());
+        Args.append(ColElts.begin(), ColElts.end());
+        Dot = Builder.CreateIntrinsic(EltTy, FloatDotID, Args);
+      } else {
+        // Integer: emit multiply + imad chain.
+        Dot = Builder.CreateMul(RowElts[0], ColElts[0]);
+        for (unsigned K = 1; K < LHSCols; ++K)
+          Dot = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_imad,
+                                        {RowElts[K], ColElts[K], Dot});
+      }
+      unsigned ResIdx = C * LHSRows + R;
+      Result = Builder.CreateInsertElement(Result, Dot, ResIdx);
+    }
+  }
+  return Result;
+}
+
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
   Value *Result = nullptr;
   Intrinsic::ID IntrinsicId = F.getIntrinsicID();
@@ -1144,6 +1237,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::vector_reduce_fadd:
     Result = expandVecReduceAdd(Orig, IntrinsicId);
     break;
+  case Intrinsic::matrix_multiply:
+    Result = expandMatrixMultiply(Orig);
+    break;
   }
   if (Result) {
     Orig->replaceAllUsesWith(Result);
diff --git a/llvm/test/CodeGen/DirectX/matrix-multiply.ll 
b/llvm/test/CodeGen/DirectX/matrix-multiply.ll
new file mode 100644
index 0000000000000..dca270f053a02
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/matrix-multiply.ll
@@ -0,0 +1,342 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py 
UTC_ARGS: --version 6
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s
+
+; Verify that llvm.matrix.multiply is expanded to scalar dot products for DXIL.
+
+declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x 
float>, i32, i32, i32)
+declare <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float>, <6 x 
float>, i32, i32, i32)
+declare <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float>, <3 x 
float>, i32, i32, i32)
+declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x 
i32>, i32, i32, i32)
+declare <16 x float> @llvm.matrix.multiply.v16f32.v16f32.v16f32(<16 x float>, 
<16 x float>, i32, i32, i32)
+declare <4 x float> @llvm.matrix.multiply.v4f32.v16f32.v4f32(<16 x float>, <4 
x float>, i32, i32, i32)
+declare <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double>, <4 
x double>, i32, i32, i32)
+declare <2 x double> @llvm.matrix.multiply.v2f64.v4f64.v2f64(<4 x double>, <2 
x double>, i32, i32, i32)
+declare <6 x float> @llvm.matrix.multiply.v6f32.v2f32.v3f32(<2 x float>, <3 x 
float>, i32, i32, i32)
+declare <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32>, <3 x 
i32>, i32, i32, i32)
+
+; 2x2 float: 4 dot2 calls.
+define <4 x float> @test_float_2x2(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: define <4 x float> @test_float_2x2(
+; CHECK-SAME: <4 x float> [[A:%.*]], <4 x float> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x float> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x float> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x float> [[A]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x float> [[A]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x float> [[B]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x float> [[B]], i64 1
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x float> [[B]], i64 2
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x float> [[B]], i64 3
+; CHECK-NEXT:    [[TMP9:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], 
float [[TMP3]], float [[TMP5]], float [[TMP6]])
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <4 x float> poison, float 
[[TMP9]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP2]], 
float [[TMP4]], float [[TMP5]], float [[TMP6]])
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <4 x float> [[TMP10]], float 
[[TMP11]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], 
float [[TMP3]], float [[TMP7]], float [[TMP8]])
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x float> [[TMP12]], float 
[[TMP13]], i64 2
+; CHECK-NEXT:    [[TMP15:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP2]], 
float [[TMP4]], float [[TMP7]], float [[TMP8]])
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x float> [[TMP14]], float 
[[TMP15]], i64 3
+; CHECK-NEXT:    ret <4 x float> [[TMP16]]
+;
+  %r = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> 
%a, <4 x float> %b, i32 2, i32 2, i32 2)
+  ret <4 x float> %r
+}
+
+; 1x2 * 2x3: 3 dot2 calls.
+define <3 x float> @test_vec_mat(<2 x float> %v, <6 x float> %m) {
+; CHECK-LABEL: define <3 x float> @test_vec_mat(
+; CHECK-SAME: <2 x float> [[V:%.*]], <6 x float> [[M:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x float> [[V]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x float> [[V]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <6 x float> [[M]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <6 x float> [[M]], i64 1
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <6 x float> [[M]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <6 x float> [[M]], i64 3
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <6 x float> [[M]], i64 4
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <6 x float> [[M]], i64 5
+; CHECK-NEXT:    [[TMP9:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], 
float [[TMP2]], float [[TMP3]], float [[TMP4]])
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <3 x float> poison, float 
[[TMP9]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], 
float [[TMP2]], float [[TMP5]], float [[TMP6]])
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <3 x float> [[TMP10]], float 
[[TMP11]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = call float @llvm.dx.dot2.f32(float [[TMP1]], 
float [[TMP2]], float [[TMP7]], float [[TMP8]])
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <3 x float> [[TMP12]], float 
[[TMP13]], i64 2
+; CHECK-NEXT:    ret <3 x float> [[TMP14]]
+;
+  %r = call <3 x float> @llvm.matrix.multiply.v3f32.v2f32.v6f32(<2 x float> 
%v, <6 x float> %m, i32 1, i32 2, i32 3)
+  ret <3 x float> %r
+}
+
+; 2x3 * 3x1: 2 dot3 calls.
+define <2 x float> @test_mat_vec(<6 x float> %m, <3 x float> %v) {
+; CHECK-LABEL: define <2 x float> @test_mat_vec(
+; CHECK-SAME: <6 x float> [[M:%.*]], <3 x float> [[V:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <6 x float> [[M]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <6 x float> [[M]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <6 x float> [[M]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <6 x float> [[M]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <6 x float> [[M]], i64 4
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <6 x float> [[M]], i64 5
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <3 x float> [[V]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <3 x float> [[V]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <3 x float> [[V]], i64 2
+; CHECK-NEXT:    [[TMP10:%.*]] = call float @llvm.dx.dot3.f32(float [[TMP1]], 
float [[TMP3]], float [[TMP5]], float [[TMP7]], float [[TMP8]], float [[TMP9]])
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <2 x float> poison, float 
[[TMP10]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = call float @llvm.dx.dot3.f32(float [[TMP2]], 
float [[TMP4]], float [[TMP6]], float [[TMP7]], float [[TMP8]], float [[TMP9]])
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <2 x float> [[TMP11]], float 
[[TMP12]], i64 1
+; CHECK-NEXT:    ret <2 x float> [[TMP13]]
+;
+  %r = call <2 x float> @llvm.matrix.multiply.v2f32.v6f32.v3f32(<6 x float> 
%m, <3 x float> %v, i32 2, i32 3, i32 1)
+  ret <2 x float> %r
+}
+
+; 2x2 integer: mul + imad chains.
+define <4 x i32> @test_int_2x2(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @test_int_2x2(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i32> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x i32> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i32> [[A]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x i32> [[A]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x i32> [[B]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i32> [[B]], i64 1
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x i32> [[B]], i64 2
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x i32> [[B]], i64 3
+; CHECK-NEXT:    [[TMP9:%.*]] = mul i32 [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP3]], i32 
[[TMP6]], i32 [[TMP9]])
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x i32> poison, i32 
[[TMP10]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = mul i32 [[TMP2]], [[TMP5]]
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP4]], i32 
[[TMP6]], i32 [[TMP12]])
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x i32> [[TMP11]], i32 
[[TMP13]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = mul i32 [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP16:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP3]], i32 
[[TMP8]], i32 [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x i32> [[TMP14]], i32 
[[TMP16]], i64 2
+; CHECK-NEXT:    [[TMP18:%.*]] = mul i32 [[TMP2]], [[TMP7]]
+; CHECK-NEXT:    [[TMP19:%.*]] = call i32 @llvm.dx.imad.i32(i32 [[TMP4]], i32 
[[TMP8]], i32 [[TMP18]])
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i32> [[TMP17]], i32 
[[TMP19]], i64 3
+; CHECK-NEXT:    ret <4 x i32> [[TMP20]]
+;
+  %r = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 
x i32> %b, i32 2, i32 2, i32 2)
+  ret <4 x i32> %r
+}
+
+; 4x4 float: 16 dot4 calls.
+define <16 x float> @test_float_4x4(<16 x float> %a, <16 x float> %b) {
+; CHECK-LABEL: define <16 x float> @test_float_4x4(
+; CHECK-SAME: <16 x float> [[A:%.*]], <16 x float> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <16 x float> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <16 x float> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <16 x float> [[A]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <16 x float> [[A]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <16 x float> [[A]], i64 4
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <16 x float> [[A]], i64 5
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <16 x float> [[A]], i64 6
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <16 x float> [[A]], i64 7
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <16 x float> [[A]], i64 8
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <16 x float> [[A]], i64 9
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <16 x float> [[A]], i64 10
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <16 x float> [[A]], i64 11
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <16 x float> [[A]], i64 12
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <16 x float> [[A]], i64 13
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <16 x float> [[A]], i64 14
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <16 x float> [[A]], i64 15
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x float> [[B]], i64 0
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <16 x float> [[B]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <16 x float> [[B]], i64 2
+; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <16 x float> [[B]], i64 3
+; CHECK-NEXT:    [[TMP21:%.*]] = extractelement <16 x float> [[B]], i64 4
+; CHECK-NEXT:    [[TMP22:%.*]] = extractelement <16 x float> [[B]], i64 5
+; CHECK-NEXT:    [[TMP23:%.*]] = extractelement <16 x float> [[B]], i64 6
+; CHECK-NEXT:    [[TMP24:%.*]] = extractelement <16 x float> [[B]], i64 7
+; CHECK-NEXT:    [[TMP25:%.*]] = extractelement <16 x float> [[B]], i64 8
+; CHECK-NEXT:    [[TMP26:%.*]] = extractelement <16 x float> [[B]], i64 9
+; CHECK-NEXT:    [[TMP27:%.*]] = extractelement <16 x float> [[B]], i64 10
+; CHECK-NEXT:    [[TMP28:%.*]] = extractelement <16 x float> [[B]], i64 11
+; CHECK-NEXT:    [[TMP29:%.*]] = extractelement <16 x float> [[B]], i64 12
+; CHECK-NEXT:    [[TMP30:%.*]] = extractelement <16 x float> [[B]], i64 13
+; CHECK-NEXT:    [[TMP31:%.*]] = extractelement <16 x float> [[B]], i64 14
+; CHECK-NEXT:    [[TMP32:%.*]] = extractelement <16 x float> [[B]], i64 15
+; CHECK-NEXT:    [[TMP33:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], 
float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP34:%.*]] = insertelement <16 x float> poison, float 
[[TMP33]], i64 0
+; CHECK-NEXT:    [[TMP35:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], 
float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP36:%.*]] = insertelement <16 x float> [[TMP34]], float 
[[TMP35]], i64 1
+; CHECK-NEXT:    [[TMP37:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], 
float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP38:%.*]] = insertelement <16 x float> [[TMP36]], float 
[[TMP37]], i64 2
+; CHECK-NEXT:    [[TMP39:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], 
float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP40:%.*]] = insertelement <16 x float> [[TMP38]], float 
[[TMP39]], i64 3
+; CHECK-NEXT:    [[TMP41:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], 
float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP21]], float 
[[TMP22]], float [[TMP23]], float [[TMP24]])
+; CHECK-NEXT:    [[TMP42:%.*]] = insertelement <16 x float> [[TMP40]], float 
[[TMP41]], i64 4
+; CHECK-NEXT:    [[TMP43:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], 
float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP21]], float 
[[TMP22]], float [[TMP23]], float [[TMP24]])
+; CHECK-NEXT:    [[TMP44:%.*]] = insertelement <16 x float> [[TMP42]], float 
[[TMP43]], i64 5
+; CHECK-NEXT:    [[TMP45:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], 
float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP21]], float 
[[TMP22]], float [[TMP23]], float [[TMP24]])
+; CHECK-NEXT:    [[TMP46:%.*]] = insertelement <16 x float> [[TMP44]], float 
[[TMP45]], i64 6
+; CHECK-NEXT:    [[TMP47:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], 
float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP21]], float 
[[TMP22]], float [[TMP23]], float [[TMP24]])
+; CHECK-NEXT:    [[TMP48:%.*]] = insertelement <16 x float> [[TMP46]], float 
[[TMP47]], i64 7
+; CHECK-NEXT:    [[TMP49:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], 
float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP25]], float 
[[TMP26]], float [[TMP27]], float [[TMP28]])
+; CHECK-NEXT:    [[TMP50:%.*]] = insertelement <16 x float> [[TMP48]], float 
[[TMP49]], i64 8
+; CHECK-NEXT:    [[TMP51:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], 
float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP25]], float 
[[TMP26]], float [[TMP27]], float [[TMP28]])
+; CHECK-NEXT:    [[TMP52:%.*]] = insertelement <16 x float> [[TMP50]], float 
[[TMP51]], i64 9
+; CHECK-NEXT:    [[TMP53:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], 
float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP25]], float 
[[TMP26]], float [[TMP27]], float [[TMP28]])
+; CHECK-NEXT:    [[TMP54:%.*]] = insertelement <16 x float> [[TMP52]], float 
[[TMP53]], i64 10
+; CHECK-NEXT:    [[TMP55:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], 
float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP25]], float 
[[TMP26]], float [[TMP27]], float [[TMP28]])
+; CHECK-NEXT:    [[TMP56:%.*]] = insertelement <16 x float> [[TMP54]], float 
[[TMP55]], i64 11
+; CHECK-NEXT:    [[TMP57:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], 
float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP29]], float 
[[TMP30]], float [[TMP31]], float [[TMP32]])
+; CHECK-NEXT:    [[TMP58:%.*]] = insertelement <16 x float> [[TMP56]], float 
[[TMP57]], i64 12
+; CHECK-NEXT:    [[TMP59:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], 
float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP29]], float 
[[TMP30]], float [[TMP31]], float [[TMP32]])
+; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <16 x float> [[TMP58]], float 
[[TMP59]], i64 13
+; CHECK-NEXT:    [[TMP61:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], 
float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP29]], float 
[[TMP30]], float [[TMP31]], float [[TMP32]])
+; CHECK-NEXT:    [[TMP62:%.*]] = insertelement <16 x float> [[TMP60]], float 
[[TMP61]], i64 14
+; CHECK-NEXT:    [[TMP63:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], 
float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP29]], float 
[[TMP30]], float [[TMP31]], float [[TMP32]])
+; CHECK-NEXT:    [[TMP64:%.*]] = insertelement <16 x float> [[TMP62]], float 
[[TMP63]], i64 15
+; CHECK-NEXT:    ret <16 x float> [[TMP64]]
+;
+  %r = call <16 x float> @llvm.matrix.multiply.v16f32.v16f32.v16f32(<16 x 
float> %a, <16 x float> %b, i32 4, i32 4, i32 4)
+  ret <16 x float> %r
+}
+
+; 4x4 * 4x1: 4 dot4 calls.
+define <4 x float> @test_mat4x4_vec4(<16 x float> %m, <4 x float> %v) {
+; CHECK-LABEL: define <4 x float> @test_mat4x4_vec4(
+; CHECK-SAME: <16 x float> [[M:%.*]], <4 x float> [[V:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <16 x float> [[M]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <16 x float> [[M]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <16 x float> [[M]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <16 x float> [[M]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <16 x float> [[M]], i64 4
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <16 x float> [[M]], i64 5
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <16 x float> [[M]], i64 6
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <16 x float> [[M]], i64 7
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <16 x float> [[M]], i64 8
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <16 x float> [[M]], i64 9
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <16 x float> [[M]], i64 10
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <16 x float> [[M]], i64 11
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <16 x float> [[M]], i64 12
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <16 x float> [[M]], i64 13
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <16 x float> [[M]], i64 14
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <16 x float> [[M]], i64 15
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x float> [[V]], i64 0
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <4 x float> [[V]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <4 x float> [[V]], i64 2
+; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x float> [[V]], i64 3
+; CHECK-NEXT:    [[TMP21:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP1]], 
float [[TMP5]], float [[TMP9]], float [[TMP13]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x float> poison, float 
[[TMP21]], i64 0
+; CHECK-NEXT:    [[TMP23:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP2]], 
float [[TMP6]], float [[TMP10]], float [[TMP14]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x float> [[TMP22]], float 
[[TMP23]], i64 1
+; CHECK-NEXT:    [[TMP25:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP3]], 
float [[TMP7]], float [[TMP11]], float [[TMP15]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP26:%.*]] = insertelement <4 x float> [[TMP24]], float 
[[TMP25]], i64 2
+; CHECK-NEXT:    [[TMP27:%.*]] = call float @llvm.dx.dot4.f32(float [[TMP4]], 
float [[TMP8]], float [[TMP12]], float [[TMP16]], float [[TMP17]], float 
[[TMP18]], float [[TMP19]], float [[TMP20]])
+; CHECK-NEXT:    [[TMP28:%.*]] = insertelement <4 x float> [[TMP26]], float 
[[TMP27]], i64 3
+; CHECK-NEXT:    ret <4 x float> [[TMP28]]
+;
+  %r = call <4 x float> @llvm.matrix.multiply.v4f32.v16f32.v4f32(<16 x float> 
%m, <4 x float> %v, i32 4, i32 4, i32 1)
+  ret <4 x float> %r
+}
+
+; 2x2 double: scalar fmul + fadd chains.
+define <4 x double> @test_double_2x2(<4 x double> %a, <4 x double> %b) {
+; CHECK-LABEL: define <4 x double> @test_double_2x2(
+; CHECK-SAME: <4 x double> [[A:%.*]], <4 x double> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x double> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x double> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x double> [[A]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x double> [[A]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x double> [[B]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x double> [[B]], i64 1
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x double> [[B]], i64 2
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x double> [[B]], i64 3
+; CHECK-NEXT:    [[TMP9:%.*]] = fmul double [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = fmul double [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = fadd double [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <4 x double> poison, double 
[[TMP11]], i64 0
+; CHECK-NEXT:    [[TMP13:%.*]] = fmul double [[TMP2]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = fmul double [[TMP4]], [[TMP6]]
+; CHECK-NEXT:    [[TMP15:%.*]] = fadd double [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x double> [[TMP12]], double 
[[TMP15]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = fmul double [[TMP1]], [[TMP7]]
+; CHECK-NEXT:    [[TMP18:%.*]] = fmul double [[TMP3]], [[TMP8]]
+; CHECK-NEXT:    [[TMP19:%.*]] = fadd double [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x double> [[TMP16]], double 
[[TMP19]], i64 2
+; CHECK-NEXT:    [[TMP21:%.*]] = fmul double [[TMP2]], [[TMP7]]
+; CHECK-NEXT:    [[TMP22:%.*]] = fmul double [[TMP4]], [[TMP8]]
+; CHECK-NEXT:    [[TMP23:%.*]] = fadd double [[TMP21]], [[TMP22]]
+; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x double> [[TMP20]], double 
[[TMP23]], i64 3
+; CHECK-NEXT:    ret <4 x double> [[TMP24]]
+;
+  %r = call <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double> 
%a, <4 x double> %b, i32 2, i32 2, i32 2)
+  ret <4 x double> %r
+}
+
+; 2x2 double * 2x1 double: 2 scalar fmul + fadd chains.
+define <2 x double> @test_double_mat_vec(<4 x double> %m, <2 x double> %v) {
+; CHECK-LABEL: define <2 x double> @test_double_mat_vec(
+; CHECK-SAME: <4 x double> [[M:%.*]], <2 x double> [[V:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x double> [[M]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x double> [[M]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x double> [[M]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x double> [[M]], i64 3
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x double> [[V]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[V]], i64 1
+; CHECK-NEXT:    [[TMP7:%.*]] = fmul double [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = fmul double [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP9:%.*]] = fadd double [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <2 x double> poison, double 
[[TMP9]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul double [[TMP2]], [[TMP5]]
+; CHECK-NEXT:    [[TMP12:%.*]] = fmul double [[TMP4]], [[TMP6]]
+; CHECK-NEXT:    [[TMP13:%.*]] = fadd double [[TMP11]], [[TMP12]]
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <2 x double> [[TMP10]], double 
[[TMP13]], i64 1
+; CHECK-NEXT:    ret <2 x double> [[TMP14]]
+;
+  %r = call <2 x double> @llvm.matrix.multiply.v2f64.v4f64.v2f64(<4 x double> 
%m, <2 x double> %v, i32 2, i32 2, i32 1)
+  ret <2 x double> %r
+}
+
+; K=1 float outer product (2x1 * 1x3 = 2x3): each element is a single fmul.
+define <6 x float> @test_k1_outer_product(<2 x float> %a, <3 x float> %b) {
+; CHECK-LABEL: define <6 x float> @test_k1_outer_product(
+; CHECK-SAME: <2 x float> [[A:%.*]], <3 x float> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x float> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x float> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <3 x float> [[B]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <3 x float> [[B]], i64 1
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x float> [[B]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = fmul float [[TMP1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <6 x float> poison, float 
[[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = fmul float [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <6 x float> [[TMP7]], float 
[[TMP8]], i64 1
+; CHECK-NEXT:    [[TMP10:%.*]] = fmul float [[TMP1]], [[TMP4]]
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <6 x float> [[TMP9]], float 
[[TMP10]], i64 2
+; CHECK-NEXT:    [[TMP12:%.*]] = fmul float [[TMP2]], [[TMP4]]
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <6 x float> [[TMP11]], float 
[[TMP12]], i64 3
+; CHECK-NEXT:    [[TMP14:%.*]] = fmul float [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <6 x float> [[TMP13]], float 
[[TMP14]], i64 4
+; CHECK-NEXT:    [[TMP16:%.*]] = fmul float [[TMP2]], [[TMP5]]
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <6 x float> [[TMP15]], float 
[[TMP16]], i64 5
+; CHECK-NEXT:    ret <6 x float> [[TMP17]]
+;
+  %r = call <6 x float> @llvm.matrix.multiply.v6f32.v2f32.v3f32(<2 x float> 
%a, <3 x float> %b, i32 2, i32 1, i32 3)
+  ret <6 x float> %r
+}
+
+; K=1 integer outer product (2x1 * 1x3 = 2x3): each element is a single mul.
+define <6 x i32> @test_k1_int_outer_product(<2 x i32> %a, <3 x i32> %b) {
+; CHECK-LABEL: define <6 x i32> @test_k1_int_outer_product(
+; CHECK-SAME: <2 x i32> [[A:%.*]], <3 x i32> [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i32> [[A]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x i32> [[A]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <3 x i32> [[B]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <3 x i32> [[B]], i64 1
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x i32> [[B]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i32 [[TMP1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <6 x i32> poison, i32 [[TMP6]], 
i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <6 x i32> [[TMP7]], i32 
[[TMP8]], i64 1
+; CHECK-NEXT:    [[TMP10:%.*]] = mul i32 [[TMP1]], [[TMP4]]
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <6 x i32> [[TMP9]], i32 
[[TMP10]], i64 2
+; CHECK-NEXT:    [[TMP12:%.*]] = mul i32 [[TMP2]], [[TMP4]]
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <6 x i32> [[TMP11]], i32 
[[TMP12]], i64 3
+; CHECK-NEXT:    [[TMP14:%.*]] = mul i32 [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <6 x i32> [[TMP13]], i32 
[[TMP14]], i64 4
+; CHECK-NEXT:    [[TMP16:%.*]] = mul i32 [[TMP2]], [[TMP5]]
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <6 x i32> [[TMP15]], i32 
[[TMP16]], i64 5
+; CHECK-NEXT:    ret <6 x i32> [[TMP17]]
+;
+  %r = call <6 x i32> @llvm.matrix.multiply.v6i32.v2i32.v3i32(<2 x i32> %a, <3 
x i32> %b, i32 2, i32 1, i32 3)
+  ret <6 x i32> %r
+}

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to