Author: Sarah Spall
Date: 2025-11-06T08:06:48-08:00
New Revision: 4d67e157682d0953761593955ae6c87ed65b3274

URL: 
https://github.com/llvm/llvm-project/commit/4d67e157682d0953761593955ae6c87ed65b3274
DIFF: 
https://github.com/llvm/llvm-project/commit/4d67e157682d0953761593955ae6c87ed65b3274.diff

LOG: [HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to 
constant expression evaluator (#164700)

Add support to handle these casts in the constant expression evaluator. 

- HLSLAggregateSplatCast
- HLSLElementwiseCast
- HLSLArrayRValue

Add tests 
Closes #125766
Closes #125321

Added: 
    clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
    clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl

Modified: 
    clang/lib/AST/ExprConstant.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 193f87ca1807d..4f6f52bf839e9 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3829,6 +3829,351 @@ static bool CheckArraySize(EvalInfo &Info, const 
ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr 
*E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      QualType IntType = DestTy->isIntegerType()
+                             ? DestTy
+                             : Info.Ctx.getIntTypeForBitwidth(64, false);
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, IntType));
+    }
+    if (DestTy->isRealFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, false),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          
cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I)
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          ReverseList.emplace_back(&Res->getStructBase(0), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, 
Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t ArrSize =
+          
cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < ArrSize; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool hlslAggSplatHelper(EvalInfo &Info, const Expr *E, APValue &SrcVal,
+                               QualType &SrcTy) {
+  SrcTy = E->getType();
+
+  if (!Evaluate(SrcVal, Info, E))
+    return false;
+
+  assert(SrcVal.isFloat() || SrcVal.isInt() ||
+         (SrcVal.isVector() && SrcVal.getVectorLength() == 1) &&
+             "Not a valid HLSLAggregateSplatCast.");
+
+  if (SrcVal.isVector()) {
+    assert(SrcTy->isVectorType() && "Type mismatch.");
+    SrcTy = SrcTy->castAs<VectorType>()->getElementType();
+    SrcVal = SrcVal.getVectorElt(0);
+  }
+  return true;
+}
+
+static bool flattenAPValue(EvalInfo &Info, const Expr *E, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) {
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -4639,6 +4984,30 @@ handleLValueToRValueConversion(EvalInfo &Info, const 
Expr *Conv, QualType Type,
   return Obj && extractSubobject(Info, Conv, Obj, LVal.Designator, RVal, AK);
 }
 
+static bool hlslElementwiseCastHelper(EvalInfo &Info, const Expr *E,
+                                      QualType DestTy,
+                                      SmallVectorImpl<APValue> &SrcVals,
+                                      SmallVectorImpl<QualType> &SrcTypes) {
+  APValue Val;
+  if (!Evaluate(Val, Info, E))
+    return false;
+
+  // must be dealing with a record
+  if (Val.isLValue()) {
+    LValue LVal;
+    LVal.setFrom(Info.Ctx, Val);
+    if (!handleLValueToRValueConversion(Info, E, E->getType(), LVal, Val))
+      return false;
+  }
+
+  unsigned NEls = elementwiseSize(Info, DestTy);
+  // flatten the source
+  if (!flattenAPValue(Info, E, Val, E->getType(), SrcVals, SrcTypes, NEls))
+    return false;
+
+  return true;
+}
+
 /// Perform an assignment of Val to LVal. Takes ownership of Val.
 static bool handleAssignment(EvalInfo &Info, const Expr *E, const LValue &LVal,
                              QualType LValType, APValue &Val) {
@@ -8670,6 +9039,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain 
cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10854,6 +11242,42 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    QualType ValTy;
+
+    if (!hlslAggSplatHelper(Info, E->getSubExpr(), Val, ValTy))
+      return false;
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // splat our Val
+    SmallVector<APValue> SplatEls(NEls, Val);
+    SmallVector<QualType> SplatType(NEls, ValTy);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return false;
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, E->getSubExpr(), E->getType(), SrcEls,
+                                   SrcTypes))
+      return false;
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return false;
+
+    return true;
+  }
   }
 }
 
@@ -11349,6 +11773,38 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    QualType ValTy;
+
+    if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+      return false;
+
+    // cast our Val once.
+    APValue Result;
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!handleScalarCast(Info, FPO, E, ValTy, VTy->getElementType(), Val,
+                          Result))
+      return false;
+
+    SmallVector<APValue, 4> SplatEls(NElts, Result);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcVals;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcVals, SrcTypes))
+      return false;
+
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, SrcVals, SrcTypes, DestTypes,
+                               ResultEls))
+      return false;
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13316,6 +13772,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13386,6 +13843,49 @@ static bool MaybeElementDependentArrayFiller(const 
Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    QualType ValTy;
+
+    if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+      return false;
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+
+    SmallVector<APValue> SplatEls(NEls, Val);
+    SmallVector<QualType> SplatType(NEls, ValTy);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return false;
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcEls, SrcTypes))
+      return false;
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return false;
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -17192,7 +17692,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) 
{
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -17339,6 +17838,21 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcVals;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, SubExpr, DestType, SrcVals, SrcTypes))
+      return false;
+
+    // cast our single element
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    APValue ResultVal;
+    if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestType, SrcVals[0],
+                          ResultVal))
+      return false;
+    return Success(ResultVal, E);
+  }
   }
 
   llvm_unreachable("unknown cast resulting in integral value");
@@ -17876,6 +18390,9 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
+  case CK_HLSLAggregateSplatCast:
+    llvm_unreachable("invalid cast kind for floating value");
+
   case CK_IntegralToFloating: {
     APSInt IntResult;
     const FPOptions FPO = E->getFPFeaturesInEffect(
@@ -17914,6 +18431,23 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcVals;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, SubExpr, E->getType(), SrcVals,
+                                   SrcTypes))
+      return false;
+    APValue Val;
+
+    // cast our single element
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    APValue ResultVal;
+    if (!handleScalarCast(Info, FPO, E, SrcTypes[0], E->getType(), SrcVals[0],
+                          ResultVal))
+      return false;
+    return Success(ResultVal, E);
+  }
   }
 }
 

diff  --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl 
b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
new file mode 100644
index 0000000000000..630acd8297642
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -0,0 +1,89 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -fnative-half-type -fnative-int16-type -std=hlsl202x 
-verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+  double D;
+  uint64_t2 U;
+  int16_t I : 5;
+  uint16_t I2: 5;
+};
+
+struct R : Base {
+  int G : 10;
+  int : 30;
+  float F;
+};
+
+struct B1 {
+  float A;
+  float B;
+};
+
+struct B2 : B1 {
+  int C;
+  int D;
+  bool BB;
+};
+
+// tests for HLSLAggregateSplatCast
+export void fn() {
+  // result type vector
+  // splat from a vector of size 1
+  
+  constexpr float1 Y = {1.0};
+  constexpr float4 F4 = (float4)Y;
+  _Static_assert(F4[0] == 1.0, "Woo!");
+  _Static_assert(F4[1] == 1.0, "Woo!");
+  _Static_assert(F4[2] == 1.0, "Woo!");
+  _Static_assert(F4[3] == 1.0, "Woo!");
+
+  // result type array
+  // splat from a scalar
+  constexpr float F = 3.33;
+  constexpr int B6[6] = (int[6])F;
+  _Static_assert(B6[0] == 3, "Woo!");
+  _Static_assert(B6[1] == 3, "Woo!");
+  _Static_assert(B6[2] == 3, "Woo!");
+  _Static_assert(B6[3] == 3, "Woo!");
+  _Static_assert(B6[4] == 3, "Woo!");
+  _Static_assert(B6[5] == 3, "Woo!");
+
+  // splat from a vector of size 1
+  constexpr int1 A1 = {1};
+  constexpr uint64_t2 A7[2] = (uint64_t2[2])A1;
+  _Static_assert(A7[0][0] == 1, "Woo!");
+  _Static_assert(A7[0][1] == 1, "Woo!");
+  _Static_assert(A7[1][0] == 1, "Woo!");
+  _Static_assert(A7[1][1] == 1, "Woo!");
+
+  // result type struct
+  // splat from a scalar
+  constexpr double D = 97.6789;
+  constexpr R SR = (R)(D + 3.0);
+  _Static_assert(SR.D == 100.6789, "Woo!");
+  _Static_assert(SR.U[0] == 100, "Woo!");
+  _Static_assert(SR.U[1] == 100, "Woo!");
+  _Static_assert(SR.I == 4, "Woo!");
+  _Static_assert(SR.I2 == 4, "Woo!");
+  _Static_assert(SR.G == 100, "Woo!");
+  _Static_assert(SR.F == 100.6789, "Woo!");
+
+  // splat from a vector of size 1
+  constexpr float1 A100 = {1000.1111};
+  constexpr B2 SB2 = (B2)A100;
+  _Static_assert(SB2.A == 1000.1111, "Woo!");
+  _Static_assert(SB2.B == 1000.1111, "Woo!");
+  _Static_assert(SB2.C == 1000, "Woo!");
+  _Static_assert(SB2.D == 1000, "Woo!");
+  _Static_assert(SB2.BB == true, "Woo!");
+
+  // splat from a bool to an int and float etc
+  constexpr bool B = true;
+  constexpr B2 SB3 = (B2)B;
+  _Static_assert(SB3.A == 1.0, "Woo!");
+  _Static_assert(SB3.B == 1.0, "Woo!");
+  _Static_assert(SB3.C == 1, "Woo!");
+  _Static_assert(SB3.D == 1, "Woo!");
+  _Static_assert(SB3.BB == true, "Woo!");
+}

diff  --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl 
b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
new file mode 100644
index 0000000000000..c9963c36ce23a
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -fnative-half-type -fnative-int16-type -std=hlsl202x 
-verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+  double D;
+  uint64_t2 U;
+  int16_t I : 5;
+  uint16_t I2: 5;
+};
+
+struct R : Base {
+  int G : 10;
+  int : 30;
+  float F;
+};
+
+struct B1 {
+  float A;
+  float B;
+};
+
+struct B2 : B1 {
+  int C;
+  int D;
+  bool BB;
+};
+
+export void fn() {
+  _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
+
+  // This compiling successfully verifies that the array constant expression
+  // gets truncated to a float at compile time for instantiation via the
+  // flat cast
+  _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
+
+  // truncation tests
+  // result type int
+  // truncate from struct
+  constexpr B1 SB1 = {1.0, 3.0};
+  constexpr int X = (int)SB1;
+  _Static_assert(X == 1, "Woo!");
+
+  // result type float
+  // truncate from array
+  constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
+  constexpr float F = (float)Arr;
+  _Static_assert(F == 4.0, "Woo!");
+
+  // result type vector
+  // truncate from array of vector
+  constexpr int2 Arr2[2] = {5,6,7,8};
+  constexpr int2 I2 = (int2)Arr2;
+  _Static_assert(I2[0] == 5, "Woo!");
+  _Static_assert(I2[1] == 6, "Woo!");
+
+  // lhs and rhs are same "size" tests
+  
+  // result type vector from  array
+  constexpr int4 I4 = (int4)Arr;
+  _Static_assert(I4[0] == 4, "Woo!");
+  _Static_assert(I4[1] == 3, "Woo!");
+  _Static_assert(I4[2] == 2, "Woo!");
+  _Static_assert(I4[3] == 1, "Woo!");
+
+  // result type array from vector
+  constexpr double3 D3 = {100.11, 200.11, 300.11};
+  constexpr float FArr[3] = (float[3])D3;
+  _Static_assert(FArr[0] == 100.11, "Woo!");
+  _Static_assert(FArr[1] == 200.11, "Woo!");
+  _Static_assert(FArr[2] == 300.11, "Woo!");
+
+  // result type struct from struct
+  constexpr B2 SB2 = {5.5, 6.5, 1000, 5000, false};
+  constexpr Base SB = (Base)SB2;
+  _Static_assert(SB.D == 5.5, "Woo!");
+  _Static_assert(SB.U[0] == 6, "Woo!");
+  _Static_assert(SB.U[1] == 1000, "Woo!");
+  _Static_assert(SB.I == 8, "Woo!");
+  _Static_assert(SB.I2 == 0, "Woo!");
+
+  // Make sure we read bitfields correctly
+  constexpr Base BB = {222.22, {100, 200}, -2, 7};
+  constexpr int Arr3[5] = (int[5])BB;
+  _Static_assert(Arr3[0] == 222, "Woo!");
+  _Static_assert(Arr3[1] == 100, "Woo!");
+  _Static_assert(Arr3[2] == 200, "Woo!");
+  _Static_assert(Arr3[3] == -2, "Woo!");
+  _Static_assert(Arr3[4] == 7, "Woo!");
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to