llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

fixes #<!-- -->168960

Adds `ICK_HLSL_Matrix_Splat` and hooks it up to `PerformImplicitConversion` and 
`IsMatrixConversion`. Map these to `CK_HLSLAggregateSplatCast`.

---
Full diff: https://github.com/llvm/llvm-project/pull/170885.diff


8 Files Affected:

- (modified) clang/include/clang/Sema/Overload.h (+3) 
- (modified) clang/include/clang/Sema/Sema.h (+4) 
- (modified) clang/lib/Sema/SemaExpr.cpp (+33) 
- (modified) clang/lib/Sema/SemaExprCXX.cpp (+10) 
- (modified) clang/lib/Sema/SemaOverload.cpp (+12) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl (+57) 
- (modified) clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl (+9-3) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl (+11) 


``````````diff
diff --git a/clang/include/clang/Sema/Overload.h 
b/clang/include/clang/Sema/Overload.h
index ab45328ee8ab7..cc9be00e9108c 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -207,6 +207,9 @@ class Sema;
     // HLSL vector splat from scalar or boolean type.
     ICK_HLSL_Vector_Splat,
 
+    /// HLSL matrix splat from scalar or boolean type.
+    ICK_HLSL_Matrix_Splat,
+
     /// The number of conversion kinds
     ICK_Num_Conversion_Kinds,
   };
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..2a32bf8b257ad 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7944,6 +7944,10 @@ class Sema final : public SemaBase {
   /// implicit casts if necessary.
   ExprResult prepareVectorSplat(QualType VectorTy, Expr *SplattedExpr);
 
+  /// Prepare `SplattedExpr` for a matrix splat operation, adding
+  /// implicit casts if necessary.
+  ExprResult prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr);
+
   // CheckExtVectorCast - check type constraints for extended vectors.
   // Since vectors are an extension, there are no C standard reference for 
this.
   // We allow casting between vectors and integer datatypes of the same size,
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index cfabd1b76c103..f5b6855b87c33 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -7806,6 +7806,39 @@ ExprResult Sema::prepareVectorSplat(QualType VectorTy, 
Expr *SplattedExpr) {
   return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
 }
 
+ExprResult Sema::prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr) {
+  QualType DestElemTy = MatrixTy->castAs<MatrixType>()->getElementType();
+
+  if (DestElemTy == SplattedExpr->getType())
+    return SplattedExpr;
+
+  assert(DestElemTy->isFloatingType() ||
+         DestElemTy->isIntegralOrEnumerationType());
+
+  CastKind CK;
+  if (SplattedExpr->getType()->isBooleanType()) {
+    // As with vectors, we want `true` to become -1 when splatting, and we
+    // need a two-step cast if the destination element type is floating.
+    if (DestElemTy->isFloatingType()) {
+      // Cast boolean to signed integral, then to floating.
+      ExprResult CastExprRes = ImpCastExprToType(SplattedExpr, Context.IntTy,
+                                                 CK_BooleanToSignedIntegral);
+      SplattedExpr = CastExprRes.get();
+      CK = CK_IntegralToFloating;
+    } else {
+      CK = CK_BooleanToSignedIntegral;
+    }
+  } else {
+    ExprResult CastExprRes = SplattedExpr;
+    CK = PrepareScalarCast(CastExprRes, DestElemTy);
+    if (CastExprRes.isInvalid())
+      return ExprError();
+    SplattedExpr = CastExprRes.get();
+  }
+
+  return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
+}
+
 ExprResult Sema::CheckExtVectorCast(SourceRange R, QualType DestTy,
                                     Expr *CastExpr, CastKind &Kind) {
   assert(DestTy->isExtVectorType() && "Not an extended vector type!");
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 69719ebd1fc8c..e7af3579be69a 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5198,6 +5198,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
   case ICK_HLSL_Vector_Truncation:
   case ICK_HLSL_Matrix_Truncation:
   case ICK_HLSL_Vector_Splat:
+  case ICK_HLSL_Matrix_Splat:
     llvm_unreachable("Improper second standard conversion");
   }
 
@@ -5217,6 +5218,15 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
                  .get();
       break;
     }
+    case ICK_HLSL_Matrix_Splat: {
+      // Matrix splat from any arithmetic type to a matrix.
+      Expr *Elem = prepareMatrixSplat(ToType, From).get();
+      From =
+          ImpCastExprToType(Elem, ToType, CK_HLSLAggregateSplatCast, 
VK_PRValue,
+                            /*BasePath=*/nullptr, CCK)
+              .get();
+      break;
+    }
     case ICK_HLSL_Vector_Truncation: {
       // Note: HLSL built-in vectors are ExtVectors. Since this truncates a
       // vector to a smaller vector or to a scalar, this can only operate on
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 9a3a78164f0f8..bc3cfe7ef9a0c 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -165,6 +165,7 @@ ImplicitConversionRank 
clang::GetConversionRank(ImplicitConversionKind Kind) {
       ICR_HLSL_Dimension_Reduction,
       ICR_Conversion,
       ICR_HLSL_Scalar_Widening,
+      ICR_HLSL_Scalar_Widening,
   };
   static_assert(std::size(Rank) == (int)ICK_Num_Conversion_Kinds);
   return Rank[(int)Kind];
@@ -228,6 +229,7 @@ static const char 
*GetImplicitConversionName(ImplicitConversionKind Kind) {
       "HLSL matrix truncation",
       "Non-decaying array conversion",
       "HLSL vector splat",
+      "HLSL matrix splat",
   };
   static_assert(std::size(Name) == (int)ICK_Num_Conversion_Kinds);
   return Name[Kind];
@@ -2145,6 +2147,15 @@ static bool IsMatrixConversion(Sema &S, QualType 
FromType, QualType ToType,
       return true;
     return IsVectorOrMatrixElementConversion(S, FromElTy, ToElTy, ICK, From);
   }
+
+  // Matrix splat from any arithmetic type to a matrix.
+  if (ToMatrixType && FromType->isArithmeticType()) {
+    ElConv = ICK_HLSL_Matrix_Splat;
+    QualType ToElTy = ToMatrixType->getElementType();
+    return IsVectorOrMatrixElementConversion(S, FromType, ToElTy, ICK, From);
+    ICK = ICK_HLSL_Matrix_Splat;
+    return true;
+  }
   if (FromMatrixType && !ToMatrixType) {
     ElConv = ICK_HLSL_Matrix_Truncation;
     QualType FromElTy = FromMatrixType->getElementType();
@@ -6301,6 +6312,7 @@ static bool CheckConvertedConstantConversions(Sema &S,
   case ICK_SVE_Vector_Conversion:
   case ICK_RVV_Vector_Conversion:
   case ICK_HLSL_Vector_Splat:
+  case ICK_HLSL_Matrix_Splat:
   case ICK_Vector_Splat:
   case ICK_Complex_Real:
   case ICK_Block_Pointer_Conversion:
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
new file mode 100644
index 0000000000000..802c418f1dad5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
@@ -0,0 +1,57 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py 
UTC_ARGS: --version 6
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.7-library -disable-llvm-passes 
-emit-llvm -finclude-default-header -o - %s | FileCheck %s
+
+// CHECK-LABEL: define hidden void @_Z13ConstantSplatv(
+// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [16 x i32], align 4
+// CHECK-NEXT:    store <16 x i32> splat (i32 1), ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ConstantSplat() {
+    int4x4 M = 1;
+}
+
+// CHECK-LABEL: define hidden void @_Z18ConstantFloatSplatv(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [4 x float], align 4
+// CHECK-NEXT:    store <4 x float> splat (float 3.250000e+00), ptr [[M]], 
align 4
+// CHECK-NEXT:    ret void
+//
+void ConstantFloatSplat() {
+    float2x2 M = 3.25;
+}
+
+// CHECK-LABEL: define hidden void @_Z12DynamicSplatf(
+// CHECK-SAME: float noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca float, align 4
+// CHECK-NEXT:    [[M:%.*]] = alloca [9 x float], align 4
+// CHECK-NEXT:    store float [[VALUE]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load float, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> 
poison, float [[TMP0]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> 
[[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
+// CHECK-NEXT:    store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void DynamicSplat(float Value) {
+    float3x3 M = Value;
+}
+
+// CHECK-LABEL: define hidden void @_Z13CastThenSplatDv4_f(
+// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[VALUE:%.*]]) 
#[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca <4 x float>, align 16
+// CHECK-NEXT:    [[M:%.*]] = alloca [9 x float], align 4
+// CHECK-NEXT:    store <4 x float> [[VALUE]], ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x float>, ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <4 x float> [[TMP0]], 
i32 0
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> 
poison, float [[CAST_VTRUNC]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> 
[[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
+// CHECK-NEXT:    store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void CastThenSplat(float4 Value) {
+    float3x3 M = (float) Value;
+}
diff --git a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl 
b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
index 04149e176edbd..51500a3bcc145 100644
--- a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
@@ -228,12 +228,14 @@ void fn2x2(float2x2) {}
 void fn2x2IO(inout float2x2) {}
 void fnI2x2IO(inout int2x2) {}
 
-void matOrVec(float4 F) {}
-void matOrVec(float2x2 F) {}
+void matOrVec(float4 F) {}  // expected-note {{candidate function}}
+void matOrVec(float2x2 F) {}  // expected-note {{candidate function}}
 
 void matOrVec2(float3 F) {} // expected-note{{candidate function}}
 void matOrVec2(float2x3 F) {} // expected-note{{candidate function}}
 
+void matOrVec3(float4x4 F) {}
+
 export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
   int2x2 i22 = f23;
   // expected-warning@-1{{implicit conversion truncates matrix: 'float2x3' 
(aka 'matrix<float, 2, 3>') to 'int2x2' (aka 'matrix<int, 2, 2>')}}
@@ -269,8 +271,12 @@ export void Case8(float2x3 f23, float4x4 f44, float3x3 
f33, float3x2 f32) {
   //CHECK-NEXT: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' 
<LValueToRValue>
 
 #ifdef ERROR
-  matOrVec(2.0); // TODO: See #168960 this should be ambiguous once we 
implement ICK_HLSL_Matrix_Splat.
+  matOrVec(2.0); // expected-error {{call to 'matOrVec' is ambiguous}}
 #endif
+  matOrVec3(3.14);
+  //CHECK:  ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' 
<HLSLAggregateSplatCast>
+  //CHECK-NEXT: FloatingLiteral {{.*}} <col:13> 'float' 3.140000e+00
+
   matOrVec2(f23);
   //CHECK: DeclRefExpr {{.*}} 'void (float2x3)' lvalue Function {{.*}} 
'matOrVec2' 'void (float2x3)'
   //CHECK-NEXT: ImplicitCastExpr {{.*}} 'float2x3':'matrix<float, 2, 3>' 
<LValueToRValue>
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl 
b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl
new file mode 100644
index 0000000000000..0c2e53d382180
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -verify %s
+
+void SplatOfVectortoMat(int4 V){
+    int2x2 M = V;
+    // expected-error@-1 {{cannot initialize a variable of type 'int2x2' (aka 
'matrix<int, 2, 2>') with an lvalue of type 'int4' (aka 'vector<int, 4>')}}
+}
+
+void SplatOfMattoMat(int4x3 N){
+    int4x4 M = N;
+    // expected-error@-1 {{cannot initialize a variable of type 'matrix<[2 * 
...], 4>' with an lvalue of type 'matrix<[2 * ...], 3>'}}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/170885
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to