[flang-commits] [flang] [flang][cuda] Get device address in fir.declare (PR #118591)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 10:36:02 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/118591

>From f61d75c0d18d4247522de6110a23c9a4a5cf0b65 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 3 Dec 2024 21:50:44 -0800
Subject: [PATCH 1/2] [flang][cuda] Get device address in fir.declare

---
 .../Optimizer/Transforms/CUFOpConversion.h    |   5 +
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 150 +++++++++++-------
 flang/test/Fir/CUDA/cuda-data-transfer.fir    |  29 ++--
 flang/test/Fir/CUDA/cuda-global-addr.mlir     |  34 ++++
 4 files changed, 145 insertions(+), 73 deletions(-)
 create mode 100644 flang/test/Fir/CUDA/cuda-global-addr.mlir

diff --git a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
index f061323db1704a..336cf46d82babf 100644
--- a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
+++ b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
@@ -23,11 +23,16 @@ class SymbolTable;
 
 namespace cuf {
 
+/// Patterns that convert CUF operations to runtime calls.
 void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
                                         mlir::DataLayout &dl,
                                         const mlir::SymbolTable &symtab,
                                         mlir::RewritePatternSet &patterns);
 
+/// Patterns that updates fir operations in presence of CUF.
+void populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
+                                      mlir::RewritePatternSet &patterns);
+
 } // namespace cuf
 
 #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFOPCONVERSION_H_
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 337ea04755d1a9..7f6843d66d39f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
   return false;
 }
 
+bool isDeviceGlobal(fir::GlobalOp op) {
+  auto attr = op.getDataAttr();
+  if (attr && (*attr == cuf::DataAttribute::Device ||
+               *attr == cuf::DataAttribute::Managed ||
+               *attr == cuf::DataAttribute::Constant))
+    return true;
+  return false;
+}
+
 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
                                    mlir::Location loc, mlir::Type toTy,
                                    mlir::Value val) {
@@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
   return val;
 }
 
-mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
-                             mlir::OpOperand &operand,
-                             const mlir::SymbolTable &symtab) {
-  mlir::Value v = operand.get();
-  auto declareOp = v.getDefiningOp<fir::DeclareOp>();
-  if (!declareOp)
-    return v;
-
-  auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
-  if (!addrOfOp)
-    return v;
-
-  auto globalOp = symtab.lookup<fir::GlobalOp>(
-      addrOfOp.getSymbol().getRootReference().getValue());
-
-  if (!globalOp)
-    return v;
-
-  bool isDevGlobal{false};
-  auto attr = globalOp.getDataAttrAttr();
-  if (attr) {
-    switch (attr.getValue()) {
-    case cuf::DataAttribute::Device:
-    case cuf::DataAttribute::Managed:
-    case cuf::DataAttribute::Constant:
-      isDevGlobal = true;
-      break;
-    default:
-      break;
-    }
-  }
-  if (!isDevGlobal)
-    return v;
-  mlir::OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(operand.getOwner());
-  auto loc = declareOp.getLoc();
-  auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
-  fir::FirOpBuilder builder(rewriter, mod);
-
-  mlir::func::FuncOp callee =
-      fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
-  auto fTy = callee.getFunctionType();
-  auto toTy = fTy.getInput(0);
-  mlir::Value inputArg =
-      createConvertOp(rewriter, loc, toTy, declareOp.getResult());
-  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-  mlir::Value sourceLine =
-      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-      builder, loc, fTy, inputArg, sourceFile, sourceLine)};
-  auto call = rewriter.create<fir::CallOp>(loc, callee, args);
-  mlir::Value cast = createConvertOp(
-      rewriter, loc, declareOp.getMemref().getType(), call->getResult(0));
-  return cast;
-}
-
 template <typename OpTy>
 static mlir::LogicalResult convertOpToCall(OpTy op,
                                            mlir::PatternRewriter &rewriter,
@@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   const fir::LLVMTypeConverter *typeConverter;
 };
 
+struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  DeclareOpConversion(mlir::MLIRContext *context,
+                      const mlir::SymbolTable &symtab)
+      : OpRewritePattern(context), symTab{symtab} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::DeclareOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+      if (auto global = symTab.lookup<fir::GlobalOp>(
+              addrOfOp.getSymbol().getRootReference().getValue())) {
+        if (isDeviceGlobal(global)) {
+          rewriter.setInsertionPointAfter(addrOfOp);
+          auto mod = op->getParentOfType<mlir::ModuleOp>();
+          fir::FirOpBuilder builder(rewriter, mod);
+          mlir::Location loc = op.getLoc();
+          mlir::func::FuncOp callee =
+              fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
+                  loc, builder);
+          auto fTy = callee.getFunctionType();
+          mlir::Type toTy = fTy.getInput(0);
+          mlir::Value inputArg =
+              createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
+          mlir::Value sourceFile =
+              fir::factory::locationToFilename(builder, loc);
+          mlir::Value sourceLine =
+              fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+          llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+              builder, loc, fTy, inputArg, sourceFile, sourceLine)};
+          auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+          mlir::Value cast = createConvertOp(
+              rewriter, loc, op.getMemref().getType(), call->getResult(0));
+          rewriter.startOpModification(op);
+          op.getMemrefMutable().assign(cast);
+          rewriter.finalizeOpModification(op);
+          return success();
+        }
+      }
+    }
+    return failure();
+  }
+
+private:
+  const mlir::SymbolTable &symTab;
+};
+
 struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
     builder.create<fir::StoreOp>(loc, src, alloc);
     addr = alloc;
   } else {
-    addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+    addr = op.getSrc();
   }
   llvm::SmallVector<mlir::Value> lenParams;
   mlir::Type boxTy = fir::BoxType::get(srcTy);
@@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
   mlir::Location loc = op.getLoc();
   fir::FirOpBuilder builder(rewriter, mod);
   mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
-  mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+  mlir::Value dstAddr = op.getDst();
   mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
   llvm::SmallVector<mlir::Value> lenParams;
   mlir::Value dstBox =
@@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
 
-      mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
-      mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
       // Materialize the src if constant.
       if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
         mlir::Value temp = builder.createTemporary(loc, srcTy);
@@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
                       "error in CUF op conversion\n");
       signalPassFailure();
     }
+
+    target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
+      if (inDeviceContext(op))
+        return true;
+      if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+        if (auto global = symtab.lookup<fir::GlobalOp>(
+                addrOfOp.getSymbol().getRootReference().getValue())) {
+          if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
+            return true;
+          if (isDeviceGlobal(global))
+            return false;
+        }
+      }
+      return true;
+    });
+
+    patterns.clear();
+    cuf::populateFIRCUFConversionPatterns(symtab, patterns);
+    if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(ctx),
+                      "error in CUF op conversion\n");
+      signalPassFailure();
+    }
   }
 };
 } // namespace
@@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
                                                &dl, &converter);
   patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
 }
+
+void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
+                                           mlir::RewritePatternSet &patterns) {
+  patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+}
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b371d397777280..7203c33e7eb11f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -199,12 +199,12 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
-// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
 // CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
 
@@ -223,11 +223,11 @@ func.func @_QPsub9() {
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
-// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
+// CHECK: %[[DST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
@@ -380,9 +380,12 @@ func.func @_QPdevice_addr_conv() {
 }
 
 // CHECK-LABEL: func.func @_QPdevice_addr_conv()
-// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
-// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+// CHECK: fir.embox %[[DECL]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
 // CHECK: fir.call @_FortranACUFDataTransferCstDesc
 
 func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
new file mode 100644
index 00000000000000..6d6022af6df8cd
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -0,0 +1,34 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
+  %0 = fir.zero_bits !fir.array<10xi32>
+  fir.has_value %0 : !fir.array<10xi32>
+}
+func.func @_QQmain() attributes {fir.bindc_name = "test"} {
+  %c14_i32 = arith.constant 14 : i32
+  %c6_i32 = arith.constant 6 : i32
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+  %c0_i32 = arith.constant 0 : i32
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %3 = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
+  %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+  %5 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+  %6 = fir.declare %5 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  fir.store %c0_i32 to %6 : !fir.ref<i32>
+  %7 = fir.array_coor %4(%1) %c4 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+  cuf.data_transfer %c1_i32 to %7 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
+  return
+}
+
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %{{.*}} = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+   
\ No newline at end of file

>From 407e881b949ef5e8cb349ffbec1d985873b0e779 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 4 Dec 2024 10:35:48 -0800
Subject: [PATCH 2/2] Check array_coor operation

---
 flang/test/Fir/CUDA/cuda-global-addr.mlir | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 6d6022af6df8cd..2baead4010f5c5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -30,5 +30,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
 // CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
 // CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
-// CHECK: %{{.*}} = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
-   
\ No newline at end of file
+// CHECK: %[[DECL:.*]] = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor %[[DECL]](%{{.*}}) %c4{{.*}} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: %[[ARRAY_COOR_PTR:.*]] = fir.convert %[[ARRAY_COOR]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[ARRAY_COOR_PTR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none



More information about the flang-commits mailing list