[flang-commits] [flang] 1b7c6cc - [flang][openacc] Add support for allocatable and pointer in data operand

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Tue Apr 25 21:06:34 PDT 2023


Author: Valentin Clement
Date: 2023-04-25T21:06:26-07:00
New Revision: 1b7c6cc688071b08e206669a2340f8390f0adf7a

URL: https://github.com/llvm/llvm-project/commit/1b7c6cc688071b08e206669a2340f8390f0adf7a
DIFF: https://github.com/llvm/llvm-project/commit/1b7c6cc688071b08e206669a2340f8390f0adf7a.diff

LOG: [flang][openacc] Add support for allocatable and pointer in data operand

Add lowering support for allocatable and pointer array sections
to acc.bounds and acc data operations.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D149189

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-enter-data.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6d06dfc128b56..6fa64975b405c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -131,7 +131,8 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
              Fortran::lower::AbstractConverter &converter,
              Fortran::lower::StatementContext &stmtCtx,
              const std::list<Fortran::parser::SectionSubscript> &subscripts,
-             std::stringstream &asFortran, const Fortran::parser::Name &name) {
+             std::stringstream &asFortran, const Fortran::parser::Name &name,
+             mlir::Value baseAddr) {
   int dimension = 0;
   mlir::Type idxTy = builder.getIndexType();
   mlir::Type boundTy = builder.getType<mlir::acc::DataBoundsType>();
@@ -146,17 +147,16 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
       mlir::Value lbound, ubound, extent;
       std::optional<std::int64_t> lval, uval;
       mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-      fir::ExtendedValue dataExv =
-          converter.getSymbolExtendedValue(*name.symbol);
       mlir::Value baseLb =
           fir::factory::readLowerBound(builder, loc, dataExv, dimension, one);
       bool defaultLb = baseLb == one;
       mlir::Value stride;
       bool strideInBytes = false;
-      if (fir::getBase(dataExv).getType().isa<fir::BaseBoxType>()) {
+
+      if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
         mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
         auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
-                                                      fir::getBase(dataExv), d);
+                                                      baseAddr, d);
         stride = dimInfo.getByteStride();
         strideInBytes = true;
       }
@@ -255,24 +255,33 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
     if (!symAddr)
       llvm::report_fatal_error("could not retrieve symbol address");
 
-    mlir::Type symTy = symAddr.getType();
-    if (auto refTy = symTy.dyn_cast<fir::ReferenceType>())
-      symTy = refTy.getEleTy();
-
-    if (auto boxTy =
-            fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>())
-      if (boxTy.getEleTy()
-              .isa<fir::PointerType, fir::HeapType, fir::RecordType>())
-        TODO(loc, "pointer, allocatable and derived type box");
+    if (auto boxTy = fir::unwrapRefType(symAddr.getType())
+                         .dyn_cast<fir::BaseBoxType>()) {
+      if (boxTy.getEleTy().isa<fir::RecordType>())
+        TODO(loc, "derived type");
 
+      // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
+      // `fir.ref<fir.class<T>>` type.
+      if (symAddr.getType().isa<fir::ReferenceType>())
+        return builder.create<fir::LoadOp>(loc, symAddr);
+    }
     return symAddr;
   };
 
   auto createOpAndAddOperand = [&](mlir::Value baseAddr, llvm::StringRef name,
                                    mlir::Location loc,
                                    llvm::SmallVector<mlir::Value> &bounds) {
-    if (baseAddr.getType().isa<fir::BaseBoxType>())
-      baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
+    if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
+      // Get the actual data address when the descriptor is an allocatable or
+      // a pointer.
+      if (boxTy.getEleTy().isa<fir::HeapType, fir::PointerType>()) {
+        mlir::Value boxAddr = builder.create<fir::BoxAddrOp>(
+            loc, fir::ReferenceType::get(boxTy.getEleTy()), baseAddr);
+        baseAddr = builder.create<fir::LoadOp>(loc, boxAddr);
+      } else { // Get the address of the boxed value.
+        baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
+      }
+    }
 
     Op op = builder.create<Op>(loc, baseAddr.getType(), baseAddr);
     op.setNameAttr(builder.getStringAttr(name));
@@ -308,15 +317,15 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
                       Fortran::parser::GetLastName(*dataRef);
                   std::stringstream asFortran;
                   asFortran << name.ToString();
+                  mlir::Value baseAddr =
+                      getDataOperandBaseAddr(*name.symbol, operandLocation);
                   if (!arrayElement->subscripts.empty()) {
                     asFortran << '(';
                     bounds = genBoundsOps(builder, operandLocation, converter,
                                           stmtCtx, arrayElement->subscripts,
-                                          asFortran, name);
+                                          asFortran, name, baseAddr);
                   }
                   asFortran << ')';
-                  mlir::Value baseAddr =
-                      getDataOperandBaseAddr(*name.symbol, operandLocation);
                   createOpAndAddOperand(baseAddr, asFortran.str(),
                                         operandLocation, bounds);
                 } else if (Fortran::parser::Unwrap<
@@ -332,7 +341,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
                     mlir::Value baseAddr =
                         getDataOperandBaseAddr(*name.symbol, operandLocation);
                     llvm::SmallVector<mlir::Value> bounds;
-                    if (baseAddr.getType().isa<fir::BaseBoxType>())
+                    if (fir::unwrapRefType(baseAddr.getType())
+                            .isa<fir::BaseBoxType>())
                       bounds = genBoundsOpsFromBox(builder, operandLocation,
                                                    converter, *name.symbol,
                                                    baseAddr, (*expr).Rank());

diff  --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90
index e3b02906f34a6..47b4242e4b243 100644
--- a/flang/test/Lower/OpenACC/acc-enter-data.f90
+++ b/flang/test/Lower/OpenACC/acc-enter-data.f90
@@ -41,10 +41,14 @@ subroutine acc_enter_data
 !CHECK: %[[CREATE_C:.*]] = acc.create varPtr(%[[C]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {dataClause = 8 : i64, name = "c", structured = false}
 !CHECK: acc.enter_data dataOperands(%[[CREATE_A]], %[[CREATE_B]], %[[CREATE_C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
 
-  !$acc enter data copyin(a) create(b)
+  !$acc enter data copyin(a) create(b) attach(d)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: %[[CREATE_B:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "b", structured = false}
-!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]], %[[CREATE_B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: %[[BOX_D:.*]] = fir.load %[[D]] : !fir.ref<!fir.box<!fir.ptr<f32>>> 
+!CHECK: %[[BOX_ADDR_D:.*]] = fir.box_addr %[[BOX_D]] : (!fir.box<!fir.ptr<f32>>) -> !fir.ref<!fir.ptr<f32>>
+!CHECK: %[[D_PTR:.*]] = fir.load %[[BOX_ADDR_D]] : !fir.ref<!fir.ptr<f32>> 
+!CHECK: %[[ATTACH_D:.*]] = acc.attach varPtr(%[[D_PTR]] : !fir.ptr<f32>) -> !fir.ptr<f32> {name = "d", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]], %[[CREATE_B]], %[[ATTACH_D]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ptr<f32>){{$}}
 
   !$acc enter data create(a) async
 !CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
@@ -348,3 +352,100 @@ subroutine acc_enter_data_assumed(a, b, n, m)
 
 end subroutine
 
+subroutine acc_enter_data_allocatable()
+  real, allocatable :: a(:)
+  integer, allocatable :: i
+  
+!CHECK-LABEL: func.func @_QPacc_enter_data_allocatable() {
+!CHECK: %[[A:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFacc_enter_data_allocatableEa"}
+!CHECK: %[[I:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "i", uniq_name = "_QFacc_enter_data_allocatableEi"}
+
+  !$acc enter data create(a)
+!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0_1]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS1]]#1 : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true}
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?xf32>>)
+
+  !$acc enter data create(a(:))
+!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[BOX_A_2:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[BOX_A_2]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS2]]#1 : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true}
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a(:)", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?xf32>>)
+
+  !$acc enter data create(a(2:5))
+!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C2:.*]] = arith.constant 2 : index
+!CHECK: %[[LB:.*]] = arith.subi %[[C2]], %[[DIMS0]]#0 : index
+!CHECK: %[[C5:.*]] = arith.constant 5 : index
+!CHECK: %[[UB:.*]] = arith.subi %[[C5]], %[[DIMS0]]#0 : index
+!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true}
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a(2:5)", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?xf32>>)
+
+  !$acc enter data create(a(3:))
+!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C3:.*]] = arith.constant 3 : index
+!CHECK: %[[LB:.*]] = arith.subi %[[C3]], %[[DIMS0]]#0 : index
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS2]]#1, %[[LB]] : index 
+!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true}
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a(3:)", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?xf32>>)
+
+  !$acc enter data create(a(:7))
+!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+!CHECK: %[[C7:.*]] = arith.constant 7 : index
+!CHECK: %[[UB:.*]] = arith.subi %[[C7]], %[[DIMS0]]#0 : index
+!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true}
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<!fir.array<?xf32>>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a(:7)", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?xf32>>)
+
+  !$acc enter data create(i)
+!CHECK: %[[BOX_I:.*]] = fir.load %[[I]] : !fir.ref<!fir.box<!fir.heap<i32>>>
+!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_I]] : (!fir.box<!fir.heap<i32>>) -> !fir.ref<!fir.heap<i32>>
+!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref<!fir.heap<i32>>
+!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap<i32>)   -> !fir.heap<i32> {name = "i", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<i32>)
+
+end subroutine
+


        


More information about the flang-commits mailing list