https://github.com/Icohedron created https://github.com/llvm/llvm-project/pull/184840
Forgot to change the target branch before merging. This PR is a cherry-pick of the squashed-and-merged PR commit b16aa4b7ec665911c74300cd7442659b70973d13 from 183424 This PR fixes #182963 This PR is an extension of #178762 which has already been merged. This PR adds support for `ConstantMatrixType` and the HLSL casts `CK_HLSLArrayRValue`, `CK_HLSLMatrixTruncation`, `CK_HLSLAggregateSplatCast`, and `CK_HLSLElementwiseCast` to the bytecode constexpr evaluator. The implementations of CK_HLSLAggregateSplatCast and CK_HLSLElementwiseCast are incomplete, as they still need to support struct and array types to enable use of the experimental new constant interpreter on other existing HLSL constexpr tests. The completion of the implementations of these casts will be tracked in a separate issue (#183426) and implemented in a separate PR. Assisted-by: claude-opus-4.6 >From b6a48d2a7df7bc284df42df3464dc577cb9d0f09 Mon Sep 17 00:00:00 2001 From: "Deric C." <[email protected]> Date: Thu, 5 Mar 2026 09:45:48 -0800 Subject: [PATCH] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType` and more HLSL casts in the new constant interpreter for basic matrix constexpr evaluation in HLSL (#183424) Fixes #182963 This PR is an extension of #178762 and is to be merged immediately after it. This PR adds support for `ConstantMatrixType` and the HLSL casts `CK_HLSLArrayRValue`, `CK_HLSLMatrixTruncation`, `CK_HLSLAggregateSplatCast`, and `CK_HLSLElementwiseCast` to the bytecode constexpr evaluator. The implementations of CK_HLSLAggregateSplatCast and CK_HLSLElementwiseCast are incomplete, as they still need to support struct and array types to enable use of the experimental new constant interpreter on other existing HLSL constexpr tests. The completion of the implementations of these casts will be tracked in a separate issue (#183426) and implemented in a separate PR. Assisted-by: claude-opus-4.6 --- clang/lib/AST/ByteCode/Compiler.cpp | 226 ++++++++++++++++++ clang/lib/AST/ByteCode/Compiler.h | 5 + clang/lib/AST/ByteCode/Pointer.cpp | 18 ++ clang/lib/AST/ByteCode/Program.cpp | 10 + .../BuiltinMatrix/MatrixConstantExpr.hlsl | 4 + 5 files changed, 263 insertions(+) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index e33fa4f86c052..93ad8eb26f29e 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -811,6 +811,199 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { case CK_LValueBitCast: return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE); + case CK_HLSLArrayRValue: { + // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue + // array, similar to LValueToRValue for composite types. + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateLocal(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + if (!this->visit(SubExpr)) + return false; + return this->emitMemcpy(CE); + } + + case CK_HLSLMatrixTruncation: { + assert(SubExpr->getType()->isConstantMatrixType()); + if (OptPrimType ResultT = classify(CE)) { + assert(!DiscardResult); + // Result must be either a float or integer. Take the first element. + if (!this->visit(SubExpr)) + return false; + return this->emitArrayElemPop(*ResultT, 0, CE); + } + // Otherwise, this truncates to a a constant matrix type. + assert(CE->getType()->isConstantMatrixType()); + + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateTemporary(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + unsigned ToSize = + CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened(); + if (!this->visit(SubExpr)) + return false; + return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0, + 0, ToSize, CE); + } + + case CK_HLSLAggregateSplatCast: { + // Aggregate splat cast: convert a scalar value to one of an aggregate type, + // inserting casts when necessary to convert the scalar to the aggregate's + // element type(s). + // TODO: Aggregate splat to struct and array types + assert(canClassify(SubExpr->getType())); + + unsigned NumElems; + PrimType DestElemT; + QualType DestElemType; + if (const auto *VT = CE->getType()->getAs<VectorType>()) { + NumElems = VT->getNumElements(); + DestElemType = VT->getElementType(); + } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) { + NumElems = MT->getNumElementsFlattened(); + DestElemType = MT->getElementType(); + } else { + return false; + } + DestElemT = classifyPrim(DestElemType); + + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateLocal(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + + PrimType SrcElemT = classifyPrim(SubExpr->getType()); + unsigned SrcOffset = + allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true); + + if (!this->visit(SubExpr)) + return false; + if (SrcElemT != DestElemT) { + if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE)) + return false; + } + if (!this->emitSetLocal(DestElemT, SrcOffset, CE)) + return false; + + for (unsigned I = 0; I != NumElems; ++I) { + if (!this->emitGetLocal(DestElemT, SrcOffset, CE)) + return false; + if (!this->emitInitElem(DestElemT, I, CE)) + return false; + } + return true; + } + + case CK_HLSLElementwiseCast: { + // Elementwise cast: flatten source elements of one aggregate type and store + // to a destination scalar or aggregate type of the same or fewer number of + // elements, while inserting casts as necessary. + // TODO: Elementwise cast to structs, nested arrays, and arrays of composite + // types + QualType SrcType = SubExpr->getType(); + QualType DestType = CE->getType(); + + // Allowed SrcTypes + const auto *SrcVT = SrcType->getAs<VectorType>(); + const auto *SrcMT = SrcType->getAs<ConstantMatrixType>(); + const auto *SrcAT = SrcType->getAsArrayTypeUnsafe(); + const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr; + + // Allowed DestTypes + const auto *DestVT = DestType->getAs<VectorType>(); + const auto *DestMT = DestType->getAs<ConstantMatrixType>(); + const auto *DestAT = DestType->getAsArrayTypeUnsafe(); + const auto *DestCAT = + DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr; + const OptPrimType DestPT = classify(DestType); + + if (!SrcVT && !SrcMT && !SrcCAT) + return false; + if (!DestVT && !DestMT && !DestCAT && !DestPT) + return false; + + unsigned SrcNumElems; + PrimType SrcElemT; + if (SrcVT) { + SrcNumElems = SrcVT->getNumElements(); + SrcElemT = classifyPrim(SrcVT->getElementType()); + } else if (SrcMT) { + SrcNumElems = SrcMT->getNumElementsFlattened(); + SrcElemT = classifyPrim(SrcMT->getElementType()); + } else if (SrcCAT) { + SrcNumElems = SrcCAT->getZExtSize(); + SrcElemT = classifyPrim(SrcCAT->getElementType()); + } + + if (DestPT) { + // Scalar destination: extract element 0 and cast. + if (!this->visit(SubExpr)) + return false; + if (!this->emitArrayElemPop(SrcElemT, 0, CE)) + return false; + if (SrcElemT != *DestPT) { + if (!this->emitPrimCast(SrcElemT, *DestPT, DestType, CE)) + return false; + } + return true; + } + + unsigned DestNumElems; + PrimType DestElemT; + QualType DestElemType; + if (DestVT) { + DestNumElems = DestVT->getNumElements(); + DestElemType = DestVT->getElementType(); + } else if (DestMT) { + DestNumElems = DestMT->getNumElementsFlattened(); + DestElemType = DestMT->getElementType(); + } else if (DestCAT) { + DestNumElems = DestCAT->getZExtSize(); + DestElemType = DestCAT->getElementType(); + } + DestElemT = classifyPrim(DestElemType); + + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateTemporary(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + + unsigned SrcOffset = + allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true); + if (!this->visit(SubExpr)) + return false; + if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE)) + return false; + + unsigned NumElems = std::min(SrcNumElems, DestNumElems); + for (unsigned I = 0; I != NumElems; ++I) { + if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE)) + return false; + if (!this->emitArrayElemPop(SrcElemT, I, CE)) + return false; + if (SrcElemT != DestElemT) { + if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE)) + return false; + } + if (!this->emitInitElem(DestElemT, I, CE)) + return false; + } + return true; + } + default: return this->emitInvalid(CE); } @@ -1813,6 +2006,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr( return true; } + if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) { + unsigned NumElems = MT->getNumElementsFlattened(); + QualType ElemQT = MT->getElementType(); + PrimType ElemT = classifyPrim(ElemQT); + + for (unsigned I = 0; I != NumElems; ++I) { + if (!this->visitZeroInitializer(ElemT, ElemQT, E)) + return false; + if (!this->emitInitElem(ElemT, I, E)) + return false; + } + return true; + } + return false; } @@ -2129,6 +2336,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits, return true; } + if (const auto *MT = QT->getAs<ConstantMatrixType>()) { + unsigned NumElems = MT->getNumElementsFlattened(); + assert(Inits.size() == NumElems); + + QualType ElemQT = MT->getElementType(); + PrimType ElemT = classifyPrim(ElemQT); + + // Matrix initializer list elements are in row-major order, which matches + // the matrix APValue convention and therefore no index remapping is + // required. + for (unsigned I = 0; I != NumElems; ++I) { + if (!this->visit(Inits[I])) + return false; + if (!this->emitInitElem(ElemT, I, E)) + return false; + } + return true; + } + return false; } diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h index 8f1f9dad4469e..717928dc1fbbd 100644 --- a/clang/lib/AST/ByteCode/Compiler.h +++ b/clang/lib/AST/ByteCode/Compiler.h @@ -406,6 +406,11 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>, return *this->classify(T->getAs<VectorType>()->getElementType()); } + PrimType classifyMatrixElementType(QualType T) const { + assert(T->isMatrixType()); + return *this->classify(T->getAs<MatrixType>()->getElementType()); + } + bool emitComplexReal(const Expr *SubExpr); bool emitComplexBoolCast(const Expr *E); bool emitComplexComparison(const Expr *LHS, const Expr *RHS, diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp index e237013f4199c..f4352e7edf5f8 100644 --- a/clang/lib/AST/ByteCode/Pointer.cpp +++ b/clang/lib/AST/ByteCode/Pointer.cpp @@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx, return true; } + // Constant Matrix types. + if (const auto *MT = Ty->getAs<ConstantMatrixType>()) { + assert(Ptr.getFieldDesc()->isPrimitiveArray()); + QualType ElemTy = MT->getElementType(); + PrimType ElemT = *Ctx.classify(ElemTy); + unsigned NumElems = MT->getNumElementsFlattened(); + + SmallVector<APValue> Values; + Values.reserve(NumElems); + for (unsigned I = 0; I != NumElems; ++I) { + TYPE_SWITCH(ElemT, + { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); }); + } + + R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns()); + return true; + } + llvm_unreachable("invalid value to return"); }; diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp index bf511ff215db9..efef5db177e56 100644 --- a/clang/lib/AST/ByteCode/Program.cpp +++ b/clang/lib/AST/ByteCode/Program.cpp @@ -483,5 +483,15 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty, IsTemporary, IsMutable); } + // Same with constant matrix types. + if (const auto *MT = Ty->getAs<ConstantMatrixType>()) { + OptPrimType ElemTy = Ctx.classify(MT->getElementType()); + if (!ElemTy) + return nullptr; + + return allocateDescriptor(D, *ElemTy, MDSize, MT->getNumElementsFlattened(), + IsConst, IsTemporary, IsMutable); + } + return nullptr; } diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl index 64220980d9edc..2c55e1a0ee4b3 100644 --- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl +++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl @@ -1,5 +1,7 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -fexperimental-new-constant-interpreter -verify %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -fexperimental-new-constant-interpreter -verify %s // expected-no-diagnostics @@ -43,6 +45,8 @@ export void fn() { _Static_assert(FA4[1] == 2.5, "Woo!"); _Static_assert(FA4[2] == 3.5, "Woo!"); _Static_assert(FA4[3] == 4.5, "Woo!"); + constexpr float F = (float)FA4; + _Static_assert(F == 1.5, "Woo!"); } // Array cast to matrix to vector _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
