================ @@ -15499,6 +15836,496 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective( buildPreInits(Context, PreInits)); } +StmtResult SemaOpenMP::ActOnOpenMPFuseDirective(ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc) { + + ASTContext &Context = getASTContext(); + DeclContext *CurrContext = SemaRef.CurContext; + Scope *CurScope = SemaRef.getCurScope(); + CaptureVars CopyTransformer(SemaRef); + + // Ensure the structured block is not empty + if (!AStmt) + return StmtError(); + + unsigned NumLoops = 1; + unsigned LoopSeqSize = 1; + + // Defer transformation in dependent contexts + // The NumLoopNests argument is set to a placeholder 1 (even though + // using looprange fuse could yield up to 3 top level loop nests) + // because a dependent context could prevent determining its true value + if (CurrContext->isDependentContext()) { + return OMPFuseDirective::Create(Context, StartLoc, EndLoc, Clauses, + NumLoops, LoopSeqSize, AStmt, nullptr, + nullptr); + } + + // Validate that the potential loop sequence is transformable for fusion + // Also collect the HelperExprs, Loop Stmts, Inits, and Number of loops + SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers; + SmallVector<Stmt *> LoopStmts; + SmallVector<SmallVector<Stmt *>> OriginalInits; + SmallVector<SmallVector<Stmt *>> TransformsPreInits; + SmallVector<SmallVector<Stmt *>> LoopSequencePreInits; + SmallVector<OMPLoopCategory, 0> LoopCategories; + if (!checkTransformableLoopSequence(OMPD_fuse, AStmt, LoopSeqSize, NumLoops, + LoopHelpers, LoopStmts, OriginalInits, + TransformsPreInits, LoopSequencePreInits, + LoopCategories, Context)) + return StmtError(); + + // Handle clauses, which can be any of the following: [looprange, apply] + const OMPLoopRangeClause *LRC = + OMPExecutableDirective::getSingleClause<OMPLoopRangeClause>(Clauses); + + // The clause arguments are invalidated if any error arises + // such as non-constant or non-positive arguments + if (LRC && (!LRC->getFirst() || !LRC->getCount())) + return StmtError(); + + // Delayed semantic check of LoopRange constraint + // Evaluates the loop range arguments and returns the first and count values + auto EvaluateLoopRangeArguments = [&Context](Expr *First, Expr *Count, + uint64_t &FirstVal, + uint64_t &CountVal) { + llvm::APSInt FirstInt = First->EvaluateKnownConstInt(Context); + llvm::APSInt CountInt = Count->EvaluateKnownConstInt(Context); + FirstVal = FirstInt.getZExtValue(); + CountVal = CountInt.getZExtValue(); + }; + + // OpenMP [6.0, Restrictions] + // first + count - 1 must not evaluate to a value greater than the + // loop sequence length of the associated canonical loop sequence. + auto ValidLoopRange = [](uint64_t FirstVal, uint64_t CountVal, + unsigned NumLoops) -> bool { + return FirstVal + CountVal - 1 <= NumLoops; + }; + uint64_t FirstVal = 1, CountVal = 0, LastVal = LoopSeqSize; + + // Validates the loop range after evaluating the semantic information + // and ensures that the range is valid for the given loop sequence size. + // Expressions are evaluated at compile time to obtain constant values. + if (LRC) { + EvaluateLoopRangeArguments(LRC->getFirst(), LRC->getCount(), FirstVal, + CountVal); + if (CountVal == 1) + SemaRef.Diag(LRC->getCountLoc(), diag::warn_omp_redundant_fusion) + << getOpenMPDirectiveName(OMPD_fuse); + + if (!ValidLoopRange(FirstVal, CountVal, LoopSeqSize)) { + SemaRef.Diag(LRC->getFirstLoc(), diag::err_omp_invalid_looprange) + << getOpenMPDirectiveName(OMPD_fuse) << (FirstVal + CountVal - 1) + << LoopSeqSize; + return StmtError(); + } + + LastVal = FirstVal + CountVal - 1; + } + + // Complete fusion generates a single canonical loop nest + // However looprange clause generates several loop nests + unsigned NumLoopNests = LRC ? LoopSeqSize - CountVal + 1 : 1; + + // Emit a warning for redundant loop fusion when the sequence contains only + // one loop. + if (LoopSeqSize == 1) + SemaRef.Diag(AStmt->getBeginLoc(), diag::warn_omp_redundant_fusion) + << getOpenMPDirectiveName(OMPD_fuse); + + assert(LoopHelpers.size() == LoopSeqSize && + "Expecting loop iteration space dimensionality to match number of " + "affected loops"); + assert(OriginalInits.size() == LoopSeqSize && + "Expecting loop iteration space dimensionality to match number of " + "affected loops"); + + // Select the type with the largest bit width among all induction variables + QualType IVType = LoopHelpers[FirstVal - 1].IterationVarRef->getType(); + for (unsigned int I = FirstVal; I < LastVal; ++I) { + QualType CurrentIVType = LoopHelpers[I].IterationVarRef->getType(); + if (Context.getTypeSize(CurrentIVType) > Context.getTypeSize(IVType)) { + IVType = CurrentIVType; + } + } + uint64_t IVBitWidth = Context.getIntWidth(IVType); + + // Create pre-init declarations for all loops lower bounds, upper bounds, + // strides and num-iterations for every top level loop in the fusion + SmallVector<VarDecl *, 4> LBVarDecls; + SmallVector<VarDecl *, 4> STVarDecls; + SmallVector<VarDecl *, 4> NIVarDecls; + SmallVector<VarDecl *, 4> UBVarDecls; + SmallVector<VarDecl *, 4> IVVarDecls; + + // Helper lambda to create variables for bounds, strides, and other + // expressions. Generates both the variable declaration and the corresponding + // initialization statement. + auto CreateHelperVarAndStmt = + [&, &SemaRef = SemaRef](Expr *ExprToCopy, const std::string &BaseName, + unsigned I, bool NeedsNewVD = false) { + Expr *TransformedExpr = + AssertSuccess(CopyTransformer.TransformExpr(ExprToCopy)); + if (!TransformedExpr) + return std::pair<VarDecl *, StmtResult>(nullptr, StmtError()); + + auto Name = (Twine(".omp.") + BaseName + std::to_string(I)).str(); + + VarDecl *VD; + if (NeedsNewVD) { + VD = buildVarDecl(SemaRef, SourceLocation(), IVType, Name); + SemaRef.AddInitializerToDecl(VD, TransformedExpr, false); + + } else { + // Create a unique variable name + DeclRefExpr *DRE = cast<DeclRefExpr>(TransformedExpr); + VD = cast<VarDecl>(DRE->getDecl()); + VD->setDeclName(&SemaRef.PP.getIdentifierTable().get(Name)); + } + // Create the corresponding declaration statement + StmtResult DeclStmt = new (Context) class DeclStmt( + DeclGroupRef(VD), SourceLocation(), SourceLocation()); + return std::make_pair(VD, DeclStmt); + }; + + // PreInits hold a sequence of variable declarations that must be executed + // before the fused loop begins. These include bounds, strides, and other + // helper variables required for the transformation. Other loop transforms + // also contain their own preinits + SmallVector<Stmt *> PreInits; + // Iterator to keep track of loop transformations + unsigned int TransformIndex = 0; + + // Update the general preinits using the preinits generated by loop sequence + // generating loop transformations. These preinits differ slightly from + // single-loop transformation preinits, as they can be detached from a + // specific loop inside the multiple generated loop nests. This happens + // because certain helper variables, like '.omp.fuse.max', are introduced to + // handle fused iteration spaces and may not be directly tied to a single + // original loop. the preinit structure must ensure that hidden variables + // like '.omp.fuse.max' are still properly handled. + // Transformations that apply this concept: Loopranged Fuse, Split + if (!LoopSequencePreInits.empty()) { + for (const auto <PreInits : LoopSequencePreInits) { + if (!LTPreInits.empty()) + llvm::append_range(PreInits, LTPreInits); + } + } + + // Process each single loop to generate and collect declarations + // and statements for all helper expressions related to + // particular single loop nests + + // Also In the case of the fused loops, we keep track of their original + // inits by appending them to their preinits statement, and in the case of + // transformations, also append their preinits (which contain the original + // loop initialization statement or other statements) + + // Firstly we need to update TransformIndex to match the begining of the + // looprange section + for (unsigned int I = 0; I < FirstVal - 1; ++I) { + if (LoopCategories[I] == OMPLoopCategory::TransformSingleLoop) + ++TransformIndex; + } + for (unsigned int I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) { + + if (LoopCategories[I] == OMPLoopCategory::RegularLoop) { + addLoopPreInits(Context, LoopHelpers[I], LoopStmts[I], OriginalInits[I], + PreInits); + } else if (LoopCategories[I] == OMPLoopCategory::TransformSingleLoop) { + // For transformed loops, insert both pre-inits and original inits. + // Order matters: pre-inits may define variables used in the original + // inits such as upper bounds... + auto TransformPreInit = TransformsPreInits[TransformIndex++]; + if (!TransformPreInit.empty()) + llvm::append_range(PreInits, TransformPreInit); + + addLoopPreInits(Context, LoopHelpers[I], LoopStmts[I], OriginalInits[I], + PreInits); + } + auto [UBVD, UBDStmt] = CreateHelperVarAndStmt(LoopHelpers[I].UB, "ub", J); + auto [LBVD, LBDStmt] = CreateHelperVarAndStmt(LoopHelpers[I].LB, "lb", J); + auto [STVD, STDStmt] = CreateHelperVarAndStmt(LoopHelpers[I].ST, "st", J); + auto [NIVD, NIDStmt] = + CreateHelperVarAndStmt(LoopHelpers[I].NumIterations, "ni", J, true); + auto [IVVD, IVDStmt] = + CreateHelperVarAndStmt(LoopHelpers[I].IterationVarRef, "iv", J); + + if (!LBVD || !STVD || !NIVD || !IVVD) + assert(LBVD && STVD && NIVD && IVVD && + "OpenMP Fuse Helper variables creation failed"); + + UBVarDecls.push_back(UBVD); + LBVarDecls.push_back(LBVD); + STVarDecls.push_back(STVD); + NIVarDecls.push_back(NIVD); + IVVarDecls.push_back(IVVD); + + PreInits.push_back(LBDStmt.get()); + PreInits.push_back(STDStmt.get()); + PreInits.push_back(NIDStmt.get()); + PreInits.push_back(IVDStmt.get()); + } + + auto MakeVarDeclRef = [&SemaRef = this->SemaRef](VarDecl *VD) { + return buildDeclRefExpr(SemaRef, VD, VD->getType(), VD->getLocation(), + false); + }; + + // Following up the creation of the final fused loop will be performed + // which has the following shape (considering the selected loops): + // + // for (fuse.index = 0; fuse.index < max(ni0, ni1..., nik); ++fuse.index) { + // if (fuse.index < ni0){ + // iv0 = lb0 + st0 * fuse.index; + // original.index0 = iv0 + // body(0); + // } + // if (fuse.index < ni1){ + // iv1 = lb1 + st1 * fuse.index; + // original.index1 = iv1 + // body(1); + // } + // + // ... + // + // if (fuse.index < nik){ + // ivk = lbk + stk * fuse.index; + // original.indexk = ivk + // body(k); Expr *InitVal = IntegerLiteral::Create(Context, + // llvm::APInt(IVWidth, 0), + + // } + + // 1. Create the initialized fuse index + const std::string IndexName = Twine(".omp.fuse.index").str(); + Expr *InitVal = IntegerLiteral::Create(Context, llvm::APInt(IVBitWidth, 0), + IVType, SourceLocation()); + VarDecl *IndexDecl = + buildVarDecl(SemaRef, {}, IVType, IndexName, nullptr, nullptr); + SemaRef.AddInitializerToDecl(IndexDecl, InitVal, false); + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef(IndexDecl), SourceLocation(), SourceLocation()); + + if (!InitStmt.isUsable()) + return StmtError(); + + auto MakeIVRef = [&SemaRef = this->SemaRef, IndexDecl, IVType, + Loc = InitVal->getExprLoc()]() { + return buildDeclRefExpr(SemaRef, IndexDecl, IVType, Loc, false); + }; + + // 2. Iteratively compute the max number of logical iterations Max(NI_1, NI_2, + // ..., NI_k) + // + // This loop accumulates the maximum value across multiple expressions, + // ensuring each step constructs a unique AST node for correctness. By using + // intermediate temporary variables and conditional operators, we maintain + // distinct nodes and avoid duplicating subtrees, For instance, max(a,b,c): + // omp.temp0 = max(a, b) + // omp.temp1 = max(omp.temp0, c) + // omp.fuse.max = max(omp.temp1, omp.temp0) + + ExprResult MaxExpr; + // I is the true + for (unsigned I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) { + DeclRefExpr *NIRef = MakeVarDeclRef(NIVarDecls[J]); + QualType NITy = NIRef->getType(); + + if (MaxExpr.isUnset()) { + // Initialize MaxExpr with the first NI expression + MaxExpr = NIRef; + } else { + // Create a new acummulator variable t_i = MaxExpr + std::string TempName = (Twine(".omp.temp.") + Twine(J)).str(); + VarDecl *TempDecl = + buildVarDecl(SemaRef, {}, NITy, TempName, nullptr, nullptr); + TempDecl->setInit(MaxExpr.get()); + DeclRefExpr *TempRef = + buildDeclRefExpr(SemaRef, TempDecl, NITy, SourceLocation(), false); + DeclRefExpr *TempRef2 = + buildDeclRefExpr(SemaRef, TempDecl, NITy, SourceLocation(), false); + // Add a DeclStmt to PreInits to ensure the variable is declared. + StmtResult TempStmt = new (Context) + DeclStmt(DeclGroupRef(TempDecl), SourceLocation(), SourceLocation()); + + if (!TempStmt.isUsable()) + return StmtError(); + PreInits.push_back(TempStmt.get()); + + // Build MaxExpr <-(MaxExpr > NIRef ? MaxExpr : NIRef) + ExprResult Comparison = + SemaRef.BuildBinOp(nullptr, SourceLocation(), BO_GT, TempRef, NIRef); + // Handle any errors in Comparison creation + if (!Comparison.isUsable()) + return StmtError(); + + DeclRefExpr *NIRef2 = MakeVarDeclRef(NIVarDecls[J]); + // Update MaxExpr using a conditional expression to hold the max value + MaxExpr = new (Context) ConditionalOperator( + Comparison.get(), SourceLocation(), TempRef2, SourceLocation(), + NIRef2->getExprStmt(), NITy, VK_LValue, OK_Ordinary); + + if (!MaxExpr.isUsable()) + return StmtError(); + } + } + if (!MaxExpr.isUsable()) + return StmtError(); + + // 3. Declare the max variable + const std::string MaxName = Twine(".omp.fuse.max").str(); + VarDecl *MaxDecl = + buildVarDecl(SemaRef, {}, IVType, MaxName, nullptr, nullptr); + MaxDecl->setInit(MaxExpr.get()); + DeclRefExpr *MaxRef = buildDeclRefExpr(SemaRef, MaxDecl, IVType, {}, false); + StmtResult MaxStmt = new (Context) + DeclStmt(DeclGroupRef(MaxDecl), SourceLocation(), SourceLocation()); + + if (MaxStmt.isInvalid()) + return StmtError(); + PreInits.push_back(MaxStmt.get()); + + // 4. Create condition Expr: index < n_max + ExprResult CondExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_LT, + MakeIVRef(), MaxRef); + if (!CondExpr.isUsable()) + return StmtError(); + // 5. Increment Expr: ++index + ExprResult IncrExpr = + SemaRef.BuildUnaryOp(CurScope, SourceLocation(), UO_PreInc, MakeIVRef()); + if (!IncrExpr.isUsable()) + return StmtError(); + + // 6. Build the Fused Loop Body + // The final fused loop iterates over the maximum logical range. Inside the + // loop, each original loop's index is calculated dynamically, and its body + // is executed conditionally. + // + // Each sub-loop's body is guarded by a conditional statement to ensure + // it executes only within its logical iteration range: + // + // if (fuse.index < ni_k){ + // iv_k = lb_k + st_k * fuse.index; + // original.index = iv_k + // body(k); + // } + + CompoundStmt *FusedBody = nullptr; + SmallVector<Stmt *, 4> FusedBodyStmts; + for (unsigned I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) { + // Assingment of the original sub-loop index to compute the logical index + // IV_k = LB_k + omp.fuse.index * ST_k + ExprResult IdxExpr = + SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Mul, + MakeVarDeclRef(STVarDecls[J]), MakeIVRef()); + if (!IdxExpr.isUsable()) + return StmtError(); + IdxExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Add, + MakeVarDeclRef(LBVarDecls[J]), IdxExpr.get()); + + if (!IdxExpr.isUsable()) + return StmtError(); + IdxExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Assign, + MakeVarDeclRef(IVVarDecls[J]), IdxExpr.get()); + if (!IdxExpr.isUsable()) + return StmtError(); + + // Update the original i_k = IV_k + SmallVector<Stmt *, 4> BodyStmts; + BodyStmts.push_back(IdxExpr.get()); + llvm::append_range(BodyStmts, LoopHelpers[I].Updates); + + // If the loop is a CXXForRangeStmt then the iterator variable is needed + if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmts[I])) + BodyStmts.push_back(SourceCXXFor->getLoopVarStmt()); + + Stmt *Body = (isa<ForStmt>(LoopStmts[I])) + ? cast<ForStmt>(LoopStmts[I])->getBody() + : cast<CXXForRangeStmt>(LoopStmts[I])->getBody(); + BodyStmts.push_back(Body); + + CompoundStmt *CombinedBody = + CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), + SourceLocation(), SourceLocation()); + ExprResult Condition = + SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_LT, MakeIVRef(), + MakeVarDeclRef(NIVarDecls[J])); + + if (!Condition.isUsable()) + return StmtError(); + + IfStmt *IfStatement = IfStmt::Create( + Context, SourceLocation(), IfStatementKind::Ordinary, nullptr, nullptr, + Condition.get(), SourceLocation(), SourceLocation(), CombinedBody, + SourceLocation(), nullptr); + + FusedBodyStmts.push_back(IfStatement); + } + FusedBody = CompoundStmt::Create(Context, FusedBodyStmts, FPOptionsOverride(), + SourceLocation(), SourceLocation()); + + // 7. Construct the final fused loop + ForStmt *FusedForStmt = new (Context) + ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, IncrExpr.get(), + FusedBody, InitStmt.get()->getBeginLoc(), SourceLocation(), + IncrExpr.get()->getEndLoc()); + + // In the case of looprange, the result of fuse won't simply + // be a single loop (ForStmt), but rather a loop sequence + // (CompoundStmt) of 3 parts: the pre-fusion loops, the fused loop + // and the post-fusion loops, preserving its original order. + // + // Note: If looprange clause produces a single fused loop nest then + // this compound statement wrapper is unnecessary (Therefore this + // treatment is skipped) + + Stmt *FusionStmt = FusedForStmt; + if (LRC && CountVal != LoopSeqSize) { + SmallVector<Stmt *, 4> FinalLoops; + // Reset the transform index + TransformIndex = 0; + + // Collect all non-fused loops before and after the fused region. + // Pre-fusion and post-fusion loops are inserted in order exploiting their + // symmetry, along with their corresponding transformation pre-inits if + // needed. The fused loop is added between the two regions. + for (unsigned I = 0; I < LoopSeqSize; ++I) { + if (I >= FirstVal - 1 && I < FirstVal + CountVal - 1) { + // Update the Transformation counter to skip already treated + // loop transformations + if (LoopCategories[I] != OMPLoopCategory::TransformSingleLoop) + ++TransformIndex; + continue; + } + + // No need to handle: + // Regular loops: they are kept intact as-is. + // Loop-sequence-generating transformations: already handled earlier. + // Only TransformSingleLoop requires inserting pre-inits here + + if (LoopCategories[I] == OMPLoopCategory::TransformSingleLoop) { + auto TransformPreInit = TransformsPreInits[TransformIndex++]; ---------------- alexey-bataev wrote:
Better to use ArrayRef https://github.com/llvm/llvm-project/pull/139293 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits