================
@@ -15907,6 +15916,209 @@ StmtResult 
SemaOpenMP::ActOnOpenMPReverseDirective(Stmt *AStmt,
                                      buildPreInits(Context, PreInits));
 }
 
+/// Build the AST for \#pragma omp split counts(c1, c2, ...).
+///
+/// Splits the single associated loop into N consecutive loops, where N is the
+/// number of count expressions.
+StmtResult SemaOpenMP::ActOnOpenMPSplitDirective(ArrayRef<OMPClause *> Clauses,
+                                                 Stmt *AStmt,
+                                                 SourceLocation StartLoc,
+                                                 SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  Scope *CurScope = SemaRef.getCurScope();
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  const OMPCountsClause *CountsClause =
+      OMPExecutableDirective::getSingleClause<OMPCountsClause>(Clauses);
+  if (!CountsClause)
+    return StmtError();
+
+  // Split applies to a single loop; check it is transformable and get helpers.
+  constexpr unsigned NumLoops = 1;
+  Stmt *Body = nullptr;
+  SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
+      NumLoops);
+  SmallVector<SmallVector<Stmt *>, NumLoops + 1> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_split, AStmt, NumLoops, LoopHelpers,
+                                  Body, OriginalInits))
+    return StmtError();
+
+  // Delay applying the transformation to when template is completely
+  // instantiated.
+  if (SemaRef.CurContext->isDependentContext())
+    return OMPSplitDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                     NumLoops, AStmt, nullptr, nullptr);
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
+
+  // Find the loop statement.
+  Stmt *LoopStmt = nullptr;
+  collectLoopStmts(AStmt, {LoopStmt});
+
+  // Determine the PreInit declarations.
+  SmallVector<Stmt *> PreInits;
+  addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
+
+  // Type and name of the original loop variable; we create one IV per segment
+  // and assign it to the original var so the body sees the same name.
+  auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
+  QualType IVTy = IterationVarRef->getType();
+  uint64_t IVWidth = Context.getTypeSize(IVTy);
+  auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+  // Iteration variable SourceLocations.
+  SourceLocation OrigVarLoc = OrigVar->getExprLoc();
+  SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc();
+  SourceLocation OrigVarLocEnd = OrigVar->getEndLoc();
+  // Internal variable names.
+  std::string OrigVarName = OrigVar->getNameInfo().getAsString();
+
+  enum class SplitCountKind { Constant, Fill };
+  SmallVector<std::pair<SplitCountKind, uint64_t>, 4> Entries;
+  for (Expr *CountExpr : CountsClause->getCountsRefs()) {
+    if (!CountExpr)
+      return OMPSplitDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                       NumLoops, AStmt, nullptr, nullptr);
+    if (isOMPFillCountExpr(CountExpr)) {
+      Entries.push_back({SplitCountKind::Fill, 0});
+      continue;
+    }
+    std::optional<llvm::APSInt> OptVal =
+        CountExpr->getIntegerConstantExpr(Context);
+    if (!OptVal || OptVal->isNegative())
+      return OMPSplitDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                       NumLoops, AStmt, nullptr, nullptr);
+    Entries.push_back({SplitCountKind::Constant, OptVal->getZExtValue()});
+  }
+
+  if (Entries.empty())
+    return StmtError();
+
+  unsigned NumFill = 0;
+  unsigned FillPos = 0;
+  for (unsigned I = 0; I < Entries.size(); ++I) {
+    if (Entries[I].first == SplitCountKind::Fill) {
+      ++NumFill;
+      FillPos = I;
+    }
+  }
+  if (NumFill > 1) {
+    Diag(CountsClause->getBeginLoc(),
+         diag::err_omp_split_counts_multiple_omp_fill);
+    return StmtError();
+  }
+  if (NumFill == 1 && FillPos != Entries.size() - 1) {
+    Diag(CountsClause->getBeginLoc(),
+         diag::err_omp_split_counts_omp_fill_not_last);
+    return StmtError();
+  }
+
+  Expr *NumIterExpr = LoopHelper.NumIterations;
+  if (NumFill == 1 && !NumIterExpr) {
+    Diag(CountsClause->getBeginLoc(),
+         diag::err_omp_split_counts_omp_fill_no_trip);
+    return StmtError();
----------------
amitamd7 wrote:

Yeah this has been taken care of with `checkTransformableLoopNest(...)`. Will 
address the 2nd point in the upcoming tests.

https://github.com/llvm/llvm-project/pull/183261
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to