jhuber6 updated this revision to Diff 439853.
jhuber6 added a comment.

Update comment.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D128550/new/

https://reviews.llvm.org/D128550

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -791,6 +791,38 @@
   Entry->setAlignment(Align(1));
 }
 
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
+    const LocationDescription &Loc, Value *&Return, Value *Ident,
+    Value *DeviceID, Value *NumTeams, Value *NumThreads, Value *HostPtr,
+    ArrayRef<Value *> KernelArgs, ArrayRef<Value *> NoWaitArgs) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  auto *KernelArgsPtr =
+      Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
+  for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
+    llvm::Value *Arg =
+        Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
+    Builder.CreateAlignedStore(
+        KernelArgs[I], Arg,
+        M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
+  }
+
+  bool HasNoWait = !NoWaitArgs.empty();
+  SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
+                                      NumThreads, HostPtr,  KernelArgsPtr};
+  if (HasNoWait)
+    OffloadingArgs.append(NoWaitArgs.begin(), NoWaitArgs.end());
+
+  Return = Builder.CreateCall(
+      HasNoWait
+          ? getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel_nowait)
+          : getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
+      OffloadingArgs);
+
+  return Builder.saveIP();
+}
+
 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
                                                omp::Directive CanceledDirective,
                                                FinalizeCallbackTy ExitCB) {
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -88,6 +88,8 @@
 __OMP_STRUCT_TYPE(Ident, ident_t, Int32, Int32, Int32, Int32, Int8Ptr)
 __OMP_STRUCT_TYPE(OffloadEntry, __tgt_offload_entry, Int8Ptr, Int8Ptr, SizeTy,
                   Int32, Int32)
+__OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, VoidPtrPtr,
+                  VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr)
 __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr)
 
 #undef __OMP_STRUCT_TYPE
@@ -412,6 +414,10 @@
 __OMP_RTL(__tgt_target_teams_nowait_mapper, false, Int32, IdentPtr, Int64,
           VoidPtr, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr,
           VoidPtrPtr, VoidPtrPtr, Int32, Int32, Int32, VoidPtr, Int32, VoidPtr)
+__OMP_RTL(__tgt_target_kernel, false, Int32, IdentPtr, Int64, Int32, Int32,
+          VoidPtr, KernelArgsPtr)
+__OMP_RTL(__tgt_target_kernel_nowait, false, Int32, IdentPtr, Int64, Int32,
+          Int32, VoidPtr, KernelArgsPtr, Int32, VoidPtr, Int32, VoidPtr)
 __OMP_RTL(__tgt_register_requires, false, Void, Int64)
 __OMP_RTL(__tgt_target_data_begin_mapper, false, Void, IdentPtr, Int64, Int32, VoidPtrPtr,
           VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr)
@@ -937,6 +943,10 @@
                 ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_target_teams_nowait_mapper, ForkAttrs, AttributeSet(),
                 ParamAttrs())
+__OMP_RTL_ATTRS(__tgt_target_kernel, ForkAttrs, AttributeSet(),
+                ParamAttrs())
+__OMP_RTL_ATTRS(__tgt_target_kernel_nowait, ForkAttrs, AttributeSet(),
+                ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_register_requires, ForkAttrs, AttributeSet(),
                 ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_target_data_begin_mapper, ForkAttrs, AttributeSet(),
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -821,6 +821,23 @@
                                 omp::Directive CanceledDirective,
                                 FinalizeCallbackTy ExitCB = {});
 
+  /// Generate a target region entry call.
+  ///
+  /// \param Loc The location at which the request originated and is fulfilled.
+  /// \param Return Return value of the created function returned by reference.
+  /// \param DeviceID Identifier for the device via the 'device' clause.
+  /// \param NumTeams Numer of teams for the region via the 'num_teams' clause
+  ///                 or 0 if unspecified and -1 if there is no 'teams' clause.
+  /// \param NumThreads Number of threads via the 'thread_limit' clause.
+  /// \param HostPtr Pointer to the host-side pointer of the target kernel.
+  /// \param KernelArgs Array of arguments to the kernel.
+  /// \param NoWaitKernelArgs Optional array of arguments to the nowait kernel.
+  InsertPointTy emitTargetKernel(const LocationDescription &Loc, Value *&Return,
+                                 Value *Ident, Value *DeviceID, Value *NumTeams,
+                                 Value *NumThreads, Value *HostPtr,
+                                 ArrayRef<Value *> KernelArgs,
+                                 ArrayRef<Value *> NoWaitArgs = {});
+
   /// Generate a barrier runtime call.
   ///
   /// \param Loc The location at which the request originated and is fulfilled.
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6717,11 +6717,9 @@
     default:
       break;
     }
-  } else if (DefaultNT == -1) {
-    return nullptr;
   }
 
-  return Bld.getInt32(DefaultNT);
+  return llvm::ConstantInt::get(CGF.Int32Ty, DefaultNT);
 }
 
 static llvm::Value *getNumThreads(CodeGenFunction &CGF, const CapturedStmt *CS,
@@ -10310,23 +10308,28 @@
     // Emit tripcount for the target loop-based directive.
     emitTargetNumIterationsCall(CGF, D, DeviceID, SizeEmitter);
 
-    bool HasNowait = D.hasClausesOfKind<OMPNowaitClause>();
+    // Arguments for the target kernel.
+    SmallVector<llvm::Value *> KernelArgs{
+        PointerNum,
+        InputInfo.BasePointersArray.getPointer(),
+        InputInfo.PointersArray.getPointer(),
+        InputInfo.SizesArray.getPointer(),
+        MapTypesArray,
+        MapNamesArray,
+        InputInfo.MappersArray.getPointer()};
+
+    // Arguments passed to the 'nowait' variant.
+    SmallVector<llvm::Value *> NoWaitKernelArgs{
+        CGF.Builder.getInt32(0),
+        llvm::ConstantPointerNull::get(CGM.VoidPtrTy),
+        CGF.Builder.getInt32(0),
+        llvm::ConstantPointerNull::get(CGM.VoidPtrTy),
+    };
+
+    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
+
     // The target region is an outlined function launched by the runtime
-    // via calls __tgt_target() or __tgt_target_teams().
-    //
-    // __tgt_target() launches a target region with one team and one thread,
-    // executing a serial region.  This master thread may in turn launch
-    // more threads within its team upon encountering a parallel region,
-    // however, no additional teams can be launched on the device.
-    //
-    // __tgt_target_teams() launches a target region with one or more teams,
-    // each with one or more threads.  This call is required for target
-    // constructs such as:
-    //  'target teams'
-    //  'target' / 'teams'
-    //  'target teams distribute parallel for'
-    //  'target parallel'
-    // and so on.
+    // via calls to __tgt_target_kernel().
     //
     // Note that on the host and CPU targets, the runtime implementation of
     // these calls simply call the outlined function without forking threads.
@@ -10337,70 +10340,15 @@
     // In contrast, on the NVPTX target, the implementation of
     // __tgt_target_teams() launches a GPU kernel with the requested number
     // of teams and threads so no additional calls to the runtime are required.
-    if (NumTeams) {
-      // If we have NumTeams defined this means that we have an enclosed teams
-      // region. Therefore we also expect to have NumThreads defined. These two
-      // values should be defined in the presence of a teams directive,
-      // regardless of having any clauses associated. If the user is using teams
-      // but no clauses, these two values will be the default that should be
-      // passed to the runtime library - a 32-bit integer with the value zero.
-      assert(NumThreads && "Thread limit expression should be available along "
-                           "with number of teams.");
-      SmallVector<llvm::Value *> OffloadingArgs = {
-          RTLoc,
-          DeviceID,
-          OutlinedFnID,
-          PointerNum,
-          InputInfo.BasePointersArray.getPointer(),
-          InputInfo.PointersArray.getPointer(),
-          InputInfo.SizesArray.getPointer(),
-          MapTypesArray,
-          MapNamesArray,
-          InputInfo.MappersArray.getPointer(),
-          NumTeams,
-          NumThreads};
-      if (HasNowait) {
-        // Add int32_t depNum = 0, void *depList = nullptr, int32_t
-        // noAliasDepNum = 0, void *noAliasDepList = nullptr.
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-      }
-      Return = CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), HasNowait
-                                   ? OMPRTL___tgt_target_teams_nowait_mapper
-                                   : OMPRTL___tgt_target_teams_mapper),
-          OffloadingArgs);
-    } else {
-      SmallVector<llvm::Value *> OffloadingArgs = {
-          RTLoc,
-          DeviceID,
-          OutlinedFnID,
-          PointerNum,
-          InputInfo.BasePointersArray.getPointer(),
-          InputInfo.PointersArray.getPointer(),
-          InputInfo.SizesArray.getPointer(),
-          MapTypesArray,
-          MapNamesArray,
-          InputInfo.MappersArray.getPointer()};
-      if (HasNowait) {
-        // Add int32_t depNum = 0, void *depList = nullptr, int32_t
-        // noAliasDepNum = 0, void *noAliasDepList = nullptr.
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-      }
-      Return = CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), HasNowait ? OMPRTL___tgt_target_nowait_mapper
-                                         : OMPRTL___tgt_target_mapper),
-          OffloadingArgs);
-    }
-
     // Check the error code and execute the host version if required.
+    CGF.Builder.restoreIP(
+        HasNoWait ? OMPBuilder.emitTargetKernel(
+                        CGF.Builder, Return, RTLoc, DeviceID, NumTeams,
+                        NumThreads, OutlinedFnID, KernelArgs, NoWaitKernelArgs)
+                  : OMPBuilder.emitTargetKernel(CGF.Builder, Return, RTLoc,
+                                                DeviceID, NumTeams, NumThreads,
+                                                OutlinedFnID, KernelArgs));
+
     llvm::BasicBlock *OffloadFailedBlock =
         CGF.createBasicBlock("omp_offload.failed");
     llvm::BasicBlock *OffloadContBlock =
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to