================
@@ -2774,6 +2779,88 @@ void 
LoweringPreparePass::buildCUDARegisterGlobalFunctions(
   }
 }
 
+// Emit `__{cuda|hip}RegisterVar` calls inside `__{cuda|hip}_register_globals`
+// for every device-side shadow that carries a `cu.var_registration` attribute
+// (attached by `CIRGenNVCUDARuntime::handleVarRegistration`).
+void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
+                                                FuncOp regGlobalFunc) {
+  mlir::Location loc = mlirModule.getLoc();
+  llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
+  cir::CIRDataLayout dataLayout(mlirModule);
+
+  PointerType voidPtrTy = builder.getVoidPtrTy();
+  PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
+  IntType intTy = builder.getSIntNTy(32);
+  IntType sizeTy =
+      builder.getUIntNTy(astCtx->getTargetInfo().getMaxPointerWidth());
+  IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
+                                     /*isSigned=*/false);
+
+  if (cudaDeviceVars.empty())
+    return;
+
+  cir::CIRBaseBuilderTy globalBuilder(getContext());
+  globalBuilder.setInsertionPointToStart(mlirModule.getBody());
+
+  // void __{cuda|hip}RegisterVar(void **fatbinHandle,
+  //                              char *hostVar, char *deviceAddress,
+  //                              const char *deviceName, int ext,
+  //                              size_t size, int constant, int normalized);
+  // OG ignores parameter types, treating pointers as void*.
+  cir::VoidType voidTy = builder.getVoidTy();
+  FuncOp cudaRegisterVar = buildRuntimeFunction(
+      globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
+      FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
+                     sizeTy, intTy, intTy},
+                    voidTy));
+
+  auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
+    auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
+    auto tmpString = cir::GlobalOp::create(
+        globalBuilder, loc, (".str" + str).str(), strType,
+        /*isConstant=*/true, {},
+        /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
+    tmpString.setInitialValueAttr(
+        ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
+    tmpString.setPrivate();
+    return tmpString;
+  };
+
+  mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
+
+  for (auto &[global, regAttr] : cudaDeviceVars) {
+    switch (regAttr.getKind()) {
+    case cir::CUDADeviceVarKind::Variable:
+      break;
+    case cir::CUDADeviceVarKind::Surface:
+      llvm_unreachable("Surface registration NYI");
+    case cir::CUDADeviceVarKind::Texture:
+      llvm_unreachable("Texture registration NYI");
+    }
+
+    if (regAttr.getIsManaged())
+      llvm_unreachable("Managed variable registration NYI");
+
+    GlobalOp deviceNameStr = makeConstantString(regAttr.getDeviceSideName());
+    mlir::Value deviceName = builder.createBitcast(
+        builder.createGetGlobal(deviceNameStr), voidPtrTy);
+    mlir::Value hostVar =
+        builder.createBitcast(builder.createGetGlobal(global), voidPtrTy);
+
+    auto isExtern = ConstantOp::create(
+        builder, loc, IntAttr::get(intTy, regAttr.getIsExtern() ? 1 : 0));
+    llvm::TypeSize size = dataLayout.getTypeSizeInBits(global.getSymType());
+    auto varSize = ConstantOp::create(
+        builder, loc, IntAttr::get(sizeTy, size.getFixedValue() / 8));
----------------
RiverDave wrote:

Missed it for some reason, thanks again.

https://github.com/llvm/llvm-project/pull/199270
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to