[flang-commits] [flang] 864902e - [flang][cuda] Call CUFGetDeviceAddress to get global device address from host address (#112989)

via flang-commits flang-commits at lists.llvm.org
Fri Oct 18 17:35:41 PDT 2024


Author: Renaud Kauffmann
Date: 2024-10-18T17:35:38-07:00
New Revision: 864902e9b4d8bc6d3f0852d5c475e3dc97dd8335

URL: https://github.com/llvm/llvm-project/commit/864902e9b4d8bc6d3f0852d5c475e3dc97dd8335
DIFF: https://github.com/llvm/llvm-project/commit/864902e9b4d8bc6d3f0852d5c475e3dc97dd8335.diff

LOG: [flang][cuda] Call CUFGetDeviceAddress to get global device address from host address  (#112989)

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Transforms/CufOpConversion.h
    flang/lib/Optimizer/Transforms/CufOpConversion.cpp
    flang/test/Fir/CUDA/cuda-data-transfer.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
index 79ce4ac5c6cbc0..0a71cdfddec1ab 100644
--- a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
+++ b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
@@ -18,12 +18,14 @@ class LLVMTypeConverter;
 
 namespace mlir {
 class DataLayout;
+class SymbolTable;
 }
 
 namespace cuf {
 
 void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
                                         mlir::DataLayout &dl,
+                                        const mlir::SymbolTable &symtab,
                                         mlir::RewritePatternSet &patterns);
 
 } // namespace cuf

diff  --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 91ef1259332de9..9df559ee0ab1f8 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
   return false;
 }
 
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+                                   mlir::Location loc, mlir::Type toTy,
+                                   mlir::Value val) {
+  if (val.getType() != toTy)
+    return rewriter.create<fir::ConvertOp>(loc, toTy, val);
+  return val;
+}
+
+mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
+                             mlir::OpOperand &operand,
+                             const mlir::SymbolTable &symtab) {
+  mlir::Value v = operand.get();
+  auto declareOp = v.getDefiningOp<fir::DeclareOp>();
+  if (!declareOp)
+    return v;
+
+  auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+  if (!addrOfOp)
+    return v;
+
+  auto globalOp = symtab.lookup<fir::GlobalOp>(
+      addrOfOp.getSymbol().getRootReference().getValue());
+
+  if (!globalOp)
+    return v;
+
+  bool isDevGlobal{false};
+  auto attr = globalOp.getDataAttrAttr();
+  if (attr) {
+    switch (attr.getValue()) {
+    case cuf::DataAttribute::Device:
+    case cuf::DataAttribute::Managed:
+    case cuf::DataAttribute::Pinned:
+      isDevGlobal = true;
+      break;
+    default:
+      break;
+    }
+  }
+  if (!isDevGlobal)
+    return v;
+  mlir::OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(operand.getOwner());
+  auto loc = declareOp.getLoc();
+  auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
+  fir::FirOpBuilder builder(rewriter, mod);
+
+  mlir::func::FuncOp callee =
+      fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
+  auto fTy = callee.getFunctionType();
+  auto toTy = fTy.getInput(0);
+  mlir::Value inputArg =
+      createConvertOp(rewriter, loc, toTy, declareOp.getResult());
+  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+  mlir::Value sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+      builder, loc, fTy, inputArg, sourceFile, sourceLine)};
+  auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+
+  return call->getResult(0);
+}
+
 template <typename OpTy>
 static mlir::LogicalResult convertOpToCall(OpTy op,
                                            mlir::PatternRewriter &rewriter,
@@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   }
 };
 
-static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
-                                   mlir::Location loc, mlir::Type toTy,
-                                   mlir::Value val) {
-  if (val.getType() != toTy)
-    return rewriter.create<fir::ConvertOp>(loc, toTy, val);
-  return val;
-}
-
 struct CufDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  CufDataTransferOpConversion(mlir::MLIRContext *context,
+                              const mlir::SymbolTable &symtab)
+      : OpRewritePattern(context), symtab{symtab} {}
+
   mlir::LogicalResult
   matchAndRewrite(cuf::DataTransferOp op,
                   mlir::PatternRewriter &rewriter) const override {
@@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
 
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue,
-          sourceFile, sourceLine)};
+      mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+      mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+      llvm::SmallVector<mlir::Value> args{
+          fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+                                        modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
       return mlir::success();
@@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
     }
     return mlir::success();
   }
+
+private:
+  const mlir::SymbolTable &symtab;
 };
 
 class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
@@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
     mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
     if (!module)
       return signalPassFailure();
+    mlir::SymbolTable symtab(module);
 
     std::optional<mlir::DataLayout> dl =
         fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
     fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
                                          /*forceUnifiedTBAATree=*/false, *dl);
     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
-    cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns);
+    cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
+                                            patterns);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
       mlir::emitError(mlir::UnknownLoc::get(ctx),
@@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
 
 void cuf::populateCUFToFIRConversionPatterns(
     const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
-    mlir::RewritePatternSet &patterns) {
+    const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
   patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
   patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
-                  CufFreeOpConversion, CufDataTransferOpConversion>(
-      patterns.getContext());
+                  CufFreeOpConversion>(patterns.getContext());
+  patterns.insert<CufDataTransferOpConversion>(patterns.getContext(), symtab);
 }

diff  --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index ed894aed5534a0..c33c50115b9fc0 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -189,4 +189,47 @@ func.func @_QPsub7() {
 // CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
+func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
+  %c5 = arith.constant 5 : index
+  %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFEm"}
+  %1 = fir.shape %c5 : (index) -> !fir.shape<1>
+  %2 = fir.declare %0(%1) {uniq_name = "_QFEm"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+  %3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+  %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+  cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub8()
+// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
+// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
+// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+
+func.func @_QPsub9() {
+  %c5 = arith.constant 5 : index
+  %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFtest9Em"}
+  %1 = fir.shape %c5 : (index) -> !fir.shape<1>
+  %2 = fir.declare %0(%1) {uniq_name = "_QFtest9Em"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+  %3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+  %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+  cuf.data_transfer %2 to %4 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub9()
+// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
+// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
+// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 } // end of module


        


More information about the flang-commits mailing list