[flang-commits] [flang] [flang][cuda] Allocate descriptor in managed memory when memref is a block argument (PR #123829)

via flang-commits flang-commits at lists.llvm.org
Tue Jan 21 14:03:53 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

When emboxing a memory reference that comes from a block argument, look for the cuda data attribute to place it in managed memory when needed. 

---
Full diff: https://github.com/llvm/llvm-project/pull/123829.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+29-8) 
- (modified) flang/test/Fir/CUDA/cuda-code-gen.mlir (+17) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 5ba93fefab3f9e..43c0e2686a8c3b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
   }
 };
 
-static bool isDeviceAllocation(mlir::Value val) {
+static bool isDeviceAllocation(mlir::Value val, mlir::Value adaptorVal) {
   if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
-    return isDeviceAllocation(loadOp.getMemref());
+    return isDeviceAllocation(loadOp.getMemref(), {});
   if (auto boxAddrOp =
           mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp()))
-    return isDeviceAllocation(boxAddrOp.getVal());
+    return isDeviceAllocation(boxAddrOp.getVal(), {});
   if (auto convertOp =
           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
-    return isDeviceAllocation(convertOp.getValue());
+    return isDeviceAllocation(convertOp.getValue(), {});
+  if (!val.getDefiningOp() && adaptorVal) {
+    if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) {
+      if (blockArg.getOwner() && blockArg.getOwner()->getParentOp() &&
+          blockArg.getOwner()->isEntryBlock()) {
+        if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>(
+                *blockArg.getOwner()->getParentOp())) {
+          auto argAttrs = func.getArgAttrs(blockArg.getArgNumber());
+          for (auto attr : argAttrs) {
+            if (attr.getName().getValue().ends_with(cuf::getDataAttrName())) {
+              auto dataAttr =
+                  mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue());
+              if (dataAttr.getValue() != cuf::DataAttribute::Pinned &&
+                  dataAttr.getValue() != cuf::DataAttribute::Unified)
+                return true;
+            }
+          }
+        }
+      }
+    }
+  }
   if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
     if (callOp.getCallee() &&
         (callOp.getCallee().value().getRootReference().getValue().starts_with(
@@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
     if (fir::isDerivedTypeWithLenParams(boxTy))
       TODO(loc, "fir.embox codegen of derived with length parameters");
     mlir::Value result = placeInMemoryIfNotGlobalInit(
-        rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
+        rewriter, loc, boxTy, dest,
+        isDeviceAllocation(xbox.getMemref(), adaptor.getMemref()));
     rewriter.replaceOp(xbox, result);
     return mlir::success();
   }
@@ -2052,9 +2073,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
       dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value()));
     }
     dest = insertBaseAddress(rewriter, loc, dest, base);
-    mlir::Value result =
-        placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest,
-                                     isDeviceAllocation(rebox.getBox()));
+    mlir::Value result = placeInMemoryIfNotGlobalInit(
+        rewriter, rebox.getLoc(), destBoxTy, dest,
+        isDeviceAllocation(rebox.getBox(), rebox.getBox()));
     rewriter.replaceOp(rebox, result);
     return mlir::success();
   }
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index 3ad28fa7bd5179..7ac89836a3ff16 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -170,3 +170,20 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
 
 // CHECK-LABEL: llvm.func @_QQmain()
 // CHECK-COUNT-3: llvm.call @_FortranACUFAllocDescriptor
+
+// -----
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git at github.com:clementval/llvm-project.git efc2415bcce8e8a9e73e77aa122c8aba1c1fbbd2)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+  func.func @_QPouter(%arg0: !fir.ref<!fir.array<100x100xf64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) {
+    %c0_i32 = arith.constant 0 : i32
+    %c100 = arith.constant 100 : index
+    %0 = fir.alloca tuple<!fir.box<!fir.array<100x100xf64>>>
+    %1 = fir.coordinate_of %0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<100x100xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<100x100xf64>>>
+    %2 = fircg.ext_embox %arg0(%c100, %c100) : (!fir.ref<!fir.array<100x100xf64>>, index, index) -> !fir.box<!fir.array<100x100xf64>>
+    fir.store %2 to %1 : !fir.ref<!fir.box<!fir.array<100x100xf64>>>
+    return
+  }
+}
+
+// CHECK-LABEL: llvm.func @_QPouter
+// CHECK: _FortranACUFAllocDescriptor

``````````

</details>


https://github.com/llvm/llvm-project/pull/123829


More information about the flang-commits mailing list