Author: Deric C.
Date: 2026-03-23T14:49:53-07:00
New Revision: bcfe1e94d3f062139aae009e44ae458f69cf2df2

URL: 
https://github.com/llvm/llvm-project/commit/bcfe1e94d3f062139aae009e44ae458f69cf2df2
DIFF: 
https://github.com/llvm/llvm-project/commit/bcfe1e94d3f062139aae009e44ae458f69cf2df2.diff

LOG: [HLSL] Allow 1x1 matrices to be splatted like scalars (#188119)

Fixes #186859 by allowing 1x1 matrices to be splatted like the scalar
and vec1 cases.

Assisted-by: GitHub Copilot (powered by Claude Opus 4.6)

Added: 
    

Modified: 
    clang/lib/Sema/SemaCast.cpp
    clang/lib/Sema/SemaHLSL.cpp
    clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 5360f8a2908bf..330e5ec699790 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2939,11 +2939,17 @@ bool 
CastOperation::CheckHLSLCStyleCast(CheckedConversionKind CCK) {
   if (Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
     SrcExpr = Self.DefaultLvalueConversion(SrcExpr.get());
     const VectorType *VT = SrcTy->getAs<VectorType>();
+    const ConstantMatrixType *MT = SrcTy->getAs<ConstantMatrixType>();
     // change splat from vec1 case to splat from scalar
     if (VT && VT->getNumElements() == 1)
       SrcExpr = Self.ImpCastExprToType(
           SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
           SrcExpr.get()->getValueKind(), nullptr, CCK);
+    // change splat from 1x1 matrix case to splat from scalar
+    else if (MT && MT->getNumElementsFlattened() == 1)
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), MT->getElementType(), CK_HLSLMatrixTruncation,
+          SrcExpr.get()->getValueKind(), nullptr, CCK);
     // Inserting a scalar cast here allows for a simplified codegen in
     // the case the destTy is a vector
     if (const VectorType *DVT = DestType->getAs<VectorType>())

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e2f2d6cb75c33..823d312df296e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -4644,8 +4644,8 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, 
QualType DestTy) {
 }
 
 // Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
-// Src is a scalar or a vector of length 1
-// Or if Dest is a vector and Src is a vector of length 1
+// Src is a scalar, a vector of length 1, or a 1x1 matrix
+// Or if Dest is a vector and Src is a vector of length 1 or a 1x1 matrix
 bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
 
   QualType SrcTy = Src->getType();
@@ -4656,13 +4656,18 @@ bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, 
QualType DestTy) {
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+  const ConstantMatrixType *SrcMatTy = SrcTy->getAs<ConstantMatrixType>();
 
-  // Src isn't a scalar or a vector of length 1
-  if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
+  // Src isn't a scalar, a vector of length 1, or a 1x1 matrix
+  if (!SrcTy->isScalarType() &&
+      !(SrcVecTy && SrcVecTy->getNumElements() == 1) &&
+      !(SrcMatTy && SrcMatTy->getNumElementsFlattened() == 1))
     return false;
 
   if (SrcVecTy)
     SrcTy = SrcVecTy->getElementType();
+  else if (SrcMatTy)
+    SrcTy = SrcMatTy->getElementType();
 
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);

diff  --git a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
index 4e6c7537bcaa4..abfea79f0a454 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
@@ -84,6 +84,39 @@ export void call5() {
   S s = (S)A;
 }
 
+// vector splat from 1x1 matrix
+// CHECK-LABEL: define void {{.*}}call9
+// CHECK: [[M:%.*]] = alloca [1 x <1 x float>], align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 4
+// CHECK-NEXT: store <1 x float> {{.*}}, ptr [[M]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[M]], align 4
+// CHECK-NEXT: [[ML:%.*]] = extractelement <1 x float> [[L]], i32 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[ML]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, 
<4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 4
+export void call9() {
+  float1x1 M = {1.0};
+  int4 A = (int4)M;
+}
+
+// struct splat from 1x1 matrix
+// CHECK-LABEL: define void {{.*}}call10
+// CHECK: [[M:%.*]] = alloca [1 x <1 x i32>], align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 1
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[M]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[M]], align 4
+// CHECK-NEXT: [[ML:%.*]] = extractelement <1 x i32> [[L]], i32 0
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 
0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 
0, i32 1
+// CHECK-NEXT: store i32 [[ML]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[ML]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call10() {
+  int1x1 M = {1};
+  S s = (S)M;
+}
+
 struct BFields {
   double DF;
   int E: 15;


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to