[flang-commits] [flang] [flang][cuda] Add conversion after CUFGetDeviceAddress to avoid issue when embossing (PR #116145)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Nov 13 17:56:19 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/116145

When calling `_FortranACUFGetDeviceAddress` the return type is a raw pointer. Cast it to the real ref type so emboxing does not trigger verifier errors. 

>From 3f8a0bd6f3a6e6c26558f7ec4bb256b59a728663 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 13 Nov 2024 17:53:05 -0800
Subject: [PATCH] [flang][cuda] Add conversion after CUFGetDeviceAddress to
 avoid problem when emboxing

---
 .../Optimizer/Transforms/CUFOpConversion.cpp  |  5 ++--
 flang/test/Fir/CUDA/cuda-data-transfer.fir    | 23 +++++++++++++++++++
 2 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 58a3cdc905d36e..bca0a09c5bff65 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -140,8 +140,9 @@ mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
   llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
       builder, loc, fTy, inputArg, sourceFile, sourceLine)};
   auto call = rewriter.create<fir::CallOp>(loc, callee, args);
-
-  return call->getResult(0);
+  mlir::Value cast = createConvertOp(
+      rewriter, loc, declareOp.getMemref().getType(), call->getResult(0));
+  return cast;
 }
 
 template <typename OpTy>
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 491d417271ce74..9c6d9e0c100125 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -202,7 +202,9 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
 // CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
 // CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
 
@@ -224,6 +226,8 @@ func.func @_QPsub9() {
 // CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
 // CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
@@ -361,5 +365,24 @@ func.func @_QPshape_shift2() {
 // CHECK: %[[BYTES:.*]] = arith.muli %[[C10]], %c4{{.*}} : i64
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %[[BYTES]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmod1Ea_dev {data_attr = #cuf.cuda<device>} : !fir.array<4xf32> {
+  %0 = fir.zero_bits !fir.array<4xf32>
+  fir.has_value %0 : !fir.array<4xf32>
+}
+func.func @_QPdevice_addr_conv() {
+  %cst = arith.constant 4.200000e+01 : f32
+  %c4 = arith.constant 4 : index
+  %0 = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
+  %1 = fir.shape %c4 : (index) -> !fir.shape<1>
+  %2 = fir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+  cuf.data_transfer %cst to %2 {transfer_kind = #cuf.cuda_transfer<host_device>} : f32, !fir.ref<!fir.array<4xf32>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPdevice_addr_conv()
+// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
+// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
 
 } // end of module



More information about the flang-commits mailing list