llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Deric C. (Icohedron) <details> <summary>Changes</summary> This PR fixes https://github.com/llvm/llvm-project/issues/183426, completing the implementations of `CK_HLSLAggregateSplatCast` and `CK_HLSLElementwiseCast` in Clang's new bytecode-based constant expression evaluation engine / interpreter. This PR also adds new RUN lines with `-fexperimental-new-constant-interpreter` to all HLSL tests that have static assertions. Assisted-by: GitHub Copilot --- Patch is 37.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189126.diff 21 Files Affected: - (modified) clang/lib/AST/ByteCode/Compiler.cpp (+453-105) - (modified) clang/lib/AST/ByteCode/Compiler.h (+23) - (modified) clang/test/SemaHLSL/BuiltIns/asfloat-constexpr.hlsl (+1) - (modified) clang/test/SemaHLSL/BuiltIns/asuint-constexpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Language/InitIncompleteArrays.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/Arithmetic/half_size.hlsl (+4) - (modified) clang/test/SemaHLSL/Types/Arithmetic/literal_suffixes.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/Arithmetic/literal_suffixes_202x.hlsl (+4) - (modified) clang/test/SemaHLSL/Types/Arithmetic/literal_suffixes_no_16bit.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/BuiltinVector/BooleanVectorConstantExpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/InitListConstantExpr.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/Traits/IsIntangibleType.hlsl (+2) - (modified) clang/test/SemaHLSL/Types/Traits/IsIntangibleTypeErrors.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/Traits/IsTypedResourceElementCompatible.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/Traits/ScalarizedLayoutCompatible.hlsl (+2) - (modified) clang/test/SemaHLSL/Types/Traits/ScalarizedLayoutCompatibleErrors.hlsl (+1) - (modified) clang/test/SemaHLSL/Types/typedefs.hlsl (+2) - (modified) clang/test/SemaHLSL/group_shared.hlsl (+1) ``````````diff diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 4d129a7ccd497..95dd64fdbdeda 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -855,26 +855,11 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *E) { } case CK_HLSLAggregateSplatCast: { - // Aggregate splat cast: convert a scalar value to one of an aggregate type, - // inserting casts when necessary to convert the scalar to the aggregate's - // element type(s). - // TODO: Aggregate splat to struct and array types + // Aggregate splat cast: convert a scalar value to one of an aggregate type + // by replicating and casting the scalar to every element of the destination + // aggregate (vector, matrix, array, or struct). assert(canClassify(SubExpr->getType())); - unsigned NumElems; - PrimType DestElemT; - QualType DestElemType; - if (const auto *VT = E->getType()->getAs<VectorType>()) { - NumElems = VT->getNumElements(); - DestElemType = VT->getElementType(); - } else if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) { - NumElems = MT->getNumElementsFlattened(); - DestElemType = MT->getElementType(); - } else { - return false; - } - DestElemT = classifyPrim(DestElemType); - if (!Initializing) { UnsignedOrNone LocalIndex = allocateLocal(E); if (!LocalIndex) @@ -883,99 +868,54 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *E) { return false; } + // The scalar to be splatted is stored in a local to be repeatedly loaded + // once for every scalar element of the destination. PrimType SrcElemT = classifyPrim(SubExpr->getType()); unsigned SrcOffset = - allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true); + allocateLocalPrimitive(SubExpr, SrcElemT, /*IsConst=*/true); if (!this->visit(SubExpr)) return false; - if (SrcElemT != DestElemT) { - if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, E)) - return false; - } - if (!this->emitSetLocal(DestElemT, SrcOffset, E)) + if (!this->emitSetLocal(SrcElemT, SrcOffset, E)) return false; - for (unsigned I = 0; I != NumElems; ++I) { - if (!this->emitGetLocal(DestElemT, SrcOffset, E)) - return false; - if (!this->emitInitElem(DestElemT, I, E)) - return false; - } - return true; + // Recursively splat the scalar into every element of the destination. + return emitHLSLAggregateSplat(SrcElemT, SrcOffset, E->getType(), E); } case CK_HLSLElementwiseCast: { - // Elementwise cast: flatten source elements of one aggregate type and store - // to a destination scalar or aggregate type of the same or fewer number of - // elements, while inserting casts as necessary. - // TODO: Elementwise cast to structs, nested arrays, and arrays of composite - // types + // Elementwise cast: flatten the elements of one aggregate source type and + // store to a destination scalar or aggregate type of the same or fewer + // number of elements. Casts are inserted element-wise to convert each + // source scalar element to its corresponding destination scalar element. QualType SrcType = SubExpr->getType(); QualType DestType = E->getType(); - // Allowed SrcTypes - const auto *SrcVT = SrcType->getAs<VectorType>(); - const auto *SrcMT = SrcType->getAs<ConstantMatrixType>(); - const auto *SrcAT = SrcType->getAsArrayTypeUnsafe(); - const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr; - - // Allowed DestTypes - const auto *DestVT = DestType->getAs<VectorType>(); - const auto *DestMT = DestType->getAs<ConstantMatrixType>(); - const auto *DestAT = DestType->getAsArrayTypeUnsafe(); - const auto *DestCAT = - DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr; - const OptPrimType DestPT = classify(DestType); - - if (!SrcVT && !SrcMT && !SrcCAT) - return false; - if (!DestVT && !DestMT && !DestCAT && !DestPT) - return false; - - unsigned SrcNumElems; - PrimType SrcElemT; - if (SrcVT) { - SrcNumElems = SrcVT->getNumElements(); - SrcElemT = classifyPrim(SrcVT->getElementType()); - } else if (SrcMT) { - SrcNumElems = SrcMT->getNumElementsFlattened(); - SrcElemT = classifyPrim(SrcMT->getElementType()); - } else if (SrcCAT) { - SrcNumElems = SrcCAT->getZExtSize(); - SrcElemT = classifyPrim(SrcCAT->getElementType()); - } - - if (DestPT) { - // Scalar destination: extract element 0 and cast. + const OptPrimType DestT = classify(DestType); + if (DestT) { + // When the destination is a scalar, we only need the first scalar + // element of the source. + unsigned SrcPtrOffset = + allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true); if (!this->visit(SubExpr)) return false; - if (!this->emitArrayElemPop(SrcElemT, 0, E)) + if (!this->emitSetLocal(PT_Ptr, SrcPtrOffset, E)) + return false; + + SmallVector<HLSLFlatElement, 1> Elements; + if (!emitHLSLFlattenAggregate(SrcType, SrcPtrOffset, Elements, 1, E)) + return false; + if (Elements.empty()) return false; - if (SrcElemT != *DestPT) { - if (!this->emitPrimCast(SrcElemT, *DestPT, DestType, E)) - return false; - } - return true; - } - unsigned DestNumElems; - PrimType DestElemT; - QualType DestElemType; - if (DestVT) { - DestNumElems = DestVT->getNumElements(); - DestElemType = DestVT->getElementType(); - } else if (DestMT) { - DestNumElems = DestMT->getNumElementsFlattened(); - DestElemType = DestMT->getElementType(); - } else if (DestCAT) { - DestNumElems = DestCAT->getZExtSize(); - DestElemType = DestCAT->getElementType(); + const HLSLFlatElement &Src = Elements[0]; + if (!this->emitGetLocal(Src.Type, Src.LocalOffset, E)) + return false; + return this->emitPrimCast(Src.Type, *DestT, DestType, E); } - DestElemT = classifyPrim(DestElemType); if (!Initializing) { - UnsignedOrNone LocalIndex = allocateTemporary(E); + UnsignedOrNone LocalIndex = allocateLocal(E); if (!LocalIndex) return false; if (!this->emitGetPtrLocal(*LocalIndex, E)) @@ -989,20 +929,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *E) { if (!this->emitSetLocal(PT_Ptr, SrcOffset, E)) return false; - unsigned NumElems = std::min(SrcNumElems, DestNumElems); - for (unsigned I = 0; I != NumElems; ++I) { - if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) - return false; - if (!this->emitArrayElemPop(SrcElemT, I, E)) - return false; - if (SrcElemT != DestElemT) { - if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, E)) - return false; - } - if (!this->emitInitElem(DestElemT, I, E)) - return false; - } - return true; + // Only flatten as many source elements as the destination requires. + unsigned MaxElems = countHLSLFlatElements(DestType); + + SmallVector<HLSLFlatElement, 16> Elements; + if (!emitHLSLFlattenAggregate(SrcType, SrcOffset, Elements, MaxElems, E)) + return false; + + return emitHLSLConstructAggregate(DestType, Elements, E); } default: @@ -7992,6 +7926,420 @@ bool Compiler<Emitter>::emitBuiltinBitCast(const CastExpr *E) { return true; } +/// Replicate a scalar value into every scalar element of an aggregate. +/// The scalar is stored in a local at \p SrcOffset and a pointer to the +/// destination must be on top of the interpreter stack. Each element receives +/// the scalar, cast to its own type. +template <class Emitter> +bool Compiler<Emitter>::emitHLSLAggregateSplat(PrimType SrcT, + unsigned SrcOffset, + QualType DestType, + const Expr *E) { + // Vectors and matrices are treated as flat sequences of elements. + unsigned NumElems = 0; + QualType ElemType; + if (const auto *VT = DestType->getAs<VectorType>()) { + NumElems = VT->getNumElements(); + ElemType = VT->getElementType(); + } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) { + NumElems = MT->getNumElementsFlattened(); + ElemType = MT->getElementType(); + } + if (NumElems > 0) { + PrimType ElemT = classifyPrim(ElemType); + for (unsigned I = 0; I != NumElems; ++I) { + if (!this->emitGetLocal(SrcT, SrcOffset, E)) + return false; + if (!this->emitPrimCast(SrcT, ElemT, ElemType, E)) + return false; + if (!this->emitInitElem(ElemT, I, E)) + return false; + } + return true; + } + + // Arrays: primitive elements are filled directly; composite elements + // require recursion into each sub-aggregate. + if (const auto *AT = DestType->getAsArrayTypeUnsafe()) { + const auto *CAT = cast<ConstantArrayType>(AT); + QualType ArrElemType = CAT->getElementType(); + unsigned ArrSize = CAT->getZExtSize(); + + if (OptPrimType ElemT = classify(ArrElemType)) { + for (unsigned I = 0; I != ArrSize; ++I) { + if (!this->emitGetLocal(SrcT, SrcOffset, E)) + return false; + if (!this->emitPrimCast(SrcT, *ElemT, ArrElemType, E)) + return false; + if (!this->emitInitElem(*ElemT, I, E)) + return false; + } + } else { + for (unsigned I = 0; I != ArrSize; ++I) { + if (!this->emitConstUint32(I, E)) + return false; + if (!this->emitArrayElemPtrUint32(E)) + return false; + if (!emitHLSLAggregateSplat(SrcT, SrcOffset, ArrElemType, E)) + return false; + if (!this->emitFinishInitPop(E)) + return false; + } + } + return true; + } + + // Records: fill base classes first, then named fields in declaration + // order. + if (DestType->isRecordType()) { + const Record *R = getRecord(DestType); + if (!R) + return false; + + if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) { + for (const CXXBaseSpecifier &BS : CXXRD->bases()) { + const Record::Base *B = R->getBase(BS.getType()); + assert(B); + if (!this->emitGetPtrBase(B->Offset, E)) + return false; + if (!emitHLSLAggregateSplat(SrcT, SrcOffset, BS.getType(), E)) + return false; + if (!this->emitFinishInitPop(E)) + return false; + } + } + + for (const Record::Field &F : R->fields()) { + if (F.isUnnamedBitField()) + continue; + + QualType FieldType = F.Decl->getType(); + if (OptPrimType FieldT = classify(FieldType)) { + if (!this->emitGetLocal(SrcT, SrcOffset, E)) + return false; + if (!this->emitPrimCast(SrcT, *FieldT, FieldType, E)) + return false; + if (F.isBitField()) { + if (!this->emitInitBitField(*FieldT, F.Offset, F.bitWidth(), E)) + return false; + } else { + if (!this->emitInitField(*FieldT, F.Offset, E)) + return false; + } + } else { + if (!this->emitGetPtrField(F.Offset, E)) + return false; + if (!emitHLSLAggregateSplat(SrcT, SrcOffset, FieldType, E)) + return false; + if (!this->emitPopPtr(E)) + return false; + } + } + return true; + } + + return false; +} + +/// Return the total number of scalar elements in a type. This is used +/// to cap how many source elements are extracted during an elementwise cast, +/// so we never flatten more than the destination can hold. +template <class Emitter> +unsigned Compiler<Emitter>::countHLSLFlatElements(QualType Ty) { + // Vector and matrix types are treated as flat sequences of elements. + if (const auto *VT = Ty->getAs<VectorType>()) + return VT->getNumElements(); + if (const auto *MT = Ty->getAs<ConstantMatrixType>()) + return MT->getNumElementsFlattened(); + // Arrays: total count is array size * scalar elements per element. + if (const auto *AT = Ty->getAsArrayTypeUnsafe()) { + const auto *CAT = cast<ConstantArrayType>(AT); + return CAT->getZExtSize() * countHLSLFlatElements(CAT->getElementType()); + } + // Records: sum scalar element counts of base classes and named fields. + if (Ty->isRecordType()) { + const Record *R = getRecord(Ty); + if (!R) + return 0; + unsigned Count = 0; + if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) { + for (const CXXBaseSpecifier &BS : CXXRD->bases()) + Count += countHLSLFlatElements(BS.getType()); + } + for (const Record::Field &F : R->fields()) { + if (F.isUnnamedBitField()) + continue; + Count += countHLSLFlatElements(F.Decl->getType()); + } + return Count; + } + // Scalar primitive types contribute one element. + if (classify(Ty)) + return 1; + return 0; +} + +/// Walk a source aggregate and extract every scalar element into its own local +/// variable. The results are appended to \p Elements in declaration order, +/// stopping once \p MaxElements have been collected. A pointer to the +/// source aggregate must be stored in the local at \p SrcOffset. +template <class Emitter> +bool Compiler<Emitter>::emitHLSLFlattenAggregate( + QualType SrcType, unsigned SrcOffset, + SmallVectorImpl<HLSLFlatElement> &Elements, unsigned MaxElements, + const Expr *E) { + + // Save a scalar value from the stack into a new local and record it. + auto saveToLocal = [&](PrimType T) -> bool { + unsigned Offset = allocateLocalPrimitive(E, T, /*IsConst=*/true); + if (!this->emitSetLocal(T, Offset, E)) + return false; + Elements.push_back({Offset, T}); + return true; + }; + + // Save a pointer from the stack into a new local for later use. + auto savePtrToLocal = [&]() -> UnsignedOrNone { + unsigned Offset = allocateLocalPrimitive(E, PT_Ptr, /*IsConst=*/true); + if (!this->emitSetLocal(PT_Ptr, Offset, E)) + return std::nullopt; + return Offset; + }; + + // Vectors and matrices are flat sequences of elements. + unsigned NumElems = 0; + QualType ElemType; + if (const auto *VT = SrcType->getAs<VectorType>()) { + NumElems = VT->getNumElements(); + ElemType = VT->getElementType(); + } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) { + NumElems = MT->getNumElementsFlattened(); + ElemType = MT->getElementType(); + } + if (NumElems > 0) { + PrimType ElemT = classifyPrim(ElemType); + for (unsigned I = 0; I != NumElems && Elements.size() < MaxElements; ++I) { + if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) + return false; + if (!this->emitArrayElemPop(ElemT, I, E)) + return false; + if (!saveToLocal(ElemT)) + return false; + } + return true; + } + + // Arrays: primitive elements are extracted directly; composite elements + // require recursion into each sub-aggregate. + if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) { + const auto *CAT = cast<ConstantArrayType>(AT); + QualType ArrElemType = CAT->getElementType(); + unsigned ArrSize = CAT->getZExtSize(); + + if (OptPrimType ElemT = classify(ArrElemType)) { + for (unsigned I = 0; I != ArrSize && Elements.size() < MaxElements; ++I) { + if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) + return false; + if (!this->emitArrayElemPop(*ElemT, I, E)) + return false; + if (!saveToLocal(*ElemT)) + return false; + } + } else { + for (unsigned I = 0; I != ArrSize && Elements.size() < MaxElements; ++I) { + if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) + return false; + if (!this->emitConstUint32(I, E)) + return false; + if (!this->emitArrayElemPtrPopUint32(E)) + return false; + UnsignedOrNone ElemPtrOffset = savePtrToLocal(); + if (!ElemPtrOffset) + return false; + if (!emitHLSLFlattenAggregate(ArrElemType, *ElemPtrOffset, Elements, + MaxElements, E)) + return false; + } + } + return true; + } + + // Records: base classes come first, then named fields in declaration + // order. + if (SrcType->isRecordType()) { + const Record *R = getRecord(SrcType); + if (!R) + return false; + + if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) { + for (const CXXBaseSpecifier &BS : CXXRD->bases()) { + if (Elements.size() >= MaxElements) + break; + const Record::Base *B = R->getBase(BS.getType()); + assert(B); + if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) + return false; + if (!this->emitGetPtrBasePop(B->Offset, /*NullOK=*/false, E)) + return false; + UnsignedOrNone BasePtrOffset = savePtrToLocal(); + if (!BasePtrOffset) + return false; + if (!emitHLSLFlattenAggregate(BS.getType(), *BasePtrOffset, Elements, + MaxElements, E)) + return false; + } + } + + for (const Record::Field &F : R->fields()) { + if (Elements.size() >= MaxElements) + break; + if (F.isUnnamedBitField()) + continue; + + QualType FieldType = F.Decl->getType(); + if (!this->emitGetLocal(PT_Ptr, SrcOffset, E)) + return false; + if (!this->emitGetPtrFieldPop(F.Offset, E)) + return false; + + if (OptPrimType FieldT = classify(FieldType)) { + if (!this->emitLoadPop(*FieldT, E)) + return false; + if (!saveToLocal(*FieldT)) + return false; + } else { + UnsignedOrNone FieldPtrOffset = savePtrToLocal(); + if (!FieldPtrOffset) + return false; + if (!emitHLSLFlattenAggregate(FieldType, *FieldPtrOffset, Elements, + MaxElements, E)) + return false; + } + } + return true; + } + + return false; +} + +/// Populate an HLSL aggregate from a flat list of previously extracted source +/// elements, casting each to the corresponding destination element type. +/// \p ElemIdx tracks the current position in \p Elements and is advanced as +/// elements are consumed. A pointer to the destination must be on top of the +/// interpreter stack. +template <class Emitter> +bool Compiler<Emitter>::emitHLSLConstructAggregate( + QualType DestType, ArrayRef<HLSLFlatElement> Elements, unsigned &ElemIdx, + const Expr *E) { + + // Consume the next source element, cast it, and leave it on the stack. + auto loadAndCast = [&](PrimType DestT, QualType DestQT) -> bool { + const auto &Src = Elements[ElemIdx++]; + if (!this->emitGetLocal(Src.Type, Src.LocalOffset, E)) + return false; + return this->emitPrimCast(Src.Type, DestT, DestQT, E); + }; + + // Vectors and matrices are flat sequences of elements. + unsigned NumElems = 0; + QualType ElemType; + if (const auto *VT = DestType->getAs<VectorType>()) { + NumElems = VT->getNumElements(); + ElemType = VT->getElementType(); + } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) { + NumElems = MT->getNumElementsFlattened(); + ElemType = MT->getElementType(); + } + if (NumElems > 0) { + PrimType DestElemT = classifyPrim(ElemType); + for (unsigned I = 0; I != NumElems; ++I) { + if (!loadAndCast(DestElemT, ElemType)) + return false; + if (!this->emitInitElem(DestElemT, I, E)) + return false; + } + return true; + } + + // Arrays: primitive elements are filled directly; composite elements + // require recursion into each sub-aggregate. + if (const auto *AT = DestType->getAsArrayTypeUnsafe()) { + const auto *CAT = cast<ConstantArrayType>(AT); + QualType ArrElemType = CAT->getElementType(); + unsigned ArrSize = CAT->getZExtSize(); + + if (OptPrimType ElemT = classify(ArrElemType)) { + for (unsigned I = 0; I != ArrSize; ++I) { + if (!loadAndCast(*ElemT, ArrElemType)) + return false; + if (!this->emitInitElem(*ElemT, I, E)) + return false; + } + } else { + for (unsigned I = 0; I != ArrSize; ++I) { + if (!this->emitConstUint32(I, E)) + return false; + if (!this->emitArrayElemPtrUin... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/189126 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
