[flang-commits] [flang] [flang][acc] Generate acc.bounds operation from FIR shape (PR #136637)

Razvan Lupusoru via flang-commits flang-commits at lists.llvm.org
Mon Apr 21 17:43:37 PDT 2025


https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/136637

This PR adds support to be able to generate `acc.bounds` operation through `MappableType`'s `generateAccBounds` when there is no fir.box entity. This is especially useful because the FIR type does not capture size information for explicit-shape arrays and current implementation relied on finding the box entity.

This scenario is possible because during HLFIRtoFIR, `fir.array_coor` and `fir.box_addr` operations are often optimized to use raw address. If one tries to map the ssa value that represents such a variable, correct dimensions need extracted from the shape information held in the fir declare operation.

>From 3a35af853b21b7d07ab7ddcd9923ceb66ec92266 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Mon, 21 Apr 2025 17:00:15 -0700
Subject: [PATCH] [flang][acc] Generate acc.bounds operation from FIR shape

This PR adds support to be able to generate `acc.bounds` operation
through `MappableType`'s `generateAccBounds` when there is no fir.box
entity. This is especially useful because the FIR type does not capture
size information for explicit-shape arrays and current implementation
relied on finding the box entity.

This scenario is possible because during HLFIRtoFIR, `fir.array_coor`
and `fir.box_addr` operations are often optimized to use raw address. If
one tries to map the ssa value that represents such a variable, correct
dimensions need extracted from the shape information held in the fir
declare operation.
---
 .../OpenACC/FIROpenACCTypeInterfaces.cpp      | 72 +++++++++++++++++++
 flang/test/Fir/OpenACC/openacc-mappable.fir   | 72 ++++++++++++++++---
 2 files changed, 135 insertions(+), 9 deletions(-)

diff --git a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
index 38c9fc5bbb52c..2d0d032d08b3c 100644
--- a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
@@ -188,6 +188,78 @@ OpenACCMappableModel<fir::SequenceType>::generateAccBounds(
                                                mlir::acc::DataBoundsType>(
           firBuilder, loc, exv, info);
     }
+
+    if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(varPtr.getDefiningOp())) {
+      mlir::Value zero =
+          firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0);
+      mlir::Value one =
+          firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1);
+
+      mlir::Value shape;
+      if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(
+              varPtr.getDefiningOp())) {
+        shape = declareOp.getShape();
+      } else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
+                     varPtr.getDefiningOp())) {
+        shape = declareOp.getShape();
+      }
+
+      const bool strideIncludeLowerExtent = true;
+
+      llvm::SmallVector<mlir::Value> accBounds;
+      if (auto shapeOp =
+              mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp())) {
+        mlir::Value cummulativeExtent = one;
+        for (auto extent : shapeOp.getExtents()) {
+          mlir::Value upperbound =
+              builder.create<mlir::arith::SubIOp>(loc, extent, one);
+          mlir::Value stride = one;
+          if (strideIncludeLowerExtent) {
+            stride = cummulativeExtent;
+            cummulativeExtent = builder.create<mlir::arith::MulIOp>(
+                loc, cummulativeExtent, extent);
+          }
+          auto accBound = builder.create<mlir::acc::DataBoundsOp>(
+              loc, mlir::acc::DataBoundsType::get(builder.getContext()),
+              /*lowerbound=*/zero, /*upperbound=*/upperbound,
+              /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
+              /*startIdx=*/one);
+          accBounds.push_back(accBound);
+        }
+      } else if (auto shapeShiftOp =
+                     mlir::dyn_cast_if_present<fir::ShapeShiftOp>(
+                         shape.getDefiningOp())) {
+        mlir::Value lowerbound;
+        mlir::Value cummulativeExtent = one;
+        for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) {
+          if (idx % 2 == 0) {
+            lowerbound = val;
+          } else {
+            mlir::Value extent = val;
+            mlir::Value upperbound =
+                builder.create<mlir::arith::SubIOp>(loc, extent, one);
+            upperbound = builder.create<mlir::arith::AddIOp>(loc, lowerbound,
+                                                             upperbound);
+            mlir::Value stride = one;
+            if (strideIncludeLowerExtent) {
+              stride = cummulativeExtent;
+              cummulativeExtent = builder.create<mlir::arith::MulIOp>(
+                  loc, cummulativeExtent, extent);
+            }
+            auto accBound = builder.create<mlir::acc::DataBoundsOp>(
+                loc, mlir::acc::DataBoundsType::get(builder.getContext()),
+                /*lowerbound=*/zero, /*upperbound=*/upperbound,
+                /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
+                /*startIdx=*/lowerbound);
+            accBounds.push_back(accBound);
+          }
+        }
+      }
+
+      if (!accBounds.empty())
+        return accBounds;
+    }
+
     assert(false && "array with unknown dimension expected to have descriptor");
     return {};
   }
diff --git a/flang/test/Fir/OpenACC/openacc-mappable.fir b/flang/test/Fir/OpenACC/openacc-mappable.fir
index 005f002c491a5..3e3e455469f69 100644
--- a/flang/test/Fir/OpenACC/openacc-mappable.fir
+++ b/flang/test/Fir/OpenACC/openacc-mappable.fir
@@ -2,6 +2,7 @@
 // RUN: fir-opt %s -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' -split-input-file --mlir-disable-threading 2>&1 | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"} {
+  // This test exercises explicit-shape local array of form "arr(2:10)"
   func.func @_QPsub() {
     %c2 = arith.constant 2 : index
     %c10 = arith.constant 10 : index
@@ -15,13 +16,66 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,
     acc.enter_data dataOperands(%5, %6 : !fir.box<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
     return
   }
-}
 
-// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
-// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
-// CHECK: Type category: array
-// CHECK: Size: 40
-// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
-// CHECK: Mappable: !fir.array<10xf32>
-// CHECK: Type category: array
-// CHECK: Size: 40
+  // CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
+  // CHECK: Mappable: !fir.box<!fir.array<10xf32>>
+  // CHECK: Type category: array
+  // CHECK: Size: 40
+
+  // CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
+  // CHECK: Mappable: !fir.array<10xf32>
+  // CHECK: Type category: array
+  // CHECK: Size: 40
+
+  // This second test exercises argument of explicit-shape arrays in following forms:
+  // `real :: arr1(nn), arr2(2:nn), arr3(10)`
+  // It uses the reference instead of the box in the clauses to test that bounds
+  // can be generated from the shape operations.
+  func.func @_QPacc_explicit_shape(%arg0: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "arr1"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "arr2"}, %arg2: !fir.ref<i32> {fir.bindc_name = "nn"}) {
+    %c-1 = arith.constant -1 : index
+    %c2 = arith.constant 2 : index
+    %c0 = arith.constant 0 : index
+    %c10 = arith.constant 10 : index
+    %0 = fir.dummy_scope : !fir.dscope
+    %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEnn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca !fir.array<10xf32> {bindc_name = "arr3", uniq_name = "_QFacc_explicit_shapeEarr3"}
+    %3 = fir.shape %c10 : (index) -> !fir.shape<1>
+    %4:2 = hlfir.declare %2(%3) {uniq_name = "_QFacc_explicit_shapeEarr3"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+    %5 = fir.load %1#0 : !fir.ref<i32>
+    %6 = fir.convert %5 : (i32) -> index
+    %7 = arith.cmpi sgt, %6, %c0 : index
+    %8 = arith.select %7, %6, %c0 : index
+    %9 = fir.shape %8 : (index) -> !fir.shape<1>
+    %10:2 = hlfir.declare %arg0(%9) dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEarr1"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+    %11 = arith.addi %6, %c-1 : index
+    %12 = arith.cmpi sgt, %11, %c0 : index
+    %13 = arith.select %12, %11, %c0 : index
+    %14 = fir.shape_shift %c2, %13 : (index, index) -> !fir.shapeshift<1>
+    %15:2 = hlfir.declare %arg1(%14) dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEarr2"} : (!fir.ref<!fir.array<?xf32>>, !fir.shapeshift<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+    %16 = acc.copyin var(%10#1 : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr1", structured = false}
+    %17 = acc.copyin var(%15#1 : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr2", structured = false}
+    %18 = acc.copyin varPtr(%4#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr3", structured = false}
+    acc.enter_data dataOperands(%16, %17, %18 : !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<10xf32>>)
+    return
+  }
+
+  // CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr1", structured = false}
+  // CHECK: Pointer-like: !fir.ref<!fir.array<?xf32>>
+  // CHECK: Mappable: !fir.array<?xf32>
+  // CHECK: Type category: array
+  // CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
+
+  // CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr2", structured = false}
+  // CHECK: Pointer-like: !fir.ref<!fir.array<?xf32>>
+  // CHECK: Mappable: !fir.array<?xf32>
+  // CHECK: Type category: array
+  // CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c2{{.*}} : index)
+
+  // CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr3", structured = false}
+  // CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
+  // CHECK: Mappable: !fir.array<10xf32>
+  // CHECK: Type category: array
+  // CHECK: Size: 40
+  // CHECK: Offset: 0
+  // CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%c10{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
+}



More information about the flang-commits mailing list