[flang-commits] [flang] [flang][cuda] Get device address in fir.declare (PR #118591)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Wed Dec 4 10:36:02 PST 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/118591
>From f61d75c0d18d4247522de6110a23c9a4a5cf0b65 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 3 Dec 2024 21:50:44 -0800
Subject: [PATCH 1/2] [flang][cuda] Get device address in fir.declare
---
.../Optimizer/Transforms/CUFOpConversion.h | 5 +
.../Optimizer/Transforms/CUFOpConversion.cpp | 150 +++++++++++-------
flang/test/Fir/CUDA/cuda-data-transfer.fir | 29 ++--
flang/test/Fir/CUDA/cuda-global-addr.mlir | 34 ++++
4 files changed, 145 insertions(+), 73 deletions(-)
create mode 100644 flang/test/Fir/CUDA/cuda-global-addr.mlir
diff --git a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
index f061323db1704a..336cf46d82babf 100644
--- a/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
+++ b/flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
@@ -23,11 +23,16 @@ class SymbolTable;
namespace cuf {
+/// Patterns that convert CUF operations to runtime calls.
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);
+/// Patterns that updates fir operations in presence of CUF.
+void populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
+ mlir::RewritePatternSet &patterns);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFOPCONVERSION_H_
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 337ea04755d1a9..7f6843d66d39f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
return false;
}
+bool isDeviceGlobal(fir::GlobalOp op) {
+ auto attr = op.getDataAttr();
+ if (attr && (*attr == cuf::DataAttribute::Device ||
+ *attr == cuf::DataAttribute::Managed ||
+ *attr == cuf::DataAttribute::Constant))
+ return true;
+ return false;
+}
+
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
@@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
return val;
}
-mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
- mlir::OpOperand &operand,
- const mlir::SymbolTable &symtab) {
- mlir::Value v = operand.get();
- auto declareOp = v.getDefiningOp<fir::DeclareOp>();
- if (!declareOp)
- return v;
-
- auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
- if (!addrOfOp)
- return v;
-
- auto globalOp = symtab.lookup<fir::GlobalOp>(
- addrOfOp.getSymbol().getRootReference().getValue());
-
- if (!globalOp)
- return v;
-
- bool isDevGlobal{false};
- auto attr = globalOp.getDataAttrAttr();
- if (attr) {
- switch (attr.getValue()) {
- case cuf::DataAttribute::Device:
- case cuf::DataAttribute::Managed:
- case cuf::DataAttribute::Constant:
- isDevGlobal = true;
- break;
- default:
- break;
- }
- }
- if (!isDevGlobal)
- return v;
- mlir::OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(operand.getOwner());
- auto loc = declareOp.getLoc();
- auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
-
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
- auto fTy = callee.getFunctionType();
- auto toTy = fTy.getInput(0);
- mlir::Value inputArg =
- createConvertOp(rewriter, loc, toTy, declareOp.getResult());
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, inputArg, sourceFile, sourceLine)};
- auto call = rewriter.create<fir::CallOp>(loc, callee, args);
- mlir::Value cast = createConvertOp(
- rewriter, loc, declareOp.getMemref().getType(), call->getResult(0));
- return cast;
-}
-
template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
@@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};
+struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ DeclareOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symtab)
+ : OpRewritePattern(context), symTab{symtab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::DeclareOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue())) {
+ if (isDeviceGlobal(global)) {
+ rewriter.setInsertionPointAfter(addrOfOp);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
+ loc, builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Type toTy = fTy.getInput(0);
+ mlir::Value inputArg =
+ createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
+ mlir::Value sourceFile =
+ fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, inputArg, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+ mlir::Value cast = createConvertOp(
+ rewriter, loc, op.getMemref().getType(), call->getResult(0));
+ rewriter.startOpModification(op);
+ op.getMemrefMutable().assign(cast);
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+ }
+ }
+ return failure();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
+
struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
using OpRewritePattern::OpRewritePattern;
@@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
builder.create<fir::StoreOp>(loc, src, alloc);
addr = alloc;
} else {
- addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+ addr = op.getSrc();
}
llvm::SmallVector<mlir::Value> lenParams;
mlir::Type boxTy = fir::BoxType::get(srcTy);
@@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
- mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+ mlir::Value dstAddr = op.getDst();
mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value dstBox =
@@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
- mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+ mlir::Value dst = op.getDst();
+ mlir::Value src = op.getSrc();
// Materialize the src if constant.
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
mlir::Value temp = builder.createTemporary(loc, srcTy);
@@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
"error in CUF op conversion\n");
signalPassFailure();
}
+
+ target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
+ if (inDeviceContext(op))
+ return true;
+ if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+ if (auto global = symtab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue())) {
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
+ return true;
+ if (isDeviceGlobal(global))
+ return false;
+ }
+ }
+ return true;
+ });
+
+ patterns.clear();
+ cuf::populateFIRCUFConversionPatterns(symtab, patterns);
+ if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUF op conversion\n");
+ signalPassFailure();
+ }
}
};
} // namespace
@@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
&dl, &converter);
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
}
+
+void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
+ mlir::RewritePatternSet &patterns) {
+ patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+}
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b371d397777280..7203c33e7eb11f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -199,12 +199,12 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
-// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
@@ -223,11 +223,11 @@ func.func @_QPsub9() {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
-// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
-// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
+// CHECK: %[[DST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
@@ -380,9 +380,12 @@ func.func @_QPdevice_addr_conv() {
}
// CHECK-LABEL: func.func @_QPdevice_addr_conv()
-// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
-// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
+// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+// CHECK: fir.embox %[[DECL]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
new file mode 100644
index 00000000000000..6d6022af6df8cd
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -0,0 +1,34 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
+ %0 = fir.zero_bits !fir.array<10xi32>
+ fir.has_value %0 : !fir.array<10xi32>
+}
+func.func @_QQmain() attributes {fir.bindc_name = "test"} {
+ %c14_i32 = arith.constant 14 : i32
+ %c6_i32 = arith.constant 6 : i32
+ %c4 = arith.constant 4 : index
+ %c1_i32 = arith.constant 1 : i32
+ %c0_i32 = arith.constant 0 : i32
+ %c10 = arith.constant 10 : index
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %3 = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
+ %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+ %5 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+ %6 = fir.declare %5 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+ fir.store %c0_i32 to %6 : !fir.ref<i32>
+ %7 = fir.array_coor %4(%1) %c4 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ cuf.data_transfer %c1_i32 to %7 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
+ return
+}
+
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %{{.*}} = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+
\ No newline at end of file
>From 407e881b949ef5e8cb349ffbec1d985873b0e779 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 4 Dec 2024 10:35:48 -0800
Subject: [PATCH 2/2] Check array_coor operation
---
flang/test/Fir/CUDA/cuda-global-addr.mlir | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 6d6022af6df8cd..2baead4010f5c5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -30,5 +30,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
-// CHECK: %{{.*}} = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
-
\ No newline at end of file
+// CHECK: %[[DECL:.*]] = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor %[[DECL]](%{{.*}}) %c4{{.*}} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK: %[[ARRAY_COOR_PTR:.*]] = fir.convert %[[ARRAY_COOR]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[ARRAY_COOR_PTR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
More information about the flang-commits
mailing list