[flang-commits] [flang] [flang][cuda] Add cuf.device_address operation (PR #122975)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Jan 14 13:52:57 PST 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/122975

Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to 
https://github.com/llvm/llvm-project/pull/122802

>From 1b0a3a3f840c0d80fc64beea290c6630878b28c5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 14 Jan 2025 12:57:45 -0800
Subject: [PATCH] [flang][cuda] Add cuf.device_address operation

---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     | 12 +++
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 76 +++++++++++++------
 flang/test/Fir/CUDA/cuda-data-transfer.fir    |  3 +
 flang/test/Fir/CUDA/cuda-global-addr.mlir     |  1 +
 flang/test/Fir/CUDA/cuda-launch.fir           |  6 +-
 5 files changed, 71 insertions(+), 27 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 6f886726b12834..a270e69b394104 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -335,4 +335,16 @@ def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
   }];
 }
 
+def cuf_DeviceAddressOp : cuf_Op<"device_address", []> {
+  let summary = "Get the device address from a host symbol";
+
+  let arguments = (ins SymbolRefAttr:$hostSymbol);
+
+  let assemblyFormat = [{
+    $hostSymbol attr-dict `->` type($addr)
+  }];
+
+  let results = (outs fir_ReferenceType:$addr);
+}
+
 #endif // FORTRAN_DIALECT_CUF_CUF_OPS
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index d61d9f63cb2949..e93bed37d39f78 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   const fir::LLVMTypeConverter *typeConverter;
 };
 
-static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
-                                       mlir::ModuleOp mod, mlir::Location loc,
-                                       mlir::Value inputArg) {
-  fir::FirOpBuilder builder(rewriter, mod);
-  mlir::func::FuncOp callee =
-      fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
-  auto fTy = callee.getFunctionType();
-  mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
-  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-  mlir::Value sourceLine =
-      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-      builder, loc, fTy, conv, sourceFile, sourceLine)};
-  auto call = rewriter.create<fir::CallOp>(loc, callee, args);
-  return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
-}
+struct CUFDeviceAddressOpConversion
+    : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
+                               const mlir::SymbolTable &symtab)
+      : OpRewritePattern(context), symTab{symtab} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::DeviceAddressOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    if (auto global = symTab.lookup<fir::GlobalOp>(
+            op.getHostSymbol().getRootReference().getValue())) {
+      auto mod = op->getParentOfType<mlir::ModuleOp>();
+      mlir::Location loc = op.getLoc();
+      auto hostAddr = rewriter.create<fir::AddrOfOp>(
+          loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
+      fir::FirOpBuilder builder(rewriter, mod);
+      mlir::func::FuncOp callee =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
+                                                                     builder);
+      auto fTy = callee.getFunctionType();
+      mlir::Value conv =
+          createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, conv, sourceFile, sourceLine)};
+      auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+      mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
+                                         call->getResult(0));
+      rewriter.replaceOp(op, addr.getDefiningOp());
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  const mlir::SymbolTable &symTab;
+};
 
 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
               addrOfOp.getSymbol().getRootReference().getValue())) {
         if (cuf::isRegisteredDeviceGlobal(global)) {
           rewriter.setInsertionPointAfter(addrOfOp);
-          auto mod = op->getParentOfType<mlir::ModuleOp>();
-          mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
-                                                    addrOfOp.getResult());
+          mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
+              op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
           rewriter.startOpModification(op);
           op.getMemrefMutable().assign(devAddr);
           rewriter.finalizeOpModification(op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
       }
     }
     llvm::SmallVector<mlir::Value> args;
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
     for (mlir::Value arg : op.getArgs()) {
       // If the argument is a global descriptor, make sure we pass the device
       // copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
             if (auto global = symTab.lookup<fir::GlobalOp>(
                     addrOfOp.getSymbol().getRootReference().getValue())) {
               if (cuf::isRegisteredDeviceGlobal(global)) {
-                arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
-                                          declareOp.getResult());
+                arg = rewriter
+                          .create<cuf::DeviceAddressOp>(op.getLoc(),
+                                                        addrOfOp.getType(),
+                                                        addrOfOp.getSymbol())
+                          .getResult();
               }
             }
           }
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
       patterns.getContext());
   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
                                                &dl, &converter);
-  patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+  patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
+      patterns.getContext(), symtab);
 }
 
 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
                                            mlir::RewritePatternSet &patterns) {
-  patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+  patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
+      patterns.getContext(), symtab);
 }
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 7203c33e7eb11f..5ed27f1be0a430 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -198,6 +198,7 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
 // CHECK-LABEL: func.func @_QPsub8()
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -222,6 +223,7 @@ func.func @_QPsub9() {
 // CHECK-LABEL: func.func @_QPsub9()
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -380,6 +382,7 @@ func.func @_QPdevice_addr_conv() {
 }
 
 // CHECK-LABEL: func.func @_QPdevice_addr_conv()
+// CHECK: fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 94ee74736f6508..0ccd0c797fb6f5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -26,6 +26,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
 }
 
 // CHECK-LABEL: func.func @_QQmain()
+// CHECK: fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
 // CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
 // CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 1e19b3bea1296f..8432b9ec926e38 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -98,9 +98,9 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
 }
 
 // CHECK-LABEL: func.func @_QQmain()
+// CHECK: _FortranACUFSyncGlobalDescriptor
 // CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_ADDR:.*]] = fir.convert %[[ADDROF]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_ADDR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
 // CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
 // CHECK: gpu.launch_func  @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})  dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)



More information about the flang-commits mailing list