Author: Deric C.
Date: 2026-03-05T09:45:39-08:00
New Revision: f8f0f93236b2aee6b63a454c05cc2876568f8972

URL: 
https://github.com/llvm/llvm-project/commit/f8f0f93236b2aee6b63a454c05cc2876568f8972
DIFF: 
https://github.com/llvm/llvm-project/commit/f8f0f93236b2aee6b63a454c05cc2876568f8972.diff

LOG: [HLSL][Matrix] Add APValue and ConstExpr evaluator support for matrices 
(#178762)

Fixes #168935

This PR adds basic support for matrix APValues and a ConstExpr evaluator
for matrices.
- ConstExpr evaluation changes:
  - Matrix initializer list
  - Matrix HLSL elementwise cast
  - Matrix HLSL aggregate splat
  - Vector HLSL matrix truncation
  - Int HLSL matrix truncation
  - Float HLSL matrix truncation
- Matrix APValue:
  - AST dumper and serialization
  - Value flattening

Note that APValue matrices hold its elements in row-major order
irrespective of the `-fmatrix-memory-layout` flag.
The `-fmatrix-memory-layout` is for codegen, not semantics, so the
decision of which memory layout to use for the matrix APValue can be
independent of the memory layout for codegen.

There are also a number of places expecting switch case coverage over
all APValues but which do not currently support matrix APValues. I have
added placeholder llvm_unreachables to these places for the matrix
APValue, as these places can not currently be exercised in clang tests
(AFAIK).

Assisted-by: claude-opus-4.5

Added: 
    clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
    clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl

Modified: 
    clang/include/clang/AST/APValue.h
    clang/include/clang/AST/PropertiesBase.td
    clang/include/clang/AST/TypeBase.h
    clang/lib/AST/APValue.cpp
    clang/lib/AST/ASTImporter.cpp
    clang/lib/AST/ExprConstant.cpp
    clang/lib/AST/ItaniumMangle.cpp
    clang/lib/AST/MicrosoftMangle.cpp
    clang/lib/AST/TextNodeDumper.cpp
    clang/lib/AST/Type.cpp
    clang/lib/CodeGen/CGExprConstant.cpp
    clang/lib/Sema/SemaTemplate.cpp
    clang/test/CodeGenHLSL/BoolMatrix.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/APValue.h 
b/clang/include/clang/AST/APValue.h
index 8a2d6d434792a..3961e4e7fdfe0 100644
--- a/clang/include/clang/AST/APValue.h
+++ b/clang/include/clang/AST/APValue.h
@@ -136,6 +136,7 @@ class APValue {
     ComplexFloat,
     LValue,
     Vector,
+    Matrix,
     Array,
     Struct,
     Union,
@@ -275,6 +276,15 @@ class APValue {
     Vec &operator=(const Vec &) = delete;
     ~Vec() { delete[] Elts; }
   };
+  struct Mat {
+    APValue *Elts = nullptr;
+    unsigned NumRows = 0;
+    unsigned NumCols = 0;
+    Mat() = default;
+    Mat(const Mat &) = delete;
+    Mat &operator=(const Mat &) = delete;
+    ~Mat() { delete[] Elts; }
+  };
   struct Arr {
     APValue *Elts;
     unsigned NumElts, ArrSize;
@@ -308,8 +318,9 @@ class APValue {
 
   // We ensure elsewhere that Data is big enough for LV and MemberPointerData.
   typedef llvm::AlignedCharArrayUnion<void *, APSInt, APFloat, ComplexAPSInt,
-                                      ComplexAPFloat, Vec, Arr, StructData,
-                                      UnionData, AddrLabelDiffData> DataType;
+                                      ComplexAPFloat, Vec, Mat, Arr, 
StructData,
+                                      UnionData, AddrLabelDiffData>
+      DataType;
   static const size_t DataSize = sizeof(DataType);
 
   DataType Data;
@@ -341,6 +352,13 @@ class APValue {
       : Kind(None), AllowConstexprUnknown(false) {
     MakeVector(); setVector(E, N);
   }
+  /// Creates a matrix APValue with given dimensions. The elements
+  /// are read from \p E and assumed to be in row-major order.
+  explicit APValue(const APValue *E, unsigned NumRows, unsigned NumCols)
+      : Kind(None), AllowConstexprUnknown(false) {
+    MakeMatrix();
+    setMatrix(E, NumRows, NumCols);
+  }
   /// Creates an integer complex APValue with the given real and imaginary
   /// values.
   APValue(APSInt R, APSInt I) : Kind(None), AllowConstexprUnknown(false) {
@@ -471,6 +489,7 @@ class APValue {
   bool isComplexFloat() const { return Kind == ComplexFloat; }
   bool isLValue() const { return Kind == LValue; }
   bool isVector() const { return Kind == Vector; }
+  bool isMatrix() const { return Kind == Matrix; }
   bool isArray() const { return Kind == Array; }
   bool isStruct() const { return Kind == Struct; }
   bool isUnion() const { return Kind == Union; }
@@ -573,6 +592,37 @@ class APValue {
     return ((const Vec *)(const void *)&Data)->NumElts;
   }
 
+  unsigned getMatrixNumRows() const {
+    assert(isMatrix() && "Invalid accessor");
+    return ((const Mat *)(const void *)&Data)->NumRows;
+  }
+  unsigned getMatrixNumColumns() const {
+    assert(isMatrix() && "Invalid accessor");
+    return ((const Mat *)(const void *)&Data)->NumCols;
+  }
+  unsigned getMatrixNumElements() const {
+    return getMatrixNumRows() * getMatrixNumColumns();
+  }
+  APValue &getMatrixElt(unsigned Idx) {
+    assert(isMatrix() && "Invalid accessor");
+    assert(Idx < getMatrixNumElements() && "Index out of range");
+    return ((Mat *)(char *)&Data)->Elts[Idx];
+  }
+  const APValue &getMatrixElt(unsigned Idx) const {
+    return const_cast<APValue *>(this)->getMatrixElt(Idx);
+  }
+  APValue &getMatrixElt(unsigned Row, unsigned Col) {
+    assert(isMatrix() && "Invalid accessor");
+    assert(Row < getMatrixNumRows() && "Row index out of range");
+    assert(Col < getMatrixNumColumns() && "Column index out of range");
+    // Matrix elements are stored in row-major order.
+    unsigned I = Row * getMatrixNumColumns() + Col;
+    return getMatrixElt(I);
+  }
+  const APValue &getMatrixElt(unsigned Row, unsigned Col) const {
+    return const_cast<APValue *>(this)->getMatrixElt(Row, Col);
+  }
+
   APValue &getArrayInitializedElt(unsigned I) {
     assert(isArray() && "Invalid accessor");
     assert(I < getArrayInitializedElts() && "Index out of range");
@@ -668,6 +718,11 @@ class APValue {
     for (unsigned i = 0; i != N; ++i)
       InternalElts[i] = E[i];
   }
+  void setMatrix(const APValue *E, unsigned NumRows, unsigned NumCols) {
+    MutableArrayRef<APValue> InternalElts = setMatrixUninit(NumRows, NumCols);
+    for (unsigned i = 0; i != NumRows * NumCols; ++i)
+      InternalElts[i] = E[i];
+  }
   void setComplexInt(APSInt R, APSInt I) {
     assert(R.getBitWidth() == I.getBitWidth() &&
            "Invalid complex int (type mismatch).");
@@ -716,6 +771,11 @@ class APValue {
     new ((void *)(char *)&Data) Vec();
     Kind = Vector;
   }
+  void MakeMatrix() {
+    assert(isAbsent() && "Bad state change");
+    new ((void *)(char *)&Data) Mat();
+    Kind = Matrix;
+  }
   void MakeComplexInt() {
     assert(isAbsent() && "Bad state change");
     new ((void *)(char *)&Data) ComplexAPSInt();
@@ -757,6 +817,15 @@ class APValue {
     V->NumElts = N;
     return {V->Elts, V->NumElts};
   }
+  MutableArrayRef<APValue> setMatrixUninit(unsigned NumRows, unsigned NumCols) 
{
+    assert(isMatrix() && "Invalid accessor");
+    Mat *M = ((Mat *)(char *)&Data);
+    unsigned NumElts = NumRows * NumCols;
+    M->Elts = new APValue[NumElts];
+    M->NumRows = NumRows;
+    M->NumCols = NumCols;
+    return {M->Elts, NumElts};
+  }
   MutableArrayRef<LValuePathEntry>
   setLValueUninit(LValueBase B, const CharUnits &O, unsigned Size,
                   bool OnePastTheEnd, bool IsNullPtr);

diff  --git a/clang/include/clang/AST/PropertiesBase.td 
b/clang/include/clang/AST/PropertiesBase.td
index 4581e55c28027..0011e57ed5ef7 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -353,6 +353,29 @@ let Class = PropertyTypeCase<APValue, "Vector"> in {
     return result;
   }]>;
 }
+let Class = PropertyTypeCase<APValue, "Matrix"> in {
+  def : ReadHelper<[{
+    SmallVector<APValue, 16> buffer;
+    unsigned numElts = node.getMatrixNumElements();
+    for (unsigned i = 0; i < numElts; ++i)
+      buffer.push_back(node.getMatrixElt(i));
+  }]>;
+  def : Property<"numRows", UInt32> {
+    let Read = [{ node.getMatrixNumRows() }];
+  }
+  def : Property<"numCols", UInt32> {
+    let Read = [{ node.getMatrixNumColumns() }];
+  }
+  def : Property<"elements", Array<APValue>> { let Read = [{ buffer }]; }
+  def : Creator<[{
+    APValue result;
+    result.MakeMatrix();
+    (void)result.setMatrixUninit(numRows, numCols);
+    for (unsigned i = 0; i < elements.size(); i++)
+      result.getMatrixElt(i) = elements[i];
+    return result;
+  }]>;
+}
 let Class = PropertyTypeCase<APValue, "Array"> in {
   def : ReadHelper<[{
     SmallVector<APValue, 4> buffer{};

diff  --git a/clang/include/clang/AST/TypeBase.h 
b/clang/include/clang/AST/TypeBase.h
index dc4442bfeb795..ec7845c3b3adb 100644
--- a/clang/include/clang/AST/TypeBase.h
+++ b/clang/include/clang/AST/TypeBase.h
@@ -4435,7 +4435,7 @@ class ConstantMatrixType final : public MatrixType {
   /// row-major order flattened index. Otherwise, returns the column-major 
order
   /// flattened index.
   unsigned getFlattenedIndex(unsigned Row, unsigned Column,
-                             bool IsRowMajor = false) {
+                             bool IsRowMajor = false) const {
     return IsRowMajor ? getRowMajorFlattenedIndex(Row, Column)
                       : getColumnMajorFlattenedIndex(Row, Column);
   }

diff  --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp
index 2e1c8eb3726cf..95b6f7f745ccb 100644
--- a/clang/lib/AST/APValue.cpp
+++ b/clang/lib/AST/APValue.cpp
@@ -333,6 +333,11 @@ APValue::APValue(const APValue &RHS)
     setVector(((const Vec *)(const char *)&RHS.Data)->Elts,
               RHS.getVectorLength());
     break;
+  case Matrix:
+    MakeMatrix();
+    setMatrix(((const Mat *)(const char *)&RHS.Data)->Elts,
+              RHS.getMatrixNumRows(), RHS.getMatrixNumColumns());
+    break;
   case ComplexInt:
     MakeComplexInt();
     setComplexInt(RHS.getComplexIntReal(), RHS.getComplexIntImag());
@@ -414,6 +419,8 @@ void APValue::DestroyDataAndMakeUninit() {
     ((APFixedPoint *)(char *)&Data)->~APFixedPoint();
   else if (Kind == Vector)
     ((Vec *)(char *)&Data)->~Vec();
+  else if (Kind == Matrix)
+    ((Mat *)(char *)&Data)->~Mat();
   else if (Kind == ComplexInt)
     ((ComplexAPSInt *)(char *)&Data)->~ComplexAPSInt();
   else if (Kind == ComplexFloat)
@@ -444,6 +451,7 @@ bool APValue::needsCleanup() const {
   case Union:
   case Array:
   case Vector:
+  case Matrix:
     return true;
   case Int:
     return getInt().needsCleanup();
@@ -580,6 +588,12 @@ void APValue::Profile(llvm::FoldingSetNodeID &ID) const {
       getVectorElt(I).Profile(ID);
     return;
 
+  case Matrix:
+    for (unsigned R = 0, N = getMatrixNumRows(); R != N; ++R)
+      for (unsigned C = 0, M = getMatrixNumColumns(); C != M; ++C)
+        getMatrixElt(R, C).Profile(ID);
+    return;
+
   case Int:
     profileIntValue(ID, getInt());
     return;
@@ -747,6 +761,24 @@ void APValue::printPretty(raw_ostream &Out, const 
PrintingPolicy &Policy,
     Out << '}';
     return;
   }
+  case APValue::Matrix: {
+    const auto *MT = Ty->castAs<ConstantMatrixType>();
+    QualType ElemTy = MT->getElementType();
+    Out << '{';
+    for (unsigned R = 0; R < getMatrixNumRows(); ++R) {
+      if (R != 0)
+        Out << ", ";
+      Out << '{';
+      for (unsigned C = 0; C < getMatrixNumColumns(); ++C) {
+        if (C != 0)
+          Out << ", ";
+        getMatrixElt(R, C).printPretty(Out, Policy, ElemTy, Ctx);
+      }
+      Out << '}';
+    }
+    Out << '}';
+    return;
+  }
   case APValue::ComplexInt:
     Out << getComplexIntReal() << "+" << getComplexIntImag() << "i";
     return;
@@ -1139,6 +1171,7 @@ LinkageInfo LinkageComputer::getLVForValue(const APValue 
&V,
   case APValue::ComplexInt:
   case APValue::ComplexFloat:
   case APValue::Vector:
+  case APValue::Matrix:
     break;
 
   case APValue::AddrLabelDiff:

diff  --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 00af38626a8f7..340a0fb4fb5aa 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -10630,6 +10630,9 @@ ASTNodeImporter::ImportAPValue(const APValue 
&FromValue) {
                Elts.data(), FromValue.getVectorLength());
     break;
   }
+  case APValue::Matrix:
+    // Matrix values cannot currently arise in APValue import contexts.
+    llvm_unreachable("Matrix APValue import not yet supported");
   case APValue::Array:
     Result.MakeArray(FromValue.getArrayInitializedElts(),
                      FromValue.getArraySize());

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index a131cf8d158df..f233ff76632ae 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -1772,6 +1772,7 @@ static bool EvaluateIntegerOrLValue(const Expr *E, 
APValue &Result,
                                     EvalInfo &Info);
 static bool EvaluateFloat(const Expr *E, APFloat &Result, EvalInfo &Info);
 static bool EvaluateComplex(const Expr *E, ComplexValue &Res, EvalInfo &Info);
+static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info);
 static bool EvaluateAtomic(const Expr *E, const LValue *This, APValue &Result,
                            EvalInfo &Info);
 static bool EvaluateAsRValue(EvalInfo &Info, const Expr *E, APValue &Result);
@@ -2602,6 +2603,7 @@ static bool HandleConversionToBool(const APValue &Val, 
bool &Result) {
     Result = Val.getMemberPointerDecl();
     return true;
   case APValue::Vector:
+  case APValue::Matrix:
   case APValue::Array:
   case APValue::Struct:
   case APValue::Union:
@@ -3932,6 +3934,12 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType 
BaseTy) {
       Size += NumEl;
       continue;
     }
+    if (Type->isConstantMatrixType()) {
+      unsigned NumEl =
+          Type->castAs<ConstantMatrixType>()->getNumElementsFlattened();
+      Size += NumEl;
+      continue;
+    }
     if (Type->isConstantArrayType()) {
       QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
                           ->getElementType();
@@ -3982,6 +3990,11 @@ static bool hlslAggSplatHelper(EvalInfo &Info, const 
Expr *E, APValue &SrcVal,
     SrcTy = SrcTy->castAs<VectorType>()->getElementType();
     SrcVal = SrcVal.getVectorElt(0);
   }
+  if (SrcVal.isMatrix()) {
+    assert(SrcTy->isConstantMatrixType() && "Type mismatch.");
+    SrcTy = SrcTy->castAs<ConstantMatrixType>()->getElementType();
+    SrcVal = SrcVal.getMatrixElt(0, 0);
+  }
   return true;
 }
 
@@ -4011,6 +4024,22 @@ static bool flattenAPValue(EvalInfo &Info, const Expr 
*E, APValue Value,
       }
       continue;
     }
+    if (Work.isMatrix()) {
+      assert(Type->isConstantMatrixType() && "Type mismatch.");
+      const auto *MT = Type->castAs<ConstantMatrixType>();
+      QualType ElTy = MT->getElementType();
+      // Matrix elements are flattened in row-major order.
+      for (unsigned Row = 0; Row < Work.getMatrixNumRows() && Populated < Size;
+           Row++) {
+        for (unsigned Col = 0;
+             Col < Work.getMatrixNumColumns() && Populated < Size; Col++) {
+          Elements.push_back(Work.getMatrixElt(Row, Col));
+          Types.push_back(ElTy);
+          Populated++;
+        }
+      }
+      continue;
+    }
     if (Work.isArray()) {
       assert(Type->isConstantArrayType() && "Type mismatch.");
       QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
@@ -7742,6 +7771,7 @@ class APValueToBufferConverter {
     case APValue::FixedPoint:
       // FIXME: We should support these.
 
+    case APValue::Matrix:
     case APValue::Union:
     case APValue::MemberPointer:
     case APValue::AddrLabelDiff: {
@@ -11667,8 +11697,17 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
     return Success(Elements, E);
   }
   case CK_HLSLMatrixTruncation: {
-    // TODO: See #168935. Add matrix truncation support to expr constant.
-    return Error(E);
+    // Matrix truncation occurs in row-major order.
+    APValue Val;
+    if (!EvaluateMatrix(SE, Val, Info))
+      return Error(E);
+    SmallVector<APValue, 16> Elements;
+    for (unsigned Row = 0;
+         Row < Val.getMatrixNumRows() && Elements.size() < NElts; Row++)
+      for (unsigned Col = 0;
+           Col < Val.getMatrixNumColumns() && Elements.size() < NElts; Col++)
+        Elements.push_back(Val.getMatrixElt(Row, Col));
+    return Success(Elements, E);
   }
   case CK_HLSLAggregateSplatCast: {
     APValue Val;
@@ -14594,6 +14633,117 @@ bool 
VectorExprEvaluator::VisitShuffleVectorExpr(const ShuffleVectorExpr *E) {
   return Success(APValue(ResultElements.data(), ResultElements.size()), E);
 }
 
+//===----------------------------------------------------------------------===//
+// Matrix Evaluation
+//===----------------------------------------------------------------------===//
+
+namespace {
+class MatrixExprEvaluator : public ExprEvaluatorBase<MatrixExprEvaluator> {
+  APValue &Result;
+
+public:
+  MatrixExprEvaluator(EvalInfo &Info, APValue &Result)
+      : ExprEvaluatorBaseTy(Info), Result(Result) {}
+
+  bool Success(ArrayRef<APValue> M, const Expr *E) {
+    auto *CMTy = E->getType()->castAs<ConstantMatrixType>();
+    assert(M.size() == CMTy->getNumElementsFlattened());
+    // FIXME: remove this APValue copy.
+    Result = APValue(M.data(), CMTy->getNumRows(), CMTy->getNumColumns());
+    return true;
+  }
+  bool Success(const APValue &M, const Expr *E) {
+    assert(M.isMatrix() && "expected matrix");
+    Result = M;
+    return true;
+  }
+
+  bool VisitCastExpr(const CastExpr *E);
+  bool VisitInitListExpr(const InitListExpr *E);
+};
+} // end anonymous namespace
+
+static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info) {
+  assert(E->isPRValue() && E->getType()->isConstantMatrixType() &&
+         "not a matrix prvalue");
+  return MatrixExprEvaluator(Info, Result).Visit(E);
+}
+
+bool MatrixExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const auto *MT = E->getType()->castAs<ConstantMatrixType>();
+  unsigned NumRows = MT->getNumRows();
+  unsigned NumCols = MT->getNumColumns();
+  unsigned NElts = NumRows * NumCols;
+  QualType EltTy = MT->getElementType();
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    QualType ValTy;
+
+    if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+      return false;
+
+    APValue CastedVal;
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!handleScalarCast(Info, FPO, E, ValTy, EltTy, Val, CastedVal))
+      return false;
+
+    SmallVector<APValue, 16> SplatEls(NElts, CastedVal);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    SmallVector<APValue> SrcVals;
+    SmallVector<QualType> SrcTypes;
+
+    if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcVals, SrcTypes))
+      return false;
+
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<QualType, 16> DestTypes(NElts, EltTy);
+    SmallVector<APValue, 16> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, SrcVals, SrcTypes, DestTypes,
+                               ResultEls))
+      return false;
+    return Success(ResultEls, E);
+  }
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  }
+}
+
+bool MatrixExprEvaluator::VisitInitListExpr(const InitListExpr *E) {
+  const auto *MT = E->getType()->castAs<ConstantMatrixType>();
+  QualType EltTy = MT->getElementType();
+
+  assert(E->getNumInits() == MT->getNumElementsFlattened() &&
+         "Expected number of elements in initializer list to match the number "
+         "of matrix elements");
+
+  SmallVector<APValue, 16> Elements;
+  Elements.reserve(MT->getNumElementsFlattened());
+
+  // The following loop assumes the elements of the matrix InitListExpr are in
+  // row-major order, which matches the row-major ordering assumption of the
+  // matrix APValue.
+  for (unsigned I = 0, N = MT->getNumElementsFlattened(); I < N; ++I) {
+    if (EltTy->isIntegerType()) {
+      llvm::APSInt IntVal;
+      if (!EvaluateInteger(E->getInit(I), IntVal, Info))
+        return false;
+      Elements.push_back(APValue(IntVal));
+    } else {
+      llvm::APFloat FloatVal(0.0);
+      if (!EvaluateFloat(E->getInit(I), FloatVal, Info))
+        return false;
+      Elements.push_back(APValue(FloatVal));
+    }
+  }
+
+  return Success(Elements, E);
+}
+
 
//===----------------------------------------------------------------------===//
 // Array Evaluation
 
//===----------------------------------------------------------------------===//
@@ -18932,8 +19082,10 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
     return Success(Val.getVectorElt(0), E);
   }
   case CK_HLSLMatrixTruncation: {
-    // TODO: See #168935. Add matrix truncation support to expr constant.
-    return Error(E);
+    APValue Val;
+    if (!EvaluateMatrix(SubExpr, Val, Info))
+      return Error(E);
+    return Success(Val.getMatrixElt(0, 0), E);
   }
   case CK_HLSLElementwiseCast: {
     SmallVector<APValue> SrcVals;
@@ -19529,8 +19681,10 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr 
*E) {
     return Success(Val.getVectorElt(0), E);
   }
   case CK_HLSLMatrixTruncation: {
-    // TODO: See #168935. Add matrix truncation support to expr constant.
-    return Error(E);
+    APValue Val;
+    if (!EvaluateMatrix(SubExpr, Val, Info))
+      return Error(E);
+    return Success(Val.getMatrixElt(0, 0), E);
   }
   case CK_HLSLElementwiseCast: {
     SmallVector<APValue> SrcVals;
@@ -20429,6 +20583,9 @@ static bool Evaluate(APValue &Result, EvalInfo &Info, 
const Expr *E) {
   } else if (T->isVectorType()) {
     if (!EvaluateVector(E, Result, Info))
       return false;
+  } else if (T->isConstantMatrixType()) {
+    if (!EvaluateMatrix(E, Result, Info))
+      return false;
   } else if (T->isIntegralOrEnumerationType()) {
     if (!IntExprEvaluator(Info, Result).Visit(E))
       return false;

diff  --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 1faf7f1466e39..eea04b14eaf09 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -6495,6 +6495,9 @@ static bool isZeroInitialized(QualType T, const APValue 
&V) {
     return true;
   }
 
+  case APValue::Matrix:
+    llvm_unreachable("Matrix APValues not yet supported");
+
   case APValue::Int:
     return !V.getInt();
 
@@ -6708,6 +6711,9 @@ void CXXNameMangler::mangleValueInTemplateArg(QualType T, 
const APValue &V,
     break;
   }
 
+  case APValue::Matrix:
+    llvm_unreachable("Matrix template argument mangling not yet supported");
+
   case APValue::Int:
     mangleIntegerLiteral(T, V.getInt());
     break;

diff  --git a/clang/lib/AST/MicrosoftMangle.cpp 
b/clang/lib/AST/MicrosoftMangle.cpp
index 1f28d281be9fe..1bf92d4209f9f 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -2154,6 +2154,11 @@ void 
MicrosoftCXXNameMangler::mangleTemplateArgValue(QualType T,
     return;
   }
 
+  case APValue::Matrix: {
+    Error("template argument (value type: matrix)");
+    return;
+  }
+
   case APValue::AddrLabelDiff: {
     Error("template argument (value type: address label 
diff )");
     return;

diff  --git a/clang/lib/AST/TextNodeDumper.cpp 
b/clang/lib/AST/TextNodeDumper.cpp
index be8b03e1c2b72..89d2a2691e8fb 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -620,6 +620,7 @@ static bool isSimpleAPValue(const APValue &Value) {
   case APValue::Vector:
   case APValue::Array:
   case APValue::Struct:
+  case APValue::Matrix:
     return false;
   case APValue::Union:
     return isSimpleAPValue(Value.getUnionValue());
@@ -812,6 +813,19 @@ void TextNodeDumper::Visit(const APValue &Value, QualType 
Ty) {
 
     return;
   }
+  case APValue::Matrix: {
+    unsigned NumRows = Value.getMatrixNumRows();
+    unsigned NumCols = Value.getMatrixNumColumns();
+    OS << "Matrix " << NumRows << "x" << NumCols;
+
+    dumpAPValueChildren(
+        Value, Ty,
+        [](const APValue &Value, unsigned Index) -> const APValue & {
+          return Value.getMatrixElt(Index);
+        },
+        Value.getMatrixNumElements(), "element", "elements");
+    return;
+  }
   case APValue::Union: {
     OS << "Union";
     {

diff  --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index a85f08753a132..290b91effeacd 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -3084,6 +3084,10 @@ bool Type::isLiteralType(const ASTContext &Ctx) const {
   if (BaseTy->isScalarType() || BaseTy->isVectorType() ||
       BaseTy->isAnyComplexType())
     return true;
+  // Matrices with constant numbers of rows and columns are also literal types
+  // in HLSL.
+  if (Ctx.getLangOpts().HLSL && BaseTy->isConstantMatrixType())
+    return true;
   //    -- a reference type; or
   if (BaseTy->isReferenceType())
     return true;

diff  --git a/clang/lib/CodeGen/CGExprConstant.cpp 
b/clang/lib/CodeGen/CGExprConstant.cpp
index 3f44243e1c35c..0739935acd867 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2548,6 +2548,35 @@ ConstantEmitter::tryEmitPrivate(const APValue &Value, 
QualType DestType,
     }
     return llvm::ConstantVector::get(Inits);
   }
+  case APValue::Matrix: {
+    const auto *MT = DestType->castAs<ConstantMatrixType>();
+    unsigned NumRows = Value.getMatrixNumRows();
+    unsigned NumCols = Value.getMatrixNumColumns();
+    unsigned NumElts = NumRows * NumCols;
+    SmallVector<llvm::Constant *, 16> Inits(NumElts);
+
+    bool IsRowMajor = CGM.getLangOpts().getDefaultMatrixMemoryLayout() ==
+                      LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+
+    for (unsigned Row = 0; Row != NumRows; ++Row) {
+      for (unsigned Col = 0; Col != NumCols; ++Col) {
+        const APValue &Elt = Value.getMatrixElt(Row, Col);
+        unsigned Idx = MT->getFlattenedIndex(Row, Col, IsRowMajor);
+        if (Elt.isInt())
+          Inits[Idx] =
+              llvm::ConstantInt::get(CGM.getLLVMContext(), Elt.getInt());
+        else if (Elt.isFloat())
+          Inits[Idx] =
+              llvm::ConstantFP::get(CGM.getLLVMContext(), Elt.getFloat());
+        else if (Elt.isIndeterminate())
+          Inits[Idx] = llvm::PoisonValue::get(
+              CGM.getTypes().ConvertType(MT->getElementType()));
+        else
+          llvm_unreachable("unsupported matrix element type");
+      }
+    }
+    return llvm::ConstantVector::get(Inits);
+  }
   case APValue::AddrLabelDiff: {
     const AddrLabelExpr *LHSExpr = Value.getAddrLabelDiffLHS();
     const AddrLabelExpr *RHSExpr = Value.getAddrLabelDiffRHS();

diff  --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 9f3bf7437fc54..9b0bec20618a0 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -8138,6 +8138,9 @@ static Expr 
*BuildExpressionFromNonTypeTemplateArgumentValue(
     return MakeInitList(Elts);
   }
 
+  case APValue::Matrix:
+    llvm_unreachable("Matrix template argument expression not yet supported");
+
   case APValue::None:
   case APValue::Indeterminate:
     llvm_unreachable("Unexpected APValue kind.");

diff  --git a/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl 
b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
new file mode 100644
index 0000000000000..98428225ebb3d
--- /dev/null
+++ b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
@@ -0,0 +1,56 @@
+// Test without serialization:
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x \
+// RUN:            -fnative-half-type -fnative-int16-type \
+// RUN:            -ast-dump %s -ast-dump-filter Test \
+// RUN: | FileCheck --strict-whitespace %s
+//
+// Test with serialization:
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x \
+// RUN:            -fnative-half-type -fnative-int16-type -emit-pch -o %t %s
+// RUN: %clang_cc1 -x hlsl -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x \
+// RUN:           -fnative-half-type -fnative-int16-type \
+// RUN:           -include-pch %t -ast-dump-all -ast-dump-filter Test 
/dev/null \
+// RUN: | sed -e "s/ <undeserialized declarations>//" -e "s/ imported//" \
+// RUN: | FileCheck --strict-whitespace %s
+
+export void Test() {
+  constexpr int2x2 mat2x2 = {1, 2, 3, 4};
+  // CHECK: VarDecl {{.*}} mat2x2 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 2x2
+  // CHECK-NEXT: | `-elements: Int 1, Int 2, Int 3, Int 4
+
+  constexpr float3x2 mat3x2 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  // CHECK: VarDecl {{.*}} mat3x2 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 3x2
+  // CHECK-NEXT: | |-elements: Float 1.000000e+00, Float 2.000000e+00, Float 
3.000000e+00, Float 4.000000e+00
+  // CHECK-NEXT: | `-elements: Float 5.000000e+00, Float 6.000000e+00
+
+  constexpr int16_t3x2 i16mat3x2 = {-1, 2, -3, 4, -5, 6};
+  // CHECK: VarDecl {{.*}} i16mat3x2 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 3x2
+  // CHECK-NEXT: | |-elements: Int -1, Int 2, Int -3, Int 4
+  // CHECK-NEXT: | `-elements: Int -5, Int 6
+
+  constexpr int64_t4x1 i64mat4x1 = {100, -200, 300, -400};
+  // CHECK: VarDecl {{.*}} i64mat4x1 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 4x1
+  // CHECK-NEXT: | `-elements: Int 100, Int -200, Int 300, Int -400
+
+  constexpr half2x3 hmat2x3 = {1.5h, -2.5h, 3.5h, -4.5h, 5.5h, -6.5h};
+  // CHECK: VarDecl {{.*}} hmat2x3 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 2x3
+  // CHECK-NEXT: | |-elements: Float 1.500000e+00, Float -2.500000e+00, Float 
3.500000e+00, Float -4.500000e+00
+  // CHECK-NEXT: | `-elements: Float 5.500000e+00, Float -6.500000e+00
+
+  constexpr double1x4 dmat1x4 = {0.5l, -1.25l, 2.75l, -3.125l};
+  // CHECK: VarDecl {{.*}} dmat1x4 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 1x4
+  // CHECK-NEXT: | `-elements: Float 5.000000e-01, Float -1.250000e+00, Float 
2.750000e+00, Float -3.125000e+00
+
+  constexpr bool3x3 bmat3x3 = {true, false, true, false, true, false, true, 
false, true};
+  // CHECK: VarDecl {{.*}} bmat3x3 {{.*}} constexpr cinit
+  // CHECK-NEXT: |-value: Matrix 3x3
+  // CHECK-NEXT: | |-elements: Int 1, Int 0, Int 1, Int 0
+  // CHECK-NEXT: | |-elements: Int 1, Int 0, Int 1, Int 0
+  // CHECK-NEXT: | `-element: Int 1
+}

diff  --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl 
b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
index c101ac02f7891..c61d82635d513 100644
--- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -58,12 +58,9 @@ bool2x2 fn2(bool V) {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
 // CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr 
align 1 @__const._Z3fn3v.s, i32 20, i1 false)
 // CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
-// CHECK-NEXT:    store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], 
align 1
-// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 1
-// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
-// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
-// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[BM]], align 1
 // CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
 // CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
 // CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
@@ -114,15 +111,12 @@ void fn5() {
 // CHECK-NEXT:    [[V:%.*]] = alloca i32, align 4
 // CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
 // CHECK-NEXT:    store i32 0, ptr [[V]], align 4
-// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
-// CHECK-NEXT:    store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], 
align 1
-// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 1
-// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr 
align 1 @__const._Z3fn6v.s, i32 20, i1 false)
 // CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V]], align 4
 // CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
-// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
 // CHECK-NEXT:    [[TMP1:%.*]] = zext i1 [[LOADEDV]] to i32
-// CHECK-NEXT:    [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM1]], i32 0, 
i32 1
+// CHECK-NEXT:    [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM]], i32 0, 
i32 1
 // CHECK-NEXT:    store i32 [[TMP1]], ptr [[TMP2]], align 4
 // CHECK-NEXT:    ret void
 //

diff  --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl 
b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
new file mode 100644
index 0000000000000..64220980d9edc
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -0,0 +1,127 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major 
-verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify 
%s
+
+// expected-no-diagnostics
+
+// Matrix subscripting is not currently supported with matrix constexpr. So all
+// tests involve casting to another type to determine if the output is correct.
+
+export void fn() {
+
+  // Matrix truncation to int - should get element at (0,0)
+  {
+    constexpr int2x3 IM = {1, 2, 3,
+                           4, 5, 6};
+    _Static_assert((int)IM == 1, "Woo!");
+  }
+
+  // Matrix splat to vector
+  {
+    constexpr bool2x2 BM2x2 = true;
+    constexpr bool4 BV4 = (bool4)BM2x2;
+    _Static_assert(BV4.x == true, "Woo!");
+    _Static_assert(BV4.y == true, "Woo!");
+    _Static_assert(BV4.z == true, "Woo!");
+    _Static_assert(BV4.w == true, "Woo!");
+  }
+
+  // Matrix cast to vector
+  {
+    constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5};
+    constexpr float4 FV4 = (float4)FM2x2;
+    _Static_assert(FV4.x == 1.5, "Woo!");
+    _Static_assert(FV4.y == 2.5, "Woo!");
+    _Static_assert(FV4.z == 3.5, "Woo!");
+    _Static_assert(FV4.w == 4.5, "Woo!");
+  }
+
+  // Matrix cast to array
+  {
+    constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5};
+    constexpr float FA4[4] = (float[4])FM2x2;
+    _Static_assert(FA4[0] == 1.5, "Woo!");
+    _Static_assert(FA4[1] == 2.5, "Woo!");
+    _Static_assert(FA4[2] == 3.5, "Woo!");
+    _Static_assert(FA4[3] == 4.5, "Woo!");
+  }
+
+  // Array cast to matrix to vector
+  {
+    constexpr int IA4[4] = {1, 2, 3, 4};
+    constexpr int2x2 IM2x2 = (int2x2)IA4;
+    constexpr int4 IV4 = (int4)IM2x2;
+    _Static_assert(IV4.x == 1, "Woo!");
+    _Static_assert(IV4.y == 2, "Woo!");
+    _Static_assert(IV4.z == 3, "Woo!");
+    _Static_assert(IV4.w == 4, "Woo!");
+  }
+
+  // Vector cast to matrix to vector
+  {
+    constexpr bool4 BV4_0 = {true, false, true, false};
+    constexpr bool2x2 BM2x2 = (bool2x2)BV4_0;
+    constexpr bool4 BV4 = (bool4)BM2x2;
+    _Static_assert(BV4.x == true, "Woo!");
+    _Static_assert(BV4.y == false, "Woo!");
+    _Static_assert(BV4.z == true, "Woo!");
+    _Static_assert(BV4.w == false, "Woo!");
+  }
+
+  // Matrix truncation to vector
+  {
+    constexpr int3x2 IM3x2 = { 1,  2,
+                               3,  4,
+                               5,  6};
+    constexpr int4 IV4 = (int4)IM3x2;
+    _Static_assert(IV4.x == 1, "Woo!");
+    _Static_assert(IV4.y == 2, "Woo!");
+    _Static_assert(IV4.z == 3, "Woo!");
+    _Static_assert(IV4.w == 4, "Woo!");
+  }
+
+  // Matrix truncation to array
+  {
+    constexpr int3x2 IM3x2 = { 1,  2,
+                               3,  4,
+                               5,  6};
+    constexpr int IA4[4] = (int[4])IM3x2;
+    _Static_assert(IA4[0] == 1, "Woo!");
+    _Static_assert(IA4[1] == 2, "Woo!");
+    _Static_assert(IA4[2] == 3, "Woo!");
+    _Static_assert(IA4[3] == 4, "Woo!");
+  }
+
+  // Array cast to matrix truncation to vector
+  {
+    constexpr float FA16[16] = { 1.0,  2.0,  3.0,  4.0,
+                                 5.0,  6.0,  7.0,  8.0,
+                                 9.0, 10.0, 11.0, 12.0,
+                                13.0, 14.0, 15.0, 16.0};
+    constexpr float4x4 FM4x4 = (float4x4)FA16;
+    constexpr float4 FV4 = (float4)FM4x4;
+    _Static_assert(FV4.x == 1.0, "Woo!");
+    _Static_assert(FV4.y == 2.0, "Woo!");
+    _Static_assert(FV4.z == 3.0, "Woo!");
+    _Static_assert(FV4.w == 4.0, "Woo!");
+  }
+
+  // Vector cast to matrix truncation to vector
+  {
+    constexpr bool4 BV4 = {true, false, true, false};
+    constexpr bool2x2 BM2x2 = (bool2x2)BV4;
+    constexpr bool3 BV3 = (bool3)BM2x2;
+    _Static_assert(BV4.x == true, "Woo!");
+    _Static_assert(BV4.y == false, "Woo!");
+    _Static_assert(BV4.z == true, "Woo!");
+  }
+
+  // Matrix cast to vector with type conversion
+  {
+    constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5};
+    constexpr int4 IV4 = (int4)FM2x2;
+    _Static_assert(IV4.x == 1, "Woo!");
+    _Static_assert(IV4.y == 2, "Woo!");
+    _Static_assert(IV4.z == 3, "Woo!");
+    _Static_assert(IV4.w == 4, "Woo!");
+  }
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to