gtbercea updated this revision to Diff 121538.

Repository:
  rL LLVM

https://reviews.llvm.org/D38976

Files:
  lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
  lib/CodeGen/CGOpenMPRuntimeNVPTX.h
  test/OpenMP/nvptx_data_sharing.cpp
  test/OpenMP/nvptx_parallel_codegen.cpp
  test/OpenMP/nvptx_target_teams_codegen.cpp

Index: test/OpenMP/nvptx_target_teams_codegen.cpp
===================================================================
--- test/OpenMP/nvptx_target_teams_codegen.cpp
+++ test/OpenMP/nvptx_target_teams_codegen.cpp
@@ -60,7 +60,7 @@
   //
   // CHECK: [[AWAIT_WORK]]
   // CHECK: call void @llvm.nvvm.barrier0()
-  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]])
+  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]], i8*** %shared_args)
   // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8
   // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1
   // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
@@ -146,7 +146,7 @@
   //
   // CHECK: [[AWAIT_WORK]]
   // CHECK: call void @llvm.nvvm.barrier0()
-  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]])
+  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]], i8*** %shared_args)
   // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8
   // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1
   // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
Index: test/OpenMP/nvptx_parallel_codegen.cpp
===================================================================
--- test/OpenMP/nvptx_parallel_codegen.cpp
+++ test/OpenMP/nvptx_parallel_codegen.cpp
@@ -78,7 +78,7 @@
   //
   // CHECK: [[AWAIT_WORK]]
   // CHECK: call void @llvm.nvvm.barrier0()
-  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]])
+  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]],
   // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8
   // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1
   // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
@@ -92,20 +92,20 @@
   //
   // CHECK: [[EXEC_PARALLEL]]
   // CHECK: [[WF1:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i32*, i32*)* [[PARALLEL_FN1:@.+]] to i8*)
+  // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i16, i32, i8**)* [[PARALLEL_FN1:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM1]], label {{%?}}[[EXEC_PFN1:.+]], label {{%?}}[[CHECK_NEXT1:.+]]
   //
   // CHECK: [[EXEC_PFN1]]
-  // CHECK: call void [[PARALLEL_FN1]](
+  // CHECK: call void [[PARALLEL_FN1]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT1]]
   // CHECK: [[WF2:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i32*, i32*)* [[PARALLEL_FN2:@.+]] to i8*)
+  // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i16, i32, i8**)* [[PARALLEL_FN2:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM2]], label {{%?}}[[EXEC_PFN2:.+]], label {{%?}}[[CHECK_NEXT2:.+]]
   //
   // CHECK: [[EXEC_PFN2]]
-  // CHECK: call void [[PARALLEL_FN2]](
+  // CHECK: call void [[PARALLEL_FN2]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT2]]
@@ -152,13 +152,13 @@
   // CHECK-DAG: [[MWS:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
   // CHECK: [[MTMP1:%.+]] = sub i32 [[MNTH]], [[MWS]]
   // CHECK: call void @__kmpc_kernel_init(i32 [[MTMP1]]
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN1]] to i8*))
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32, i8**)* [[PARALLEL_FN1]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @__kmpc_serialized_parallel(
   // CHECK: {{call|invoke}} void [[PARALLEL_FN3:@.+]](
   // CHECK: call void @__kmpc_end_serialized_parallel(
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN2]] to i8*))
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32, i8**)* [[PARALLEL_FN2]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK-64-DAG: load i32, i32* [[REF_A]]
@@ -203,7 +203,7 @@
   //
   // CHECK: [[AWAIT_WORK]]
   // CHECK: call void @llvm.nvvm.barrier0()
-  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]])
+  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]],
   // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8
   // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1
   // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
@@ -217,11 +217,11 @@
   //
   // CHECK: [[EXEC_PARALLEL]]
   // CHECK: [[WF:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i32*, i32*)* [[PARALLEL_FN4:@.+]] to i8*)
+  // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i16, i32, i8**)* [[PARALLEL_FN4:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM]], label {{%?}}[[EXEC_PFN:.+]], label {{%?}}[[CHECK_NEXT:.+]]
   //
   // CHECK: [[EXEC_PFN]]
-  // CHECK: call void [[PARALLEL_FN4]](
+  // CHECK: call void [[PARALLEL_FN4]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT]]
@@ -283,7 +283,7 @@
   // CHECK: br i1 [[CMP]], label {{%?}}[[IF_THEN:.+]], label {{%?}}[[IF_ELSE:.+]]
   //
   // CHECK: [[IF_THEN]]
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN4]] to i8*))
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32, i8**)* [[PARALLEL_FN4]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: br label {{%?}}[[IF_END:.+]]
Index: test/OpenMP/nvptx_data_sharing.cpp
===================================================================
--- /dev/null
+++ test/OpenMP/nvptx_data_sharing.cpp
@@ -0,0 +1,52 @@
+// Test device data sharing codegen.
+///==========================================================================///
+
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple powerpc64le-unknown-unknown -fopenmp-targets=nvptx64-nvidia-cuda -emit-llvm-bc %s -o %t-ppc-host.bc
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple nvptx64-unknown-unknown -fopenmp-targets=nvptx64-nvidia-cuda -emit-llvm %s -fopenmp-is-device -fopenmp-host-ir-file-path %t-ppc-host.bc -o - | FileCheck %s --check-prefix CK1
+
+// expected-no-diagnostics
+
+#ifndef HEADER
+#define HEADER
+
+void test_ds(){
+  #pragma omp target
+  {
+    int a = 10;
+    #pragma omp parallel
+    {
+       a = 1000;
+    }
+  }
+}
+
+/// ========= In the worker function ========= ///
+
+// CK1: define internal void @__omp_offloading_{{.*}}test_ds{{.*}}worker(){{.*}}{
+// CK1: [[SHAREDARGS:%.+]] = alloca i8**
+// CK1: call i1 @__kmpc_kernel_parallel(i8** %work_fn, i8*** [[SHAREDARGS]])
+// CK1: [[SHARGSTMP:%.+]] = load i8**, i8*** [[SHAREDARGS]]
+// CK1: call void @__omp_outlined___wrapper{{.*}}({{.*}}, i8** %5)
+
+/// ========= In the kernel function ========= ///
+
+// CK1: {{.*}}define void @__omp_offloading{{.*}}test_ds{{.*}}()
+// CK1: [[SHAREDARGS1:%.+]] = alloca i8**
+// CK1: call void @__kmpc_kernel_prepare_parallel({{.*}}, i8*** [[SHAREDARGS1]], i32 1)
+// CK1: [[SHARGSTMP1:%.+]] = load i8**, i8*** [[SHAREDARGS1]]
+// CK1: [[SHARGSTMP2:%.+]] = getelementptr inbounds i8*, i8** [[SHARGSTMP1]]
+// CK1: [[SHAREDVAR:%.+]] = bitcast i32* {{.*}} to i8*
+// CK1: store i8* [[SHAREDVAR]], i8** [[SHARGSTMP2]]
+
+/// ========= In the data sharing wrapper function ========= ///
+
+// CK1: {{.*}}define internal void @__omp_outlined___wrapper({{.*}}i8**){{.*}}{
+// CK1: [[SHAREDARGS2:%.+]] = alloca i8**
+// CK1: store i8** %2, i8*** [[SHAREDARGS2]]
+// CK1: [[SHARGSTMP3:%.+]] = load i8**, i8*** [[SHAREDARGS2]]
+// CK1: [[SHARGSTMP4:%.+]] = getelementptr inbounds i8*, i8** [[SHARGSTMP3]]
+// CK1: [[SHARGSTMP5:%.+]] = bitcast i8** [[SHARGSTMP4]] to i32**
+// CK1: [[SHARGSTMP6:%.+]] = load i32*, i32** [[SHARGSTMP5]]
+// CK1: call void @__omp_outlined__({{.*}}, i32* [[SHARGSTMP6]])
+
+#endif
\ No newline at end of file
Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.h
===================================================================
--- lib/CodeGen/CGOpenMPRuntimeNVPTX.h
+++ lib/CodeGen/CGOpenMPRuntimeNVPTX.h
@@ -305,6 +305,17 @@
   // target region and used by containing directives such as 'parallel'
   // to emit optimized code.
   ExecutionMode CurrentExecutionMode;
+
+  /// Map between an outlined function and its wrapper.
+  llvm::DenseMap<llvm::Function *, llvm::Function *> WrapperFunctionsMap;
+
+  /// Emit function which wraps the outline parallel region
+  /// and controls the parameters which are passed to this function.
+  /// The wrapper ensures that the outlined function is called
+  /// with the correct arguments when data is shared.
+  llvm::Function *
+  createDataSharingWrapper(llvm::Function *OutlinedParallelFn,
+      const OMPExecutableDirective &D);
 };
 
 } // CodeGen namespace.
Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
===================================================================
--- lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -294,6 +294,7 @@
   EntryFunctionState EST;
   WorkerFunctionState WST(CGM);
   Work.clear();
+  WrapperFunctionsMap.clear();
 
   // Emit target region as a standalone region.
   class NVPTXPrePostActionTy : public PrePostActionTy {
@@ -467,7 +468,7 @@
 }
 
 void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) {
-  auto &Ctx = CGM.getContext();
+  ASTContext &Ctx = CGM.getContext();
 
   CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
   CGF.disableDebugInfo();
@@ -510,7 +511,10 @@
   CGF.InitTempAlloca(ExecStatus, Bld.getInt8(/*C=*/0));
   CGF.InitTempAlloca(WorkFn, llvm::Constant::getNullValue(CGF.Int8PtrTy));
 
-  llvm::Value *Args[] = {WorkFn.getPointer()};
+  // Set up shared arguments
+  Address SharedArgs =
+      CGF.CreateDefaultAlignTempAlloca(CGF.Int8PtrPtrTy, "shared_args");
+  llvm::Value *Args[] = {WorkFn.getPointer(), SharedArgs.getPointer()};
   llvm::Value *Ret = CGF.EmitRuntimeCall(
       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_parallel), Args);
   Bld.CreateStore(Bld.CreateZExt(Ret, CGF.Int8Ty), ExecStatus);
@@ -529,6 +533,9 @@
   // Signal start of parallel region.
   CGF.EmitBlock(ExecuteBB);
 
+  // Current context
+  ASTContext &Ctx = CGF.getContext();
+
   // Process work items: outlined parallel functions.
   for (auto *W : Work) {
     // Try to match this outlined function.
@@ -544,14 +551,18 @@
     // Execute this outlined function.
     CGF.EmitBlock(ExecuteFNBB);
 
-    // Insert call to work function.
-    // FIXME: Pass arguments to outlined function from master thread.
-    auto *Fn = cast<llvm::Function>(W);
-    Address ZeroAddr =
-        CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr");
-    CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0));
-    llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()};
-    emitCall(CGF, Fn, FnArgs);
+    // Insert call to work function via shared wrapper. The shared
+    // wrapper takes exactly three arguments:
+    //   - the parallelism level;
+    //   - the master thread ID;
+    //   - the list of references to shared arguments.
+    //
+    // TODO: Assert that the function is a wrapper function.s
+    Address Capture = CGF.EmitLoadOfPointer(SharedArgs,
+       Ctx.getPointerType(
+          Ctx.getPointerType(Ctx.VoidPtrTy)).castAs<PointerType>());
+    emitCall(CGF, W, {Bld.getInt16(/*ParallelLevel=*/0),
+        getMasterThreadID(CGF), Capture.getPointer()});
 
     // Go to end of parallel region.
     CGF.EmitBranch(TerminateBB);
@@ -617,16 +628,18 @@
   }
   case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: {
     /// Build void __kmpc_kernel_prepare_parallel(
-    /// void *outlined_function);
-    llvm::Type *TypeParams[] = {CGM.Int8PtrTy};
+    /// void *outlined_function, void ***args, kmp_int32 nArgs);
+    llvm::Type *TypeParams[] = {CGM.Int8PtrTy,
+        CGM.Int8PtrPtrTy->getPointerTo(0), CGM.Int32Ty};
     llvm::FunctionType *FnTy =
         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel");
     break;
   }
   case OMPRTL_NVPTX__kmpc_kernel_parallel: {
-    /// Build bool __kmpc_kernel_parallel(void **outlined_function);
-    llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy};
+    /// Build bool __kmpc_kernel_parallel(void **outlined_function, void ***args);
+    llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy,
+        CGM.Int8PtrPtrTy->getPointerTo(0)};
     llvm::Type *RetTy = CGM.getTypes().ConvertType(CGM.getContext().BoolTy);
     llvm::FunctionType *FnTy =
         llvm::FunctionType::get(RetTy, TypeParams, /*isVarArg*/ false);
@@ -845,8 +858,17 @@
 llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction(
     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
     OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
-  return CGOpenMPRuntime::emitParallelOutlinedFunction(D, ThreadIDVar,
-                                                       InnermostKind, CodeGen);
+
+  auto *OutlinedFun = cast<llvm::Function>(
+    CGOpenMPRuntime::emitParallelOutlinedFunction(
+          D, ThreadIDVar, InnermostKind, CodeGen));
+  if (!isInSpmdExecutionMode()) {
+    llvm::Function *WrapperFun =
+        createDataSharingWrapper(OutlinedFun, D);
+    WrapperFunctionsMap[OutlinedFun] = WrapperFun;
+  }
+
+  return OutlinedFun;
 }
 
 llvm::Value *CGOpenMPRuntimeNVPTX::emitTeamsOutlinedFunction(
@@ -898,15 +920,52 @@
     CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
     ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
   llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
+  llvm::Function *WFn = WrapperFunctionsMap[Fn];
+  assert(WFn && "Wrapper function does not exist!");
 
-  auto &&L0ParallelGen = [this, Fn](CodeGenFunction &CGF, PrePostActionTy &) {
+  // Force inline this outlined function at its call site.
+  Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
+
+  auto &&L0ParallelGen = [this, WFn, &CapturedVars](CodeGenFunction &CGF,
+                                                    PrePostActionTy &) {
     CGBuilderTy &Bld = CGF.Builder;
 
-    // Prepare for parallel region. Indicate the outlined function.
-    llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy)};
-    CGF.EmitRuntimeCall(
-        createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
-        Args);
+    llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
+
+    if (!CapturedVars.empty()) {
+      // Prepare for parallel region. Indicate the outlined function.
+      Address SharedArgs =
+          CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy,
+              "shared_args");
+      llvm::Value *SharedArgsPtr = SharedArgs.getPointer();
+      llvm::Value *Args[] = {ID, SharedArgsPtr,
+                             Bld.getInt32(CapturedVars.size())};
+
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
+          Args);
+
+      unsigned Idx = 0;
+      ASTContext &Ctx = CGF.getContext();
+      for (llvm::Value *V : CapturedVars) {
+        Address Dst = Bld.CreateConstInBoundsGEP(
+            CGF.EmitLoadOfPointer(SharedArgs,
+            Ctx.getPointerType(
+                Ctx.getPointerType(Ctx.VoidPtrTy)).castAs<PointerType>()),
+            Idx, CGF.getPointerSize());
+        llvm::Value *PtrV = Bld.CreateBitCast(V, CGF.VoidPtrTy);
+        CGF.EmitStoreOfScalar(PtrV, Dst, /*Volatile=*/false,
+            Ctx.getPointerType(Ctx.VoidPtrTy));
+        Idx++;
+      }
+    } else {
+      llvm::Value *Args[] = {ID,
+          llvm::ConstantPointerNull::get(CGF.VoidPtrPtrTy->getPointerTo(0)),
+          /*nArgs=*/Bld.getInt32(0)};
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
+          Args);
+    }
 
     // Activate workers. This barrier is used by the master to signal
     // work for the workers.
@@ -921,7 +980,7 @@
     syncCTAThreads(CGF);
 
     // Remember for post-processing in worker loop.
-    Work.push_back(Fn);
+    Work.emplace_back(WFn);
   };
 
   auto *RTLoc = emitUpdateLocation(CGF, Loc);
@@ -2317,3 +2376,111 @@
   }
   CGOpenMPRuntime::emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, TargetArgs);
 }
+
+/// Emit the code that each thread requires to execute when it encounters
+/// one of the three possible parallelism levels. This also emits the required
+/// data sharing code for each level.
+static void emitParallelismLevelCode(
+    CodeGenFunction &CGF) {
+  // Emit L0 and L1 level parallel code.
+  // TODO: Emit L1 code for nested parallelism.
+  llvm::BasicBlock *AfterBB = CGF.createBasicBlock(".after.parallel");
+
+  // Emit L0 code
+  llvm::BasicBlock *LBB = CGF.createBasicBlock(".level0.parallel");
+  CGF.EmitBlock(LBB);
+  CGF.EmitBranch(AfterBB);
+  CGF.EmitBlock(AfterBB);
+}
+
+/// Emit function which wraps the outline parallel region
+/// and controls the arguments which are passed to this function.
+/// The wrapper ensures that the outlined function is called
+/// with the correct arguments when data is shared.
+llvm::Function *CGOpenMPRuntimeNVPTX::createDataSharingWrapper(
+    llvm::Function *OutlinedParallelFn, const OMPExecutableDirective &D) {
+  ASTContext &Ctx = CGM.getContext();
+  const auto &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+
+  // Create a function that takes as argument the source thread.
+  FunctionArgList WrapperArgs;
+  QualType Int16QTy =
+      Ctx.getIntTypeForBitwidth(/*DestWidth=*/16, /*Signed=*/false);
+  QualType Int32QTy =
+      Ctx.getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/false);
+  QualType Int32PtrQTy = Ctx.getPointerType(Int32QTy);
+  QualType VoidPtrPtrQTy = Ctx.getPointerType(Ctx.VoidPtrTy);
+  ImplicitParamDecl ParallelLevelArg(Ctx, Int16QTy, ImplicitParamDecl::Other);
+  ImplicitParamDecl WrapperArg(Ctx, Int32QTy, ImplicitParamDecl::Other);
+  ImplicitParamDecl SharedArgsList(Ctx, VoidPtrPtrQTy,
+      ImplicitParamDecl::Other);
+  WrapperArgs.emplace_back(&ParallelLevelArg);
+  WrapperArgs.emplace_back(&WrapperArg);
+  WrapperArgs.emplace_back(&SharedArgsList);
+
+  auto &CGFI =
+      CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, WrapperArgs);
+
+  auto *Fn = llvm::Function::Create(
+      CGM.getTypes().GetFunctionType(CGFI), llvm::GlobalValue::InternalLinkage,
+      OutlinedParallelFn->getName() + "_wrapper", &CGM.getModule());
+  CGM.SetInternalFunctionAttributes(/*D=*/nullptr, Fn, CGFI);
+  Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
+
+  CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
+  CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, Fn, CGFI, WrapperArgs);
+
+  const auto *RD = CS.getCapturedRecordDecl();
+  auto CurField = RD->field_begin();
+
+  // Emit code which performas the data sharing.
+  emitParallelismLevelCode(CGF);
+
+  // Get the array of arguments.
+  SmallVector<llvm::Value *, 8> Args;
+
+  // TODO: suppport SIMD and pass actual values
+  Args.emplace_back(llvm::ConstantPointerNull::get(
+      CGM.Int32Ty->getPointerTo()));
+  Args.emplace_back(llvm::ConstantPointerNull::get(
+      CGM.Int32Ty->getPointerTo()));
+
+  CGBuilderTy &Bld = CGF.Builder;
+  auto CI = CS.capture_begin();
+
+  // Load the start of the array
+  auto SharedArgs =
+      CGF.EmitLoadOfPointer(CGF.GetAddrOfLocalVar(&SharedArgsList),
+          VoidPtrPtrQTy->castAs<PointerType>());
+
+  // For each captured variable
+  for (unsigned I = 0; I < CS.capture_size(); ++I, ++CI, ++CurField) {
+    // Name of captured variable
+    StringRef Name;
+    if (CI->capturesThis())
+      Name = "this";
+    else
+      Name = CI->getCapturedVar()->getName();
+
+    // We retrieve the CLANG type of the argument. We use it to create
+    // an alloca which will give us the LLVM type.
+    QualType ElemTy = CurField->getType();
+    // If this is a capture by copy the element type has to be the pointer to
+    // the data.
+    if (CI->capturesVariableByCopy())
+      ElemTy = Ctx.getPointerType(ElemTy);
+
+    // Get shared address of the captured variable.
+    Address ArgAddress = Bld.CreateConstInBoundsGEP(
+        SharedArgs, I, CGF.getPointerSize());
+    Address TypedArgAddress = Bld.CreateBitCast(
+        ArgAddress, CGF.ConvertTypeForMem(Ctx.getPointerType(ElemTy)));
+    llvm::Value *Arg = CGF.EmitLoadOfScalar(TypedArgAddress,
+        /*Volatile=*/false, Int32PtrQTy, SourceLocation());
+    Args.emplace_back(Arg);
+  }
+
+  emitCall(CGF, OutlinedParallelFn, Args);
+  CGF.FinishFunction();
+  return Fn;
+}
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to