Author: Michael Kruse Date: 2021-10-06T12:21:04-05:00 New Revision: 2130117f92e51df73ac8c4b7e37f7f89178a89f2
URL: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2 DIFF: https://github.com/llvm/llvm-project/commit/2130117f92e51df73ac8c4b7e37f7f89178a89f2.diff LOG: [Clang][OpenMP] Allow loop-transformations with template parameters. Clang would reject #pragma omp for #pragma omp tile sizes(P) for (int i = 0; i < 128; ++i) {} where P is a template parameter, but the loop itself is not template-dependent. Because P context-dependent, the TransformedStmt cannot be generated and therefore is nullptr (until the template is instantiated by TreeTransform). The OMPForDirective would still expect the a loop is the dependent context and trigger an error. Fix by introducing a NumGeneratedLoops field to OMPLoopTransformation. This is used to distinguish the case where no TransformedStmt will be generated at all (e.g. #pragma omp unroll full) and template instantiation is needed. In the latter case, delay resolving the iteration space like when the for-loop itself is template-dependent until the template instatiation. A more radical solution would always delay the iteration space analysis until template instantiation, but would also break many test cases. Reviewed By: ABataev Differential Revision: https://reviews.llvm.org/D111124 Added: Modified: clang/include/clang/AST/StmtOpenMP.h clang/lib/AST/StmtOpenMP.cpp clang/lib/Sema/SemaOpenMP.cpp clang/lib/Serialization/ASTReaderStmt.cpp clang/lib/Serialization/ASTWriterStmt.cpp clang/test/OpenMP/tile_ast_print.cpp clang/test/OpenMP/unroll_ast_print.cpp Removed: ################################################################################ diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index 285426d26e21..60d47b93ba79 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -959,6 +959,9 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { class OMPLoopTransformationDirective : public OMPLoopBasedDirective { friend class ASTStmtReader; + /// Number of loops generated by this loop transformation. + unsigned NumGeneratedLoops = 0; + protected: explicit OMPLoopTransformationDirective(StmtClass SC, OpenMPDirectiveKind Kind, @@ -967,10 +970,16 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective { unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} + /// Set the number of loops generated by this loop transformation. + void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } + public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } + /// Return the number of loops generated by this loop transformation. + unsigned getNumGeneratedLoops() { return NumGeneratedLoops; } + /// Get the de-sugared statements after after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5058,7 +5067,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective { unsigned NumLoops) : OMPLoopTransformationDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) {} + NumLoops) { + setNumGeneratedLoops(3 * NumLoops); + } void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5163,7 +5174,7 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective { static OMPUnrollDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, - Stmt *TransformedStmt, Stmt *PreInits); + unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits); /// Build an empty '#pragma omp unroll' AST node for deserialization. /// diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index c615463f42da..014274f46cae 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -138,9 +138,18 @@ bool OMPLoopBasedDirective::doForAllLoops( Stmt *TransformedStmt = Dir->getTransformedStmt(); if (!TransformedStmt) { - // May happen if the loop transformation does not result in a - // generated loop (such as full unrolling). - break; + unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops(); + if (NumGeneratedLoops == 0) { + // May happen if the loop transformation does not result in a + // generated loop (such as full unrolling). + break; + } + if (NumGeneratedLoops > 0) { + // The loop transformation construct has generated loops, but these + // may not have been generated yet due to being in a dependent + // context. + return true; + } } CurStmt = TransformedStmt; @@ -419,10 +428,13 @@ OMPTileDirective *OMPTileDirective::CreateEmpty(const ASTContext &C, OMPUnrollDirective * OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, - Stmt *AssociatedStmt, Stmt *TransformedStmt, - Stmt *PreInits) { + Stmt *AssociatedStmt, unsigned NumGeneratedLoops, + Stmt *TransformedStmt, Stmt *PreInits) { + assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop"); + auto *Dir = createDirective<OMPUnrollDirective>( C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); + Dir->setNumGeneratedLoops(NumGeneratedLoops); Dir->setTransformedStmt(TransformedStmt); Dir->setPreInits(PreInits); return Dir; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index af70a180b27c..850b6f162d72 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -12919,10 +12919,12 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, Body, OriginalInits)) return StmtError(); + unsigned NumGeneratedLoops = PartialClause ? 1 : 0; + // Delay unrolling to when template is completely instantiated. if (CurContext->isDependentContext()) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - nullptr, nullptr); + NumGeneratedLoops, nullptr, nullptr); OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front(); @@ -12941,9 +12943,9 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, // The generated loop may only be passed to other loop-associated directive // when a partial clause is specified. Without the requirement it is // sufficient to generate loop unroll metadata at code-generation. - if (!PartialClause) + if (NumGeneratedLoops == 0) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - nullptr, nullptr); + NumGeneratedLoops, nullptr, nullptr); // Otherwise, we need to provide a de-sugared/transformed AST that can be // associated with another loop directive. @@ -13164,7 +13166,8 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - OuterFor, buildPreInits(Context, PreInits)); + NumGeneratedLoops, OuterFor, + buildPreInits(Context, PreInits)); } OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr, diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 34a58831e0d4..4e6eaf77ff56 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtReader::VisitOMPLoopTransformationDirective( OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); + D->setNumGeneratedLoops(Record.readUInt32()); } void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index bf32294bc95f..000bf808d32b 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtWriter::VisitOMPLoopTransformationDirective( OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); + Record.writeUInt32(D->getNumGeneratedLoops()); } void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/test/OpenMP/tile_ast_print.cpp b/clang/test/OpenMP/tile_ast_print.cpp index 37791f0a8475..14f064358d8d 100644 --- a/clang/test/OpenMP/tile_ast_print.cpp +++ b/clang/test/OpenMP/tile_ast_print.cpp @@ -162,4 +162,25 @@ void tfoo6() { } +// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) { +// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7 +template <int Tile> +void foo7(int start, int stop, int step) { + // PRINT: #pragma omp tile sizes(Tile) + // DUMP: OMPTileDirective + // DUMP-NEXT: OMPSizesClause + // DUMP-NEXT: DeclRefExpr {{.*}} 'Tile' 'int' + #pragma omp tile sizes(Tile) + // PRINT-NEXT: for (int i = start; i < stop; i += step) + // DUMP-NEXT: ForStmt + for (int i = start; i < stop; i += step) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} +void tfoo7() { + foo7<5>(0, 42, 2); +} + + #endif diff --git a/clang/test/OpenMP/unroll_ast_print.cpp b/clang/test/OpenMP/unroll_ast_print.cpp index 63e7b1dbe6ed..4d858284877f 100644 --- a/clang/test/OpenMP/unroll_ast_print.cpp +++ b/clang/test/OpenMP/unroll_ast_print.cpp @@ -124,4 +124,26 @@ void unroll_template() { unroll_templated<int,0,1024,1,4>(); } + +// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) { +// DUMP-LABEL: FunctionTemplateDecl {{.*}} unroll_templated_factor +template <int Factor> +void unroll_templated_factor(int start, int stop, int step) { + // PRINT: #pragma omp unroll partial(Factor) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: DeclRefExpr {{.*}} 'Factor' 'int' + #pragma omp unroll partial(Factor) + // PRINT-NEXT: for (int i = start; i < stop; i += step) + // DUMP-NEXT: ForStmt + for (int i = start; i < stop; i += step) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} +void unroll_template_factor() { + unroll_templated_factor<4>(0, 42, 2); +} + + #endif _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits