[flang-commits] [flang] [flang][cuda] Allocate descriptor in managed memory on rebox block argument (PR #123971)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Jan 22 09:09:54 PST 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/123971

Another case where the descriptor must be allocated with the CUF runtime and not a simple alloca instruction. 

>From 3863ad094a7307f49e52169da3842a75ca355682 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 22 Jan 2025 09:06:24 -0800
Subject: [PATCH] [flang][cuda] Allocate descriptor in managed memory on rebox
 block arg

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 38 +++++++++++++------------
 flang/test/Fir/CUDA/cuda-code-gen.mlir  | 11 +++++++
 2 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 43c0e2686a8c3b..6ff2c20d744537 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
         getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
 
     if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
-      return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
-                      operands, rewriter);
-    return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
-                      operands, rewriter);
+      return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
+                      inputStrides, operands, rewriter);
+    return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
+                      inputStrides, operands, rewriter);
   }
 
 private:
   /// Write resulting shape and base address in descriptor, and replace rebox
   /// op.
   llvm::LogicalResult
-  finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
-                mlir::Value base, mlir::ValueRange lbounds,
-                mlir::ValueRange extents, mlir::ValueRange strides,
+  finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
+                mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
+                mlir::ValueRange lbounds, mlir::ValueRange extents,
+                mlir::ValueRange strides,
                 mlir::ConversionPatternRewriter &rewriter) const {
     mlir::Location loc = rebox.getLoc();
     mlir::Value zero =
@@ -2075,15 +2076,15 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
     dest = insertBaseAddress(rewriter, loc, dest, base);
     mlir::Value result = placeInMemoryIfNotGlobalInit(
         rewriter, rebox.getLoc(), destBoxTy, dest,
-        isDeviceAllocation(rebox.getBox(), rebox.getBox()));
+        isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
     rewriter.replaceOp(rebox, result);
     return mlir::success();
   }
 
   // Apply slice given the base address, extents and strides of the input box.
   llvm::LogicalResult
-  sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
-           mlir::Value base, mlir::ValueRange inputExtents,
+  sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
+           mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
            mlir::ValueRange inputStrides, mlir::ValueRange operands,
            mlir::ConversionPatternRewriter &rewriter) const {
     mlir::Location loc = rebox.getLoc();
@@ -2109,7 +2110,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
     if (rebox.getSlice().empty())
       // The array section is of the form array[%component][substring], keep
       // the input array extents and strides.
-      return finalizeRebox(rebox, destBoxTy, dest, base,
+      return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
                            /*lbounds*/ std::nullopt, inputExtents, inputStrides,
                            rewriter);
 
@@ -2158,15 +2159,16 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
         slicedStrides.emplace_back(stride);
       }
     }
-    return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
-                         slicedExtents, slicedStrides, rewriter);
+    return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
+                         /*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
+                         rewriter);
   }
 
   /// Apply a new shape to the data described by a box given the base address,
   /// extents and strides of the box.
   llvm::LogicalResult
-  reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
-             mlir::Value base, mlir::ValueRange inputExtents,
+  reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
+             mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
              mlir::ValueRange inputStrides, mlir::ValueRange operands,
              mlir::ConversionPatternRewriter &rewriter) const {
     mlir::ValueRange reboxShifts{
@@ -2175,7 +2177,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
             rebox.getShift().size()};
     if (rebox.getShape().empty()) {
       // Only setting new lower bounds.
-      return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
+      return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
                            inputExtents, inputStrides, rewriter);
     }
 
@@ -2199,8 +2201,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
       // nextStride = extent * stride;
       stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
     }
-    return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
-                         newStrides, rewriter);
+    return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
+                         newExtents, newStrides, rewriter);
   }
 
   /// Return scalar element type of the input box.
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index 7ac89836a3ff16..063454799502af 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
 
 // CHECK-LABEL: llvm.func @_QPouter
 // CHECK: _FortranACUFAllocDescriptor
+
+// -----
+
+func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "da"}, %arg1: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "db"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}) {
+  %0 = fircg.ext_rebox %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %1 = fircg.ext_rebox %arg1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  return
+}
+
+// CHECK-LABEL: llvm.func @_QMm1Psub1
+// CHECK-COUNT-2: _FortranACUFAllocDescriptor



More information about the flang-commits mailing list