[flang-commits] [flang] [flang][cuda] Convert cuf.alloc and cuf.free for scalar and arrays (PR #110055)

via flang-commits flang-commits at lists.llvm.org
Wed Sep 25 15:39:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays

---
Full diff: https://github.com/llvm/llvm-project/pull/110055.diff


3 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+92-49) 
- (added) flang/test/Fir/CUDA/cuda-alloc-free.fir (+64) 
- (modified) flang/test/Fir/CUDA/cuda-allocate.fir (-11) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f8ace2dd96a0d8..fff026bbda2d40 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
   return false;
 }
 
+static int computeWidth(mlir::Location loc, mlir::Type type,
+                        fir::KindMapping &kindMap) {
+  auto eleTy = fir::unwrapSequenceType(type);
+  int width = 0;
+  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (eleTy.isInteger(1)) {
+    width = 1;
+  } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
+    int kind = t.getFKind();
+    width = kindMap.getLogicalBitsize(kind) / 8;
+  } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
+    int kind = t.getFKind();
+    int elemSize = kindMap.getRealBitsize(kind) / 8;
+    width = 2 * elemSize;
+  } else {
+    llvm::report_fatal_error("unsupported type");
+  }
+  return width;
+}
+
 struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   mlir::LogicalResult
   matchAndRewrite(cuf::AllocOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
-
-    // Only convert cuf.alloc that allocates a descriptor.
-    if (!boxTy)
-      return failure();
 
     if (inDeviceContext(op.getOperation())) {
       // In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+      // Convert scalar and known size array allocations.
+      mlir::Value bytes;
+      fir::KindMapping kindMap{fir::getKindMapping(mod)};
+      if (fir::isa_trivial(op.getInType())) {
+        int width = computeWidth(loc, op.getInType(), kindMap);
+        bytes =
+            builder.createIntegerConstant(loc, builder.getIndexType(), width);
+      } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+                     op.getInType())) {
+        mlir::Value width = builder.createIntegerConstant(
+            loc, builder.getIndexType(),
+            computeWidth(loc, seqTy.getEleTy(), kindMap));
+        mlir::Value nbElem;
+        if (fir::sequenceWithNonConstantShape(seqTy)) {
+          assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+          nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+          for (unsigned i = 1; i < op.getShape().size(); ++i) {
+            nbElem = rewriter.create<mlir::arith::MulIOp>(
+                loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
+          }
+        } else {
+          nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+                                                 seqTy.getConstantArraySize());
+        }
+        bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
+      }
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), kMemTypeDevice);
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+      auto callOp = builder.create<fir::CallOp>(loc, func, args);
+      auto convOp = builder.createConvert(loc, op.getResult().getType(),
+                                          callOp.getResult(0));
+      rewriter.replaceOp(op, convOp);
+      return mlir::success();
+    }
+
+    // Convert descriptor allocations to function call.
+    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
-
     auto fTy = func.getFunctionType();
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
     mlir::Value sourceLine =
         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
 
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   mlir::LogicalResult
   matchAndRewrite(cuf::FreeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    // Only convert cuf.free on descriptor.
-    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
-      return failure();
-    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
-    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
-      return failure();
-
     if (inDeviceContext(op.getOperation())) {
       rewriter.eraseOp(op);
       return mlir::success();
     }
 
+    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+      return failure();
+
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), kMemTypeDevice);
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+      builder.create<fir::CallOp>(loc, func, args);
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Convert cuf.free on descriptors.
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
-
     auto fTy = func.getFunctionType();
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
     mlir::Value sourceLine =
         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
     llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   }
 };
 
-static int computeWidth(mlir::Location loc, mlir::Type type,
-                        fir::KindMapping &kindMap) {
-  auto eleTy = fir::unwrapSequenceType(type);
-  int width = 0;
-  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
-    width = t.getWidth() / 8;
-  } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
-    width = t.getWidth() / 8;
-  } else if (eleTy.isInteger(1)) {
-    width = 1;
-  } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
-    int kind = t.getFKind();
-    width = kindMap.getLogicalBitsize(kind) / 8;
-  } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
-    int kind = t.getFKind();
-    int elemSize = kindMap.getRealBitsize(kind) / 8;
-    width = 2 * elemSize;
-  } else {
-    llvm::report_fatal_error("unsupported type");
-  }
-  return width;
-}
-
 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
                                    mlir::Location loc, mlir::Type toTy,
                                    mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
         fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
     fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
                                          /*forceUnifiedTBAATree=*/false, *dl);
-    target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
-      return !mlir::isa<fir::BaseBoxType>(op.getInType());
-    });
-    target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
-      if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
-              op.getDevptr().getType())) {
-        return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
-      }
-      return true;
-    });
     target.addDynamicallyLegalOp<cuf::DataTransferOp>(
         [](::cuf::DataTransferOp op) {
           mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir
new file mode 100644
index 00000000000000..25821418a40f11
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir
@@ -0,0 +1,64 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+func.func @_QPsub1() {
+  %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} -> !fir.ref<i32>
+  %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
+// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
+// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub2() {
+  %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
+  cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64 
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.convert %3 : (i32) -> i64
+  %5 = fir.convert %4 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %6 = arith.cmpi sgt, %5, %c0 : index
+  %7 = arith.select %6, %5, %c0 : index
+  %8 = fir.load %2#0 : !fir.ref<i32>
+  %9 = fir.convert %8 : (i32) -> i64
+  %10 = fir.convert %9 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %11 = arith.cmpi sgt, %10, %c0_0 : index
+  %12 = arith.select %11, %10, %c0_0 : index
+  %13 = cuf.alloc !fir.array<?x?xi32>, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} -> !fir.ref<!fir.array<?x?xi32>>
+  %14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2>
+  %15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
+  cuf.free %15#1 : !fir.ref<!fir.array<?x?xi32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub3
+// CHECK: %[[N:.*]] = arith.select 
+// CHECK: %[[M:.*]] = arith.select
+// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+} // end module
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 65c68bb69301af..d68ff894d5af5a 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -26,17 +26,6 @@ func.func @_QPsub1() {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
 
-// Check operations that should not be transformed yet.
-func.func @_QPsub2() {
-  %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
-  cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
-  return
-}
-
-// CHECK-LABEL: func.func @_QPsub2()
-// CHECK: cuf.alloc !fir.array<10xf32>
-// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
-
 fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
     %0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
     %c0 = arith.constant 0 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/110055


More information about the flang-commits mailing list