Author: Deric C. Date: 2026-03-05T09:45:39-08:00 New Revision: f8f0f93236b2aee6b63a454c05cc2876568f8972
URL: https://github.com/llvm/llvm-project/commit/f8f0f93236b2aee6b63a454c05cc2876568f8972 DIFF: https://github.com/llvm/llvm-project/commit/f8f0f93236b2aee6b63a454c05cc2876568f8972.diff LOG: [HLSL][Matrix] Add APValue and ConstExpr evaluator support for matrices (#178762) Fixes #168935 This PR adds basic support for matrix APValues and a ConstExpr evaluator for matrices. - ConstExpr evaluation changes: - Matrix initializer list - Matrix HLSL elementwise cast - Matrix HLSL aggregate splat - Vector HLSL matrix truncation - Int HLSL matrix truncation - Float HLSL matrix truncation - Matrix APValue: - AST dumper and serialization - Value flattening Note that APValue matrices hold its elements in row-major order irrespective of the `-fmatrix-memory-layout` flag. The `-fmatrix-memory-layout` is for codegen, not semantics, so the decision of which memory layout to use for the matrix APValue can be independent of the memory layout for codegen. There are also a number of places expecting switch case coverage over all APValues but which do not currently support matrix APValues. I have added placeholder llvm_unreachables to these places for the matrix APValue, as these places can not currently be exercised in clang tests (AFAIK). Assisted-by: claude-opus-4.5 Added: clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl Modified: clang/include/clang/AST/APValue.h clang/include/clang/AST/PropertiesBase.td clang/include/clang/AST/TypeBase.h clang/lib/AST/APValue.cpp clang/lib/AST/ASTImporter.cpp clang/lib/AST/ExprConstant.cpp clang/lib/AST/ItaniumMangle.cpp clang/lib/AST/MicrosoftMangle.cpp clang/lib/AST/TextNodeDumper.cpp clang/lib/AST/Type.cpp clang/lib/CodeGen/CGExprConstant.cpp clang/lib/Sema/SemaTemplate.cpp clang/test/CodeGenHLSL/BoolMatrix.hlsl Removed: ################################################################################ diff --git a/clang/include/clang/AST/APValue.h b/clang/include/clang/AST/APValue.h index 8a2d6d434792a..3961e4e7fdfe0 100644 --- a/clang/include/clang/AST/APValue.h +++ b/clang/include/clang/AST/APValue.h @@ -136,6 +136,7 @@ class APValue { ComplexFloat, LValue, Vector, + Matrix, Array, Struct, Union, @@ -275,6 +276,15 @@ class APValue { Vec &operator=(const Vec &) = delete; ~Vec() { delete[] Elts; } }; + struct Mat { + APValue *Elts = nullptr; + unsigned NumRows = 0; + unsigned NumCols = 0; + Mat() = default; + Mat(const Mat &) = delete; + Mat &operator=(const Mat &) = delete; + ~Mat() { delete[] Elts; } + }; struct Arr { APValue *Elts; unsigned NumElts, ArrSize; @@ -308,8 +318,9 @@ class APValue { // We ensure elsewhere that Data is big enough for LV and MemberPointerData. typedef llvm::AlignedCharArrayUnion<void *, APSInt, APFloat, ComplexAPSInt, - ComplexAPFloat, Vec, Arr, StructData, - UnionData, AddrLabelDiffData> DataType; + ComplexAPFloat, Vec, Mat, Arr, StructData, + UnionData, AddrLabelDiffData> + DataType; static const size_t DataSize = sizeof(DataType); DataType Data; @@ -341,6 +352,13 @@ class APValue { : Kind(None), AllowConstexprUnknown(false) { MakeVector(); setVector(E, N); } + /// Creates a matrix APValue with given dimensions. The elements + /// are read from \p E and assumed to be in row-major order. + explicit APValue(const APValue *E, unsigned NumRows, unsigned NumCols) + : Kind(None), AllowConstexprUnknown(false) { + MakeMatrix(); + setMatrix(E, NumRows, NumCols); + } /// Creates an integer complex APValue with the given real and imaginary /// values. APValue(APSInt R, APSInt I) : Kind(None), AllowConstexprUnknown(false) { @@ -471,6 +489,7 @@ class APValue { bool isComplexFloat() const { return Kind == ComplexFloat; } bool isLValue() const { return Kind == LValue; } bool isVector() const { return Kind == Vector; } + bool isMatrix() const { return Kind == Matrix; } bool isArray() const { return Kind == Array; } bool isStruct() const { return Kind == Struct; } bool isUnion() const { return Kind == Union; } @@ -573,6 +592,37 @@ class APValue { return ((const Vec *)(const void *)&Data)->NumElts; } + unsigned getMatrixNumRows() const { + assert(isMatrix() && "Invalid accessor"); + return ((const Mat *)(const void *)&Data)->NumRows; + } + unsigned getMatrixNumColumns() const { + assert(isMatrix() && "Invalid accessor"); + return ((const Mat *)(const void *)&Data)->NumCols; + } + unsigned getMatrixNumElements() const { + return getMatrixNumRows() * getMatrixNumColumns(); + } + APValue &getMatrixElt(unsigned Idx) { + assert(isMatrix() && "Invalid accessor"); + assert(Idx < getMatrixNumElements() && "Index out of range"); + return ((Mat *)(char *)&Data)->Elts[Idx]; + } + const APValue &getMatrixElt(unsigned Idx) const { + return const_cast<APValue *>(this)->getMatrixElt(Idx); + } + APValue &getMatrixElt(unsigned Row, unsigned Col) { + assert(isMatrix() && "Invalid accessor"); + assert(Row < getMatrixNumRows() && "Row index out of range"); + assert(Col < getMatrixNumColumns() && "Column index out of range"); + // Matrix elements are stored in row-major order. + unsigned I = Row * getMatrixNumColumns() + Col; + return getMatrixElt(I); + } + const APValue &getMatrixElt(unsigned Row, unsigned Col) const { + return const_cast<APValue *>(this)->getMatrixElt(Row, Col); + } + APValue &getArrayInitializedElt(unsigned I) { assert(isArray() && "Invalid accessor"); assert(I < getArrayInitializedElts() && "Index out of range"); @@ -668,6 +718,11 @@ class APValue { for (unsigned i = 0; i != N; ++i) InternalElts[i] = E[i]; } + void setMatrix(const APValue *E, unsigned NumRows, unsigned NumCols) { + MutableArrayRef<APValue> InternalElts = setMatrixUninit(NumRows, NumCols); + for (unsigned i = 0; i != NumRows * NumCols; ++i) + InternalElts[i] = E[i]; + } void setComplexInt(APSInt R, APSInt I) { assert(R.getBitWidth() == I.getBitWidth() && "Invalid complex int (type mismatch)."); @@ -716,6 +771,11 @@ class APValue { new ((void *)(char *)&Data) Vec(); Kind = Vector; } + void MakeMatrix() { + assert(isAbsent() && "Bad state change"); + new ((void *)(char *)&Data) Mat(); + Kind = Matrix; + } void MakeComplexInt() { assert(isAbsent() && "Bad state change"); new ((void *)(char *)&Data) ComplexAPSInt(); @@ -757,6 +817,15 @@ class APValue { V->NumElts = N; return {V->Elts, V->NumElts}; } + MutableArrayRef<APValue> setMatrixUninit(unsigned NumRows, unsigned NumCols) { + assert(isMatrix() && "Invalid accessor"); + Mat *M = ((Mat *)(char *)&Data); + unsigned NumElts = NumRows * NumCols; + M->Elts = new APValue[NumElts]; + M->NumRows = NumRows; + M->NumCols = NumCols; + return {M->Elts, NumElts}; + } MutableArrayRef<LValuePathEntry> setLValueUninit(LValueBase B, const CharUnits &O, unsigned Size, bool OnePastTheEnd, bool IsNullPtr); diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td index 4581e55c28027..0011e57ed5ef7 100644 --- a/clang/include/clang/AST/PropertiesBase.td +++ b/clang/include/clang/AST/PropertiesBase.td @@ -353,6 +353,29 @@ let Class = PropertyTypeCase<APValue, "Vector"> in { return result; }]>; } +let Class = PropertyTypeCase<APValue, "Matrix"> in { + def : ReadHelper<[{ + SmallVector<APValue, 16> buffer; + unsigned numElts = node.getMatrixNumElements(); + for (unsigned i = 0; i < numElts; ++i) + buffer.push_back(node.getMatrixElt(i)); + }]>; + def : Property<"numRows", UInt32> { + let Read = [{ node.getMatrixNumRows() }]; + } + def : Property<"numCols", UInt32> { + let Read = [{ node.getMatrixNumColumns() }]; + } + def : Property<"elements", Array<APValue>> { let Read = [{ buffer }]; } + def : Creator<[{ + APValue result; + result.MakeMatrix(); + (void)result.setMatrixUninit(numRows, numCols); + for (unsigned i = 0; i < elements.size(); i++) + result.getMatrixElt(i) = elements[i]; + return result; + }]>; +} let Class = PropertyTypeCase<APValue, "Array"> in { def : ReadHelper<[{ SmallVector<APValue, 4> buffer{}; diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h index dc4442bfeb795..ec7845c3b3adb 100644 --- a/clang/include/clang/AST/TypeBase.h +++ b/clang/include/clang/AST/TypeBase.h @@ -4435,7 +4435,7 @@ class ConstantMatrixType final : public MatrixType { /// row-major order flattened index. Otherwise, returns the column-major order /// flattened index. unsigned getFlattenedIndex(unsigned Row, unsigned Column, - bool IsRowMajor = false) { + bool IsRowMajor = false) const { return IsRowMajor ? getRowMajorFlattenedIndex(Row, Column) : getColumnMajorFlattenedIndex(Row, Column); } diff --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp index 2e1c8eb3726cf..95b6f7f745ccb 100644 --- a/clang/lib/AST/APValue.cpp +++ b/clang/lib/AST/APValue.cpp @@ -333,6 +333,11 @@ APValue::APValue(const APValue &RHS) setVector(((const Vec *)(const char *)&RHS.Data)->Elts, RHS.getVectorLength()); break; + case Matrix: + MakeMatrix(); + setMatrix(((const Mat *)(const char *)&RHS.Data)->Elts, + RHS.getMatrixNumRows(), RHS.getMatrixNumColumns()); + break; case ComplexInt: MakeComplexInt(); setComplexInt(RHS.getComplexIntReal(), RHS.getComplexIntImag()); @@ -414,6 +419,8 @@ void APValue::DestroyDataAndMakeUninit() { ((APFixedPoint *)(char *)&Data)->~APFixedPoint(); else if (Kind == Vector) ((Vec *)(char *)&Data)->~Vec(); + else if (Kind == Matrix) + ((Mat *)(char *)&Data)->~Mat(); else if (Kind == ComplexInt) ((ComplexAPSInt *)(char *)&Data)->~ComplexAPSInt(); else if (Kind == ComplexFloat) @@ -444,6 +451,7 @@ bool APValue::needsCleanup() const { case Union: case Array: case Vector: + case Matrix: return true; case Int: return getInt().needsCleanup(); @@ -580,6 +588,12 @@ void APValue::Profile(llvm::FoldingSetNodeID &ID) const { getVectorElt(I).Profile(ID); return; + case Matrix: + for (unsigned R = 0, N = getMatrixNumRows(); R != N; ++R) + for (unsigned C = 0, M = getMatrixNumColumns(); C != M; ++C) + getMatrixElt(R, C).Profile(ID); + return; + case Int: profileIntValue(ID, getInt()); return; @@ -747,6 +761,24 @@ void APValue::printPretty(raw_ostream &Out, const PrintingPolicy &Policy, Out << '}'; return; } + case APValue::Matrix: { + const auto *MT = Ty->castAs<ConstantMatrixType>(); + QualType ElemTy = MT->getElementType(); + Out << '{'; + for (unsigned R = 0; R < getMatrixNumRows(); ++R) { + if (R != 0) + Out << ", "; + Out << '{'; + for (unsigned C = 0; C < getMatrixNumColumns(); ++C) { + if (C != 0) + Out << ", "; + getMatrixElt(R, C).printPretty(Out, Policy, ElemTy, Ctx); + } + Out << '}'; + } + Out << '}'; + return; + } case APValue::ComplexInt: Out << getComplexIntReal() << "+" << getComplexIntImag() << "i"; return; @@ -1139,6 +1171,7 @@ LinkageInfo LinkageComputer::getLVForValue(const APValue &V, case APValue::ComplexInt: case APValue::ComplexFloat: case APValue::Vector: + case APValue::Matrix: break; case APValue::AddrLabelDiff: diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index 00af38626a8f7..340a0fb4fb5aa 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -10630,6 +10630,9 @@ ASTNodeImporter::ImportAPValue(const APValue &FromValue) { Elts.data(), FromValue.getVectorLength()); break; } + case APValue::Matrix: + // Matrix values cannot currently arise in APValue import contexts. + llvm_unreachable("Matrix APValue import not yet supported"); case APValue::Array: Result.MakeArray(FromValue.getArrayInitializedElts(), FromValue.getArraySize()); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index a131cf8d158df..f233ff76632ae 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -1772,6 +1772,7 @@ static bool EvaluateIntegerOrLValue(const Expr *E, APValue &Result, EvalInfo &Info); static bool EvaluateFloat(const Expr *E, APFloat &Result, EvalInfo &Info); static bool EvaluateComplex(const Expr *E, ComplexValue &Res, EvalInfo &Info); +static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info); static bool EvaluateAtomic(const Expr *E, const LValue *This, APValue &Result, EvalInfo &Info); static bool EvaluateAsRValue(EvalInfo &Info, const Expr *E, APValue &Result); @@ -2602,6 +2603,7 @@ static bool HandleConversionToBool(const APValue &Val, bool &Result) { Result = Val.getMemberPointerDecl(); return true; case APValue::Vector: + case APValue::Matrix: case APValue::Array: case APValue::Struct: case APValue::Union: @@ -3932,6 +3934,12 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) { Size += NumEl; continue; } + if (Type->isConstantMatrixType()) { + unsigned NumEl = + Type->castAs<ConstantMatrixType>()->getNumElementsFlattened(); + Size += NumEl; + continue; + } if (Type->isConstantArrayType()) { QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type)) ->getElementType(); @@ -3982,6 +3990,11 @@ static bool hlslAggSplatHelper(EvalInfo &Info, const Expr *E, APValue &SrcVal, SrcTy = SrcTy->castAs<VectorType>()->getElementType(); SrcVal = SrcVal.getVectorElt(0); } + if (SrcVal.isMatrix()) { + assert(SrcTy->isConstantMatrixType() && "Type mismatch."); + SrcTy = SrcTy->castAs<ConstantMatrixType>()->getElementType(); + SrcVal = SrcVal.getMatrixElt(0, 0); + } return true; } @@ -4011,6 +4024,22 @@ static bool flattenAPValue(EvalInfo &Info, const Expr *E, APValue Value, } continue; } + if (Work.isMatrix()) { + assert(Type->isConstantMatrixType() && "Type mismatch."); + const auto *MT = Type->castAs<ConstantMatrixType>(); + QualType ElTy = MT->getElementType(); + // Matrix elements are flattened in row-major order. + for (unsigned Row = 0; Row < Work.getMatrixNumRows() && Populated < Size; + Row++) { + for (unsigned Col = 0; + Col < Work.getMatrixNumColumns() && Populated < Size; Col++) { + Elements.push_back(Work.getMatrixElt(Row, Col)); + Types.push_back(ElTy); + Populated++; + } + } + continue; + } if (Work.isArray()) { assert(Type->isConstantArrayType() && "Type mismatch."); QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type)) @@ -7742,6 +7771,7 @@ class APValueToBufferConverter { case APValue::FixedPoint: // FIXME: We should support these. + case APValue::Matrix: case APValue::Union: case APValue::MemberPointer: case APValue::AddrLabelDiff: { @@ -11667,8 +11697,17 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) { return Success(Elements, E); } case CK_HLSLMatrixTruncation: { - // TODO: See #168935. Add matrix truncation support to expr constant. - return Error(E); + // Matrix truncation occurs in row-major order. + APValue Val; + if (!EvaluateMatrix(SE, Val, Info)) + return Error(E); + SmallVector<APValue, 16> Elements; + for (unsigned Row = 0; + Row < Val.getMatrixNumRows() && Elements.size() < NElts; Row++) + for (unsigned Col = 0; + Col < Val.getMatrixNumColumns() && Elements.size() < NElts; Col++) + Elements.push_back(Val.getMatrixElt(Row, Col)); + return Success(Elements, E); } case CK_HLSLAggregateSplatCast: { APValue Val; @@ -14594,6 +14633,117 @@ bool VectorExprEvaluator::VisitShuffleVectorExpr(const ShuffleVectorExpr *E) { return Success(APValue(ResultElements.data(), ResultElements.size()), E); } +//===----------------------------------------------------------------------===// +// Matrix Evaluation +//===----------------------------------------------------------------------===// + +namespace { +class MatrixExprEvaluator : public ExprEvaluatorBase<MatrixExprEvaluator> { + APValue &Result; + +public: + MatrixExprEvaluator(EvalInfo &Info, APValue &Result) + : ExprEvaluatorBaseTy(Info), Result(Result) {} + + bool Success(ArrayRef<APValue> M, const Expr *E) { + auto *CMTy = E->getType()->castAs<ConstantMatrixType>(); + assert(M.size() == CMTy->getNumElementsFlattened()); + // FIXME: remove this APValue copy. + Result = APValue(M.data(), CMTy->getNumRows(), CMTy->getNumColumns()); + return true; + } + bool Success(const APValue &M, const Expr *E) { + assert(M.isMatrix() && "expected matrix"); + Result = M; + return true; + } + + bool VisitCastExpr(const CastExpr *E); + bool VisitInitListExpr(const InitListExpr *E); +}; +} // end anonymous namespace + +static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info) { + assert(E->isPRValue() && E->getType()->isConstantMatrixType() && + "not a matrix prvalue"); + return MatrixExprEvaluator(Info, Result).Visit(E); +} + +bool MatrixExprEvaluator::VisitCastExpr(const CastExpr *E) { + const auto *MT = E->getType()->castAs<ConstantMatrixType>(); + unsigned NumRows = MT->getNumRows(); + unsigned NumCols = MT->getNumColumns(); + unsigned NElts = NumRows * NumCols; + QualType EltTy = MT->getElementType(); + const Expr *SE = E->getSubExpr(); + + switch (E->getCastKind()) { + case CK_HLSLAggregateSplatCast: { + APValue Val; + QualType ValTy; + + if (!hlslAggSplatHelper(Info, SE, Val, ValTy)) + return false; + + APValue CastedVal; + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + if (!handleScalarCast(Info, FPO, E, ValTy, EltTy, Val, CastedVal)) + return false; + + SmallVector<APValue, 16> SplatEls(NElts, CastedVal); + return Success(SplatEls, E); + } + case CK_HLSLElementwiseCast: { + SmallVector<APValue> SrcVals; + SmallVector<QualType> SrcTypes; + + if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcVals, SrcTypes)) + return false; + + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + SmallVector<QualType, 16> DestTypes(NElts, EltTy); + SmallVector<APValue, 16> ResultEls(NElts); + if (!handleElementwiseCast(Info, E, FPO, SrcVals, SrcTypes, DestTypes, + ResultEls)) + return false; + return Success(ResultEls, E); + } + default: + return ExprEvaluatorBaseTy::VisitCastExpr(E); + } +} + +bool MatrixExprEvaluator::VisitInitListExpr(const InitListExpr *E) { + const auto *MT = E->getType()->castAs<ConstantMatrixType>(); + QualType EltTy = MT->getElementType(); + + assert(E->getNumInits() == MT->getNumElementsFlattened() && + "Expected number of elements in initializer list to match the number " + "of matrix elements"); + + SmallVector<APValue, 16> Elements; + Elements.reserve(MT->getNumElementsFlattened()); + + // The following loop assumes the elements of the matrix InitListExpr are in + // row-major order, which matches the row-major ordering assumption of the + // matrix APValue. + for (unsigned I = 0, N = MT->getNumElementsFlattened(); I < N; ++I) { + if (EltTy->isIntegerType()) { + llvm::APSInt IntVal; + if (!EvaluateInteger(E->getInit(I), IntVal, Info)) + return false; + Elements.push_back(APValue(IntVal)); + } else { + llvm::APFloat FloatVal(0.0); + if (!EvaluateFloat(E->getInit(I), FloatVal, Info)) + return false; + Elements.push_back(APValue(FloatVal)); + } + } + + return Success(Elements, E); +} + //===----------------------------------------------------------------------===// // Array Evaluation //===----------------------------------------------------------------------===// @@ -18932,8 +19082,10 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) { return Success(Val.getVectorElt(0), E); } case CK_HLSLMatrixTruncation: { - // TODO: See #168935. Add matrix truncation support to expr constant. - return Error(E); + APValue Val; + if (!EvaluateMatrix(SubExpr, Val, Info)) + return Error(E); + return Success(Val.getMatrixElt(0, 0), E); } case CK_HLSLElementwiseCast: { SmallVector<APValue> SrcVals; @@ -19529,8 +19681,10 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) { return Success(Val.getVectorElt(0), E); } case CK_HLSLMatrixTruncation: { - // TODO: See #168935. Add matrix truncation support to expr constant. - return Error(E); + APValue Val; + if (!EvaluateMatrix(SubExpr, Val, Info)) + return Error(E); + return Success(Val.getMatrixElt(0, 0), E); } case CK_HLSLElementwiseCast: { SmallVector<APValue> SrcVals; @@ -20429,6 +20583,9 @@ static bool Evaluate(APValue &Result, EvalInfo &Info, const Expr *E) { } else if (T->isVectorType()) { if (!EvaluateVector(E, Result, Info)) return false; + } else if (T->isConstantMatrixType()) { + if (!EvaluateMatrix(E, Result, Info)) + return false; } else if (T->isIntegralOrEnumerationType()) { if (!IntExprEvaluator(Info, Result).Visit(E)) return false; diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 1faf7f1466e39..eea04b14eaf09 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -6495,6 +6495,9 @@ static bool isZeroInitialized(QualType T, const APValue &V) { return true; } + case APValue::Matrix: + llvm_unreachable("Matrix APValues not yet supported"); + case APValue::Int: return !V.getInt(); @@ -6708,6 +6711,9 @@ void CXXNameMangler::mangleValueInTemplateArg(QualType T, const APValue &V, break; } + case APValue::Matrix: + llvm_unreachable("Matrix template argument mangling not yet supported"); + case APValue::Int: mangleIntegerLiteral(T, V.getInt()); break; diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index 1f28d281be9fe..1bf92d4209f9f 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -2154,6 +2154,11 @@ void MicrosoftCXXNameMangler::mangleTemplateArgValue(QualType T, return; } + case APValue::Matrix: { + Error("template argument (value type: matrix)"); + return; + } + case APValue::AddrLabelDiff: { Error("template argument (value type: address label diff )"); return; diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp index be8b03e1c2b72..89d2a2691e8fb 100644 --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -620,6 +620,7 @@ static bool isSimpleAPValue(const APValue &Value) { case APValue::Vector: case APValue::Array: case APValue::Struct: + case APValue::Matrix: return false; case APValue::Union: return isSimpleAPValue(Value.getUnionValue()); @@ -812,6 +813,19 @@ void TextNodeDumper::Visit(const APValue &Value, QualType Ty) { return; } + case APValue::Matrix: { + unsigned NumRows = Value.getMatrixNumRows(); + unsigned NumCols = Value.getMatrixNumColumns(); + OS << "Matrix " << NumRows << "x" << NumCols; + + dumpAPValueChildren( + Value, Ty, + [](const APValue &Value, unsigned Index) -> const APValue & { + return Value.getMatrixElt(Index); + }, + Value.getMatrixNumElements(), "element", "elements"); + return; + } case APValue::Union: { OS << "Union"; { diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index a85f08753a132..290b91effeacd 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -3084,6 +3084,10 @@ bool Type::isLiteralType(const ASTContext &Ctx) const { if (BaseTy->isScalarType() || BaseTy->isVectorType() || BaseTy->isAnyComplexType()) return true; + // Matrices with constant numbers of rows and columns are also literal types + // in HLSL. + if (Ctx.getLangOpts().HLSL && BaseTy->isConstantMatrixType()) + return true; // -- a reference type; or if (BaseTy->isReferenceType()) return true; diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp index 3f44243e1c35c..0739935acd867 100644 --- a/clang/lib/CodeGen/CGExprConstant.cpp +++ b/clang/lib/CodeGen/CGExprConstant.cpp @@ -2548,6 +2548,35 @@ ConstantEmitter::tryEmitPrivate(const APValue &Value, QualType DestType, } return llvm::ConstantVector::get(Inits); } + case APValue::Matrix: { + const auto *MT = DestType->castAs<ConstantMatrixType>(); + unsigned NumRows = Value.getMatrixNumRows(); + unsigned NumCols = Value.getMatrixNumColumns(); + unsigned NumElts = NumRows * NumCols; + SmallVector<llvm::Constant *, 16> Inits(NumElts); + + bool IsRowMajor = CGM.getLangOpts().getDefaultMatrixMemoryLayout() == + LangOptions::MatrixMemoryLayout::MatrixRowMajor; + + for (unsigned Row = 0; Row != NumRows; ++Row) { + for (unsigned Col = 0; Col != NumCols; ++Col) { + const APValue &Elt = Value.getMatrixElt(Row, Col); + unsigned Idx = MT->getFlattenedIndex(Row, Col, IsRowMajor); + if (Elt.isInt()) + Inits[Idx] = + llvm::ConstantInt::get(CGM.getLLVMContext(), Elt.getInt()); + else if (Elt.isFloat()) + Inits[Idx] = + llvm::ConstantFP::get(CGM.getLLVMContext(), Elt.getFloat()); + else if (Elt.isIndeterminate()) + Inits[Idx] = llvm::PoisonValue::get( + CGM.getTypes().ConvertType(MT->getElementType())); + else + llvm_unreachable("unsupported matrix element type"); + } + } + return llvm::ConstantVector::get(Inits); + } case APValue::AddrLabelDiff: { const AddrLabelExpr *LHSExpr = Value.getAddrLabelDiffLHS(); const AddrLabelExpr *RHSExpr = Value.getAddrLabelDiffRHS(); diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp index 9f3bf7437fc54..9b0bec20618a0 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -8138,6 +8138,9 @@ static Expr *BuildExpressionFromNonTypeTemplateArgumentValue( return MakeInitList(Elts); } + case APValue::Matrix: + llvm_unreachable("Matrix template argument expression not yet supported"); + case APValue::None: case APValue::Indeterminate: llvm_unreachable("Unexpected APValue kind."); diff --git a/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl new file mode 100644 index 0000000000000..98428225ebb3d --- /dev/null +++ b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl @@ -0,0 +1,56 @@ +// Test without serialization: +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x \ +// RUN: -fnative-half-type -fnative-int16-type \ +// RUN: -ast-dump %s -ast-dump-filter Test \ +// RUN: | FileCheck --strict-whitespace %s +// +// Test with serialization: +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x \ +// RUN: -fnative-half-type -fnative-int16-type -emit-pch -o %t %s +// RUN: %clang_cc1 -x hlsl -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x \ +// RUN: -fnative-half-type -fnative-int16-type \ +// RUN: -include-pch %t -ast-dump-all -ast-dump-filter Test /dev/null \ +// RUN: | sed -e "s/ <undeserialized declarations>//" -e "s/ imported//" \ +// RUN: | FileCheck --strict-whitespace %s + +export void Test() { + constexpr int2x2 mat2x2 = {1, 2, 3, 4}; + // CHECK: VarDecl {{.*}} mat2x2 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 2x2 + // CHECK-NEXT: | `-elements: Int 1, Int 2, Int 3, Int 4 + + constexpr float3x2 mat3x2 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + // CHECK: VarDecl {{.*}} mat3x2 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 3x2 + // CHECK-NEXT: | |-elements: Float 1.000000e+00, Float 2.000000e+00, Float 3.000000e+00, Float 4.000000e+00 + // CHECK-NEXT: | `-elements: Float 5.000000e+00, Float 6.000000e+00 + + constexpr int16_t3x2 i16mat3x2 = {-1, 2, -3, 4, -5, 6}; + // CHECK: VarDecl {{.*}} i16mat3x2 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 3x2 + // CHECK-NEXT: | |-elements: Int -1, Int 2, Int -3, Int 4 + // CHECK-NEXT: | `-elements: Int -5, Int 6 + + constexpr int64_t4x1 i64mat4x1 = {100, -200, 300, -400}; + // CHECK: VarDecl {{.*}} i64mat4x1 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 4x1 + // CHECK-NEXT: | `-elements: Int 100, Int -200, Int 300, Int -400 + + constexpr half2x3 hmat2x3 = {1.5h, -2.5h, 3.5h, -4.5h, 5.5h, -6.5h}; + // CHECK: VarDecl {{.*}} hmat2x3 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 2x3 + // CHECK-NEXT: | |-elements: Float 1.500000e+00, Float -2.500000e+00, Float 3.500000e+00, Float -4.500000e+00 + // CHECK-NEXT: | `-elements: Float 5.500000e+00, Float -6.500000e+00 + + constexpr double1x4 dmat1x4 = {0.5l, -1.25l, 2.75l, -3.125l}; + // CHECK: VarDecl {{.*}} dmat1x4 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 1x4 + // CHECK-NEXT: | `-elements: Float 5.000000e-01, Float -1.250000e+00, Float 2.750000e+00, Float -3.125000e+00 + + constexpr bool3x3 bmat3x3 = {true, false, true, false, true, false, true, false, true}; + // CHECK: VarDecl {{.*}} bmat3x3 {{.*}} constexpr cinit + // CHECK-NEXT: |-value: Matrix 3x3 + // CHECK-NEXT: | |-elements: Int 1, Int 0, Int 1, Int 0 + // CHECK-NEXT: | |-elements: Int 1, Int 0, Int 1, Int 0 + // CHECK-NEXT: | `-element: Int 1 +} diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl index c101ac02f7891..c61d82635d513 100644 --- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl +++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl @@ -58,12 +58,9 @@ bool2x2 fn2(bool V) { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[RETVAL:%.*]] = alloca i1, align 4 // CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z3fn3v.s, i32 20, i1 false) // CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], align 1 -// CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1 -// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1 -// CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[BM]], align 1 // CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0 // CHECK-NEXT: store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4 // CHECK-NEXT: [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4 @@ -114,15 +111,12 @@ void fn5() { // CHECK-NEXT: [[V:%.*]] = alloca i32, align 4 // CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1 // CHECK-NEXT: store i32 0, ptr [[V]], align 4 -// CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], align 1 -// CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1 -// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1 +// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z3fn6v.s, i32 20, i1 false) // CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[V]], align 4 // CHECK-NEXT: [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1 -// CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 +// CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 // CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[LOADEDV]] to i32 -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM1]], i32 0, i32 1 +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM]], i32 0, i32 1 // CHECK-NEXT: store i32 [[TMP1]], ptr [[TMP2]], align 4 // CHECK-NEXT: ret void // diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl new file mode 100644 index 0000000000000..64220980d9edc --- /dev/null +++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl @@ -0,0 +1,127 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s + +// expected-no-diagnostics + +// Matrix subscripting is not currently supported with matrix constexpr. So all +// tests involve casting to another type to determine if the output is correct. + +export void fn() { + + // Matrix truncation to int - should get element at (0,0) + { + constexpr int2x3 IM = {1, 2, 3, + 4, 5, 6}; + _Static_assert((int)IM == 1, "Woo!"); + } + + // Matrix splat to vector + { + constexpr bool2x2 BM2x2 = true; + constexpr bool4 BV4 = (bool4)BM2x2; + _Static_assert(BV4.x == true, "Woo!"); + _Static_assert(BV4.y == true, "Woo!"); + _Static_assert(BV4.z == true, "Woo!"); + _Static_assert(BV4.w == true, "Woo!"); + } + + // Matrix cast to vector + { + constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5}; + constexpr float4 FV4 = (float4)FM2x2; + _Static_assert(FV4.x == 1.5, "Woo!"); + _Static_assert(FV4.y == 2.5, "Woo!"); + _Static_assert(FV4.z == 3.5, "Woo!"); + _Static_assert(FV4.w == 4.5, "Woo!"); + } + + // Matrix cast to array + { + constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5}; + constexpr float FA4[4] = (float[4])FM2x2; + _Static_assert(FA4[0] == 1.5, "Woo!"); + _Static_assert(FA4[1] == 2.5, "Woo!"); + _Static_assert(FA4[2] == 3.5, "Woo!"); + _Static_assert(FA4[3] == 4.5, "Woo!"); + } + + // Array cast to matrix to vector + { + constexpr int IA4[4] = {1, 2, 3, 4}; + constexpr int2x2 IM2x2 = (int2x2)IA4; + constexpr int4 IV4 = (int4)IM2x2; + _Static_assert(IV4.x == 1, "Woo!"); + _Static_assert(IV4.y == 2, "Woo!"); + _Static_assert(IV4.z == 3, "Woo!"); + _Static_assert(IV4.w == 4, "Woo!"); + } + + // Vector cast to matrix to vector + { + constexpr bool4 BV4_0 = {true, false, true, false}; + constexpr bool2x2 BM2x2 = (bool2x2)BV4_0; + constexpr bool4 BV4 = (bool4)BM2x2; + _Static_assert(BV4.x == true, "Woo!"); + _Static_assert(BV4.y == false, "Woo!"); + _Static_assert(BV4.z == true, "Woo!"); + _Static_assert(BV4.w == false, "Woo!"); + } + + // Matrix truncation to vector + { + constexpr int3x2 IM3x2 = { 1, 2, + 3, 4, + 5, 6}; + constexpr int4 IV4 = (int4)IM3x2; + _Static_assert(IV4.x == 1, "Woo!"); + _Static_assert(IV4.y == 2, "Woo!"); + _Static_assert(IV4.z == 3, "Woo!"); + _Static_assert(IV4.w == 4, "Woo!"); + } + + // Matrix truncation to array + { + constexpr int3x2 IM3x2 = { 1, 2, + 3, 4, + 5, 6}; + constexpr int IA4[4] = (int[4])IM3x2; + _Static_assert(IA4[0] == 1, "Woo!"); + _Static_assert(IA4[1] == 2, "Woo!"); + _Static_assert(IA4[2] == 3, "Woo!"); + _Static_assert(IA4[3] == 4, "Woo!"); + } + + // Array cast to matrix truncation to vector + { + constexpr float FA16[16] = { 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0}; + constexpr float4x4 FM4x4 = (float4x4)FA16; + constexpr float4 FV4 = (float4)FM4x4; + _Static_assert(FV4.x == 1.0, "Woo!"); + _Static_assert(FV4.y == 2.0, "Woo!"); + _Static_assert(FV4.z == 3.0, "Woo!"); + _Static_assert(FV4.w == 4.0, "Woo!"); + } + + // Vector cast to matrix truncation to vector + { + constexpr bool4 BV4 = {true, false, true, false}; + constexpr bool2x2 BM2x2 = (bool2x2)BV4; + constexpr bool3 BV3 = (bool3)BM2x2; + _Static_assert(BV4.x == true, "Woo!"); + _Static_assert(BV4.y == false, "Woo!"); + _Static_assert(BV4.z == true, "Woo!"); + } + + // Matrix cast to vector with type conversion + { + constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5}; + constexpr int4 IV4 = (int4)FM2x2; + _Static_assert(IV4.x == 1, "Woo!"); + _Static_assert(IV4.y == 2, "Woo!"); + _Static_assert(IV4.z == 3, "Woo!"); + _Static_assert(IV4.w == 4, "Woo!"); + } +} _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
