ggeorgakoudis updated this revision to Diff 321375.
ggeorgakoudis added a comment.

Fix type for IfCond, formatting


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D95976/new/

https://reviews.llvm.org/D95976

Files:
  clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
  openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
  openmp/libomptarget/deviceRTLs/common/src/parallel.cu
  openmp/libomptarget/deviceRTLs/common/src/support.cu
  openmp/libomptarget/deviceRTLs/common/support.h
  openmp/libomptarget/deviceRTLs/interface.h

Index: openmp/libomptarget/deviceRTLs/interface.h
===================================================================
--- openmp/libomptarget/deviceRTLs/interface.h
+++ openmp/libomptarget/deviceRTLs/interface.h
@@ -177,6 +177,7 @@
  * The struct is identical to the one in the kmp.h file.
  * We maintain the same data structure for compatibility.
  */
+typedef short kmp_int16;
 typedef int kmp_int32;
 typedef struct ident {
   kmp_int32 reserved_1; /**<  might be used in Fortran; see above  */
@@ -437,6 +438,22 @@
 EXTERN void __kmpc_end_sharing_variables();
 EXTERN void __kmpc_get_shared_variables(void ***GlobalArgs);
 
+/// Entry point to start a new parallel region.
+///
+/// \param ident       The source identifier.
+/// \param global_tid  The global thread ID.
+/// \param if_expr     The if(expr), or 1 if none given.
+/// \param num_threads The num_threads(expr), or -1 if none given.
+/// \param proc_bind   The proc_bind, or `proc_bind_default` if none given.
+/// \param fn          The outlined parallel region function.
+/// \param wrapper_fn  The worker wrapper function of fn.
+/// \param args        The pointer array of arguments to fn.
+/// \param nargs       The number of arguments to fn.
+EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid,
+                               kmp_int32 if_expr, kmp_int32 num_threads,
+                               int proc_bind, void *fn, void *wrapper_fn,
+                               void **args, size_t nargs);
+
 // SPMD execution mode interrogation function.
 EXTERN int8_t __kmpc_is_spmd_exec_mode();
 
Index: openmp/libomptarget/deviceRTLs/common/support.h
===================================================================
--- openmp/libomptarget/deviceRTLs/common/support.h
+++ openmp/libomptarget/deviceRTLs/common/support.h
@@ -95,4 +95,9 @@
 DEVICE unsigned int *GetTeamsReductionTimestamp();
 DEVICE char *GetTeamsReductionScratchpad();
 
+// Invoke an outlined parallel function unwrapping global, shared arguments (up
+// to 16).
+DEVICE void __kmp_invoke_microtask(kmp_int32 global_tid, kmp_int32 bound_tid,
+                                   void *fn, void **args, size_t nargs);
+
 #endif
Index: openmp/libomptarget/deviceRTLs/common/src/support.cu
===================================================================
--- openmp/libomptarget/deviceRTLs/common/src/support.cu
+++ openmp/libomptarget/deviceRTLs/common/src/support.cu
@@ -265,4 +265,110 @@
   return static_cast<char *>(ReductionScratchpadPtr) + 256;
 }
 
+// Invoke an outlined parallel function unwrapping arguments (up
+// to 16).
+DEVICE void __kmp_invoke_microtask(kmp_int32 global_tid, kmp_int32 bound_tid,
+                                   void *fn, void **args, size_t nargs) {
+  switch (nargs) {
+  case 0:
+    ((void (*)(kmp_int32 *, kmp_int32 *))fn)(&global_tid, &bound_tid);
+    break;
+  case 1:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *))fn)(&global_tid, &bound_tid,
+                                                     args[0]);
+    break;
+  case 2:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1]);
+    break;
+  case 3:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2]);
+    break;
+  case 4:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3]);
+    break;
+  case 5:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *,
+               void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2],
+                           args[3], args[4]);
+    break;
+  case 6:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2],
+                           args[3], args[4], args[5]);
+    break;
+  case 7:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *))fn)(&global_tid, &bound_tid, args[0], args[1],
+                                   args[2], args[3], args[4], args[5], args[6]);
+    break;
+  case 8:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *))fn)(&global_tid, &bound_tid, args[0],
+                                           args[1], args[2], args[3], args[4],
+                                           args[5], args[6], args[7]);
+    break;
+  case 9:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4],
+        args[5], args[6], args[7], args[8]);
+    break;
+  case 10:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4],
+        args[5], args[6], args[7], args[8], args[9]);
+    break;
+  case 11:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4],
+        args[5], args[6], args[7], args[8], args[9], args[10]);
+    break;
+  case 12:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4],
+        args[5], args[6], args[7], args[8], args[9], args[10], args[11]);
+    break;
+  case 13:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *, void *,
+               void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2],
+                           args[3], args[4], args[5], args[6], args[7], args[8],
+                           args[9], args[10], args[11], args[12]);
+    break;
+  case 14:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *, void *, void *,
+               void *))fn)(&global_tid, &bound_tid, args[0], args[1], args[2],
+                           args[3], args[4], args[5], args[6], args[7], args[8],
+                           args[9], args[10], args[11], args[12], args[13]);
+    break;
+  case 15:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *, void *, void *,
+               void *, void *))fn)(&global_tid, &bound_tid, args[0], args[1],
+                                   args[2], args[3], args[4], args[5], args[6],
+                                   args[7], args[8], args[9], args[10],
+                                   args[11], args[12], args[13], args[14]);
+    break;
+  case 16:
+    ((void (*)(kmp_int32 *, kmp_int32 *, void *, void *, void *, void *, void *,
+               void *, void *, void *, void *, void *, void *, void *, void *,
+               void *, void *, void *))fn)(
+        &global_tid, &bound_tid, args[0], args[1], args[2], args[3], args[4],
+        args[5], args[6], args[7], args[8], args[9], args[10], args[11],
+        args[12], args[13], args[14], args[15]);
+    break;
+  default:
+    // TODO: assert
+    printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
+    return;
+  }
+}
+
 #pragma omp end declare target
Index: openmp/libomptarget/deviceRTLs/common/src/parallel.cu
===================================================================
--- openmp/libomptarget/deviceRTLs/common/src/parallel.cu
+++ openmp/libomptarget/deviceRTLs/common/src/parallel.cu
@@ -154,16 +154,6 @@
           (int)newTaskDescr->ThreadId(), (int)nThreads);
 
     isActive = true;
-    // Reconverge the threads at the end of the parallel region to correctly
-    // handle parallel levels.
-    // In Cuda9+ in non-SPMD mode we have either 1 worker thread or the whole
-    // warp. If only 1 thread is active, not need to reconverge the threads.
-    // If we have the whole warp, reconverge all the threads in the warp before
-    // actually trying to change the parallel level. Otherwise, parallel level
-    // can be changed incorrectly because of threads divergence.
-    bool IsActiveParallelRegion = threadsInTeam != 1;
-    IncParallelLevel(IsActiveParallelRegion,
-                     IsActiveParallelRegion ? __kmpc_impl_all_lanes : 1u);
   }
 
   return isActive;
@@ -180,17 +170,6 @@
   omptarget_nvptx_TaskDescr *currTaskDescr = getMyTopTaskDescriptor(threadId);
   omptarget_nvptx_threadPrivateContext->SetTopLevelTaskDescr(
       threadId, currTaskDescr->GetPrevTaskDescr());
-
-  // Reconverge the threads at the end of the parallel region to correctly
-  // handle parallel levels.
-  // In Cuda9+ in non-SPMD mode we have either 1 worker thread or the whole
-  // warp. If only 1 thread is active, not need to reconverge the threads.
-  // If we have the whole warp, reconverge all the threads in the warp before
-  // actually trying to change the parallel level. Otherwise, parallel level can
-  // be changed incorrectly because of threads divergence.
-    bool IsActiveParallelRegion = threadsInTeam != 1;
-    DecParallelLevel(IsActiveParallelRegion,
-                     IsActiveParallelRegion ? __kmpc_impl_all_lanes : 1u);
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -302,4 +281,91 @@
   PRINT(LD_IO, "call kmpc_push_proc_bind %d\n", (int)proc_bind);
 }
 
+////////////////////////////////////////////////////////////////////////////////
+// parallel interface
+////////////////////////////////////////////////////////////////////////////////
+
+EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
+                               kmp_int32 if_expr, kmp_int32 num_threads,
+                               int proc_bind, void *fn, void *wrapper_fn,
+                               void **args, size_t nargs) {
+
+  // Handle the serialized case first, same for SPMD/non-SPMD.
+  // TODO: Add UNLIKELY to optimize?
+  if (!if_expr) {
+    __kmpc_serialized_parallel(ident, global_tid);
+    __kmp_invoke_microtask(global_tid, 0, fn, args, nargs);
+    __kmpc_end_serialized_parallel(ident, global_tid);
+
+    return;
+  }
+
+  if (__kmpc_is_spmd_exec_mode()) {
+    // Increment parallel level for SPMD warps.
+    if (GetThreadIdInBlock() == 0)
+      parallelLevel[0] =
+          1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0);
+    else if (GetLaneId() == 0)
+      parallelLevel[GetWarpId()] =
+          1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0);
+    // TODO: Is that synchronization correct/needed? Can only using a memory
+    // fence ensure consistency?
+    __kmpc_impl_syncthreads();
+
+    __kmp_invoke_microtask(global_tid, 0, fn, args, nargs);
+
+    // TODO: is decrementing parallel level needed? parallelLevel will reset to
+    // the next SPMD/non-SPMD parallel region execution, existing implementation
+    // does not decrement?
+    // parallelLevel[GetWarpId()] = 0;
+    return;
+  }
+
+  // Handle the num_threads clause.
+  if (num_threads != -1)
+    __kmpc_push_num_threads(ident, global_tid, num_threads);
+
+  __kmpc_kernel_prepare_parallel((void *)wrapper_fn);
+
+  if (nargs) {
+    void **GlobalArgs;
+    __kmpc_begin_sharing_variables(&GlobalArgs, nargs);
+    // TODO: faster memcpy?
+    for (int I = 0; I < nargs; I++)
+      GlobalArgs[I] = args[I];
+  }
+
+  // TODO: what if that's a parallel region with a single thread? this is considered
+  // not active in the existing implementation.
+  bool IsActiveParallelRegion = threadsInTeam != 1;
+  // Increment parallel level for non-SPMD warps.
+  for (int I = 0; I < threadsInTeam / WARPSIZE; ++I)
+    parallelLevel[I] +=
+        (1 + (IsActiveParallelRegion ? OMP_ACTIVE_PARALLEL_LEVEL : 0));
+
+  // Master signals work to activate workers.
+  __kmpc_barrier_simple_spmd(nullptr, 0);
+
+  // OpenMP [2.5, Parallel Construct, p.49]
+  // There is an implied barrier at the end of a parallel region. After the
+  // end of a parallel region, only the master thread of the team resumes
+  // execution of the enclosing task region.
+  //
+  // The master waits at this barrier until all workers are done.
+  __kmpc_barrier_simple_spmd(nullptr, 0);
+
+  // Decrement parallel level for non-SPMD warps.
+  for (int I = 0; I < threadsInTeam / WARPSIZE; ++I)
+    parallelLevel[I] -=
+        (1 + (IsActiveParallelRegion ? OMP_ACTIVE_PARALLEL_LEVEL : 0));
+  // TODO: Is synchronization needed since out of parallel execution?
+
+  if (nargs)
+    __kmpc_end_sharing_variables();
+
+  // TODO: proc_bind is a noop?
+  // if (proc_bind != proc_bind_default)
+  //  __kmpc_push_proc_bind(ident, global_tid, proc_bind);
+}
+
 #pragma omp end declare target
Index: openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
===================================================================
--- openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
+++ openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
@@ -87,11 +87,6 @@
   int threadId = GetThreadIdInBlock();
   if (threadId == 0) {
     usedSlotIdx = __kmpc_impl_smid() % MAX_SM;
-    parallelLevel[0] =
-        1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0);
-  } else if (GetLaneId() == 0) {
-    parallelLevel[GetWarpId()] =
-        1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0);
   }
   if (!RequiresOMPRuntime) {
     // Runtime is not required - exit.
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -412,6 +412,8 @@
 __OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16)
 __OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16)
 __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
+__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
+          VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)
 __OMP_RTL(__kmpc_kernel_parallel, false, Int1, VoidPtrPtr)
 __OMP_RTL(__kmpc_kernel_end_parallel, false, Void, )
 __OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32)
Index: clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -2072,56 +2072,45 @@
   emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs);
 }
 
-void CGOpenMPRuntimeGPU::emitParallelCall(
-    CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
-    ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
+void CGOpenMPRuntimeGPU::emitParallelCall(CodeGenFunction &CGF,
+                                          SourceLocation Loc,
+                                          llvm::Function *OutlinedFn,
+                                          ArrayRef<llvm::Value *> CapturedVars,
+                                          const Expr *IfCond) {
   if (!CGF.HaveInsertPoint())
     return;
 
-  if (getExecutionMode() == CGOpenMPRuntimeGPU::EM_SPMD)
-    emitSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond);
-  else
-    emitNonSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond);
-}
-
-void CGOpenMPRuntimeGPU::emitNonSPMDParallelCall(
-    CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
-    ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
-  llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
-
-  // Force inline this outlined function at its call site.
-  Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
+  auto &&CodeGen = [this, OutlinedFn, CapturedVars,
+                    Loc](CodeGenFunction &CGF, PrePostActionTy &Action) {
+    Action.Enter(CGF);
+    llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
 
-  // Ensure we do not inline the function. This is trivially true for the ones
-  // passed to __kmpc_fork_call but the ones calles in serialized regions
-  // could be inlined. This is not a perfect but it is closer to the invariant
-  // we want, namely, every data environment starts with a new function.
-  // TODO: We should pass the if condition to the runtime function and do the
-  //       handling there. Much cleaner code.
-  cast<llvm::Function>(OutlinedFn)->addFnAttr(llvm::Attribute::NoInline);
+    // Force inline this outlined function at its call site.
+    Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
 
-  Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
-                                                      /*Name=*/".zero.addr");
-  CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
-  // ThreadId for serialized parallels is 0.
-  Address ThreadIDAddr = ZeroAddr;
-  auto &&CodeGen = [this, Fn, CapturedVars, Loc, &ThreadIDAddr](
-                       CodeGenFunction &CGF, PrePostActionTy &Action) {
-    Action.Enter(CGF);
+    // Ensure we do not inline the function. This is trivially true for the ones
+    // passed to __kmpc_fork_call but the ones calles in serialized regions
+    // could be inlined. This is not a perfect but it is closer to the invariant
+    // we want, namely, every data environment starts with a new function.
+    // TODO: We should pass the if condition to the runtime function and do the
+    //       handling there. Much cleaner code.
+    cast<llvm::Function>(OutlinedFn)->addFnAttr(llvm::Attribute::NoInline);
 
-    Address ZeroAddr =
-        CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
-                                         /*Name=*/".bound.zero.addr");
+    Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
+                                                        /*Name=*/".zero.addr");
     CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
+    // ThreadId for serialized parallels is 0.
+    Address ThreadIDAddr = ZeroAddr;
+
     llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
     OutlinedFnArgs.push_back(ThreadIDAddr.getPointer());
     OutlinedFnArgs.push_back(ZeroAddr.getPointer());
     OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
     emitOutlinedFunctionCall(CGF, Loc, Fn, OutlinedFnArgs);
   };
+
   auto &&SeqGen = [this, &CodeGen, Loc](CodeGenFunction &CGF,
                                         PrePostActionTy &) {
-
     RegionCodeGenTy RCG(CodeGen);
     llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
     llvm::Value *ThreadID = getThreadID(CGF, Loc);
@@ -2138,47 +2127,33 @@
     RCG(CGF);
   };
 
-  auto &&L0ParallelGen = [this, CapturedVars, Fn](CodeGenFunction &CGF,
-                                                  PrePostActionTy &Action) {
+  auto &&ParallelGen = [this, Loc, OutlinedFn, CapturedVars,
+                        IfCond](CodeGenFunction &CGF, PrePostActionTy &Action) {
     CGBuilderTy &Bld = CGF.Builder;
-    llvm::Function *WFn = WrapperFunctionsMap[Fn];
-    assert(WFn && "Wrapper function does not exist!");
-    llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
-
-    // Prepare for parallel region. Indicate the outlined function.
-    llvm::Value *Args[] = {ID};
-    CGF.EmitRuntimeCall(
-        OMPBuilder.getOrCreateRuntimeFunction(
-            CGM.getModule(), OMPRTL___kmpc_kernel_prepare_parallel),
-        Args);
+    llvm::Function *WFn = WrapperFunctionsMap[OutlinedFn];
+    llvm::Value *ID = llvm::ConstantPointerNull::get(CGM.Int8PtrTy);
+    if (WFn) {
+      ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
+      // Remember for post-processing in worker loop.
+      Work.emplace_back(WFn);
+    }
+    llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, CGM.Int8PtrTy);
 
     // Create a private scope that will globalize the arguments
     // passed from the outside of the target region.
+    // TODO: Is that needed?
     CodeGenFunction::OMPPrivateScope PrivateArgScope(CGF);
 
+    Address CapturedVarsAddrs = CGF.CreateDefaultAlignTempAlloca(
+        llvm::ArrayType::get(CGM.VoidPtrTy, CapturedVars.size()),
+        "captured_vars_addrs");
     // There's something to share.
     if (!CapturedVars.empty()) {
       // Prepare for parallel region. Indicate the outlined function.
-      Address SharedArgs =
-          CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy, "shared_arg_refs");
-      llvm::Value *SharedArgsPtr = SharedArgs.getPointer();
-
-      llvm::Value *DataSharingArgs[] = {
-          SharedArgsPtr,
-          llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())};
-      CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), OMPRTL___kmpc_begin_sharing_variables),
-          DataSharingArgs);
-
-      // Store variable address in a list of references to pass to workers.
-      unsigned Idx = 0;
       ASTContext &Ctx = CGF.getContext();
-      Address SharedArgListAddress = CGF.EmitLoadOfPointer(
-          SharedArgs, Ctx.getPointerType(Ctx.getPointerType(Ctx.VoidPtrTy))
-                          .castAs<PointerType>());
+      unsigned Idx = 0;
       for (llvm::Value *V : CapturedVars) {
-        Address Dst = Bld.CreateConstInBoundsGEP(SharedArgListAddress, Idx);
+        Address Dst = Bld.CreateConstArrayGEP(CapturedVarsAddrs, Idx);
         llvm::Value *PtrV;
         if (V->getType()->isIntegerTy())
           PtrV = Bld.CreateIntToPtr(V, CGF.VoidPtrTy);
@@ -2190,139 +2165,36 @@
       }
     }
 
-    // Activate workers. This barrier is used by the master to signal
-    // work for the workers.
-    syncCTAThreads(CGF);
-
-    // OpenMP [2.5, Parallel Construct, p.49]
-    // There is an implied barrier at the end of a parallel region. After the
-    // end of a parallel region, only the master thread of the team resumes
-    // execution of the enclosing task region.
-    //
-    // The master waits at this barrier until all workers are done.
-    syncCTAThreads(CGF);
-
-    if (!CapturedVars.empty())
-      CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-          CGM.getModule(), OMPRTL___kmpc_end_sharing_variables));
-
-    // Remember for post-processing in worker loop.
-    Work.emplace_back(WFn);
-  };
-
-  auto &&LNParallelGen = [this, Loc, &SeqGen, &L0ParallelGen](
-                             CodeGenFunction &CGF, PrePostActionTy &Action) {
-    if (IsInParallelRegion) {
-      SeqGen(CGF, Action);
-    } else if (IsInTargetMasterThreadRegion) {
-      L0ParallelGen(CGF, Action);
-    } else {
-      // Check for master and then parallelism:
-      // if (__kmpc_is_spmd_exec_mode() || __kmpc_parallel_level(loc, gtid)) {
-      //   Serialized execution.
-      // } else {
-      //   Worker call.
-      // }
-      CGBuilderTy &Bld = CGF.Builder;
-      llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
-      llvm::BasicBlock *SeqBB = CGF.createBasicBlock(".sequential");
-      llvm::BasicBlock *ParallelCheckBB = CGF.createBasicBlock(".parcheck");
-      llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
-      llvm::Value *IsSPMD = Bld.CreateIsNotNull(
-          CGF.EmitNounwindRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), OMPRTL___kmpc_is_spmd_exec_mode)));
-      Bld.CreateCondBr(IsSPMD, SeqBB, ParallelCheckBB);
-      // There is no need to emit line number for unconditional branch.
-      (void)ApplyDebugLocation::CreateEmpty(CGF);
-      CGF.EmitBlock(ParallelCheckBB);
-      llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
-      llvm::Value *ThreadID = getThreadID(CGF, Loc);
-      llvm::Value *PL = CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(CGM.getModule(),
-                                                OMPRTL___kmpc_parallel_level),
-          {RTLoc, ThreadID});
-      llvm::Value *Res = Bld.CreateIsNotNull(PL);
-      Bld.CreateCondBr(Res, SeqBB, MasterBB);
-      CGF.EmitBlock(SeqBB);
-      SeqGen(CGF, Action);
-      CGF.EmitBranch(ExitBB);
-      // There is no need to emit line number for unconditional branch.
-      (void)ApplyDebugLocation::CreateEmpty(CGF);
-      CGF.EmitBlock(MasterBB);
-      L0ParallelGen(CGF, Action);
-      CGF.EmitBranch(ExitBB);
-      // There is no need to emit line number for unconditional branch.
-      (void)ApplyDebugLocation::CreateEmpty(CGF);
-      // Emit the continuation block for code after the if.
-      CGF.EmitBlock(ExitBB, /*IsFinished=*/true);
-    }
-  };
-
-  if (IfCond) {
-    emitIfClause(CGF, IfCond, LNParallelGen, SeqGen);
-  } else {
-    CodeGenFunction::RunCleanupsScope Scope(CGF);
-    RegionCodeGenTy ThenRCG(LNParallelGen);
-    ThenRCG(CGF);
-  }
-}
-
-void CGOpenMPRuntimeGPU::emitSPMDParallelCall(
-    CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
-    ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
-  // Just call the outlined function to execute the parallel region.
-  // OutlinedFn(&GTid, &zero, CapturedStruct);
-  //
-  llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
-
-  Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
-                                                      /*Name=*/".zero.addr");
-  CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
-  // ThreadId for serialized parallels is 0.
-  Address ThreadIDAddr = ZeroAddr;
-  auto &&CodeGen = [this, OutlinedFn, CapturedVars, Loc, &ThreadIDAddr](
-                       CodeGenFunction &CGF, PrePostActionTy &Action) {
-    Action.Enter(CGF);
-
-    Address ZeroAddr =
-        CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
-                                         /*Name=*/".bound.zero.addr");
-    CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
-    llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
-    OutlinedFnArgs.push_back(ThreadIDAddr.getPointer());
-    OutlinedFnArgs.push_back(ZeroAddr.getPointer());
-    OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
-    emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs);
-  };
-  auto &&SeqGen = [this, &CodeGen, Loc](CodeGenFunction &CGF,
-                                        PrePostActionTy &) {
+    llvm::Value *IfCondVal = nullptr;
+    if (IfCond)
+      IfCondVal = Bld.CreateIntCast(CGF.EvaluateExprAsBool(IfCond), CGF.Int32Ty,
+                                    /* isSigned */ false);
+    else
+      IfCondVal = llvm::ConstantInt::get(CGF.Int32Ty, 1);
 
-    RegionCodeGenTy RCG(CodeGen);
+    assert(IfCondVal && "Expected a value");
     llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
-    llvm::Value *ThreadID = getThreadID(CGF, Loc);
-    llvm::Value *Args[] = {RTLoc, ThreadID};
-
-    NVPTXActionTy Action(
-        OMPBuilder.getOrCreateRuntimeFunction(
-            CGM.getModule(), OMPRTL___kmpc_serialized_parallel),
-        Args,
-        OMPBuilder.getOrCreateRuntimeFunction(
-            CGM.getModule(), OMPRTL___kmpc_end_serialized_parallel),
-        Args);
-    RCG.setAction(Action);
-    RCG(CGF);
+    llvm::Value *Args[] = {
+        RTLoc,
+        getThreadID(CGF, Loc),
+        IfCondVal,
+        llvm::ConstantInt::get(CGF.Int32Ty, -1),
+        llvm::ConstantInt::get(CGF.Int32Ty, -1),
+        FnPtr,
+        ID,
+        Bld.CreateBitOrPointerCast(CapturedVarsAddrs.getPointer(),
+                                   CGF.VoidPtrPtrTy),
+        llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())};
+    CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
+                            CGM.getModule(), OMPRTL___kmpc_parallel_51),
+                        Args);
   };
 
-  if (IsInTargetMasterThreadRegion) {
-    // In the worker need to use the real thread id.
-    ThreadIDAddr = emitThreadIDAddress(CGF, Loc);
-    RegionCodeGenTy RCG(CodeGen);
+  if (IsInParallelRegion) {
+    RegionCodeGenTy RCG(SeqGen);
     RCG(CGF);
   } else {
-    // If we are not in the target region, it is definitely L2 parallelism or
-    // more, because for SPMD mode we always has L1 parallel level, sowe don't
-    // need to check for orphaned directives.
-    RegionCodeGenTy RCG(SeqGen);
+    RegionCodeGenTy RCG(ParallelGen);
     RCG(CGF);
   }
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D95976: [Open... Giorgis Georgakoudis via Phabricator via cfe-commits
    • [PATCH] D95976: ... Giorgis Georgakoudis via Phabricator via cfe-commits

Reply via email to