[flang-commits] [flang] [flang][cuda] Correctly embox logical constant (PR #116445)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Nov 15 14:35:47 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/116445

When the rhs of a data transfer is a logical constant it was emboxed as `i1` but this type is not supported in the descriptor. Make sure we detect this and embox it correctly as a logical type. 
```
subroutine logical_cst
  logical*1, device, dimension(:,:), allocatable :: dev
  allocate(dev(10,10)
  Id2 = .false.
end
```

>From 0e1ab2802e85bc818c5d536ee6485811f33d61ce Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 15 Nov 2024 14:29:53 -0800
Subject: [PATCH] [flang][cuda] Correctly embox logical constant

---
 .../Optimizer/Transforms/CUFOpConversion.cpp  |  9 ++++++++-
 flang/test/Fir/CUDA/cuda-data-transfer.fir    | 20 ++++++++++++++++++-
 2 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8a6f28b9422f9b..ec7f67dff763b4 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -473,9 +473,16 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
   mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
   if (fir::isa_trivial(srcTy) &&
       mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
+    mlir::Value src = op.getSrc();
+    if (srcTy.isInteger(1)) {
+      // i1 is not a supported type in the descriptor and it is actually coming
+      // from a LOGICAL constant. Store it as a fir.logical.
+      srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
+      src = createConvertOp(rewriter, loc, srcTy, src);
+    }
     // Put constant in memory if it is not.
     mlir::Value alloc = builder.createTemporary(loc, srcTy);
-    builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+    builder.create<fir::StoreOp>(loc, src, alloc);
     addr = alloc;
   } else {
     addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 718e82ce99725d..3209197e118d19 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -385,7 +385,6 @@ func.func @_QPdevice_addr_conv() {
 // CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
 
-
 func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -447,5 +446,24 @@ func.func @_QPdevmul(%arg0: !fir.ref<!fir.array<1x?xf32>> {fir.bindc_name = "b"}
 // CHECK: %[[DST:.*]] = fir.convert %[[ALLOCA0]] : (!fir.ref<!fir.box<!fir.array<?x?xf32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[DST]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+func.func @_QPlogical_cst() {
+  %c0_i64 = arith.constant 0 : i64
+  %false = arith.constant false
+  %c0 = arith.constant 0 : index
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?x!fir.logical<1>>>> {bindc_name = "id2", data_attr = #cuf.cuda<device>, uniq_name = "_QFlogical_cstEid2"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<1>>>>>
+  %4 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlogical_cstEid2"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<1>>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<1>>>>>
+  cuf.data_transfer %false to %4 {transfer_kind = #cuf.cuda_transfer<host_device>} : i1, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<1>>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPlogical_cst()
+// CHECK: %[[DESC:.*]] = fir.alloca !fir.box<!fir.logical<4>>
+// CHECK: %[[CONST:.*]] = fir.alloca !fir.logical<4>
+// CHECK: %[[CONV:.*]] = fir.convert %false : (i1) -> !fir.logical<4>
+// CHECK: fir.store %[[CONV]] to %[[CONST]] : !fir.ref<!fir.logical<4>>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<4>>>
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<4>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 } // end of module



More information about the flang-commits mailing list