[flang-commits] [flang] [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (PR #102662)

via flang-commits flang-commits at lists.llvm.org
Fri Aug 9 12:02:37 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

In device context managed memory is not available so it makes no sense to allocate the descriptor using it. Fall back to fir.alloca as it is handled well in device code. 
cuf.free is just dropped. 

---
Full diff: https://github.com/llvm/llvm-project/pull/102662.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+30) 
- (modified) flang/test/Fir/CUDA/cuda-allocate.fir (+11) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f059d36315a345..d391ede82c2707 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -141,6 +141,20 @@ struct CufDeallocateOpConversion
   }
 };
 
+static bool inDeviceContext(mlir::Operation *op) {
+  if (op->getParentOfType<cuf::KernelOp>())
+    return true;
+  if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
+    if (auto cudaProcAttr =
+            funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+                cuf::getProcAttrName())) {
+      return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
+             cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
+    }
+  }
+  return false;
+}
+
 struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
     if (!boxTy)
       return failure();
 
+    if (inDeviceContext(op.getOperation())) {
+      // In device context just replace the cuf.alloc operation with a fir.alloc
+      // the cuf.free will be removed.
+      rewriter.replaceOpWithNewOp<fir::AllocaOp>(
+          op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
+          op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
+          op.getShape());
+      return mlir::success();
+    }
+
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
@@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
     if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
       return failure();
 
+    if (inDeviceContext(op.getOperation())) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
@@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
         [](::cuf::AllocateOp op) { return isBoxGlobal(op); });
     target.addDynamicallyLegalOp<cuf::DeallocateOp>(
         [](::cuf::DeallocateOp op) { return isBoxGlobal(op); });
+    target.addLegalDialect<fir::FIROpsDialect>();
     patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
     patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
                     CufFreeOpConversion>(ctx);
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 569e72f57d6d6c..a9bc7a8518e90e 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -57,6 +57,17 @@ func.func @_QPsub3() {
 // CHECK: cuf.allocate
 // CHECK: cuf.deallocate
 
+func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub4()
+// CHECK: fir.alloca
+// CHECK-NOT: cuf.free
+
 }
 
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/102662


More information about the flang-commits mailing list