[flang-commits] [flang] [flang] Enabled pulling of rebox into array_coor. (PR #199161)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue May 26 10:02:08 PDT 2026


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/199161

>From 419cd18ed35b2606abd6c0190a4eedf964abdf25 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 21 May 2026 20:27:21 -0700
Subject: [PATCH 1/2] [flang] Enabled pulling of rebox into array_coor.

This patch enables pulling slicing `fir.rebox` operations
into `fir.array_coor`. This helps preserve information about
the original rank of the array being accessed.
`FIRToMemRef` and later passes may benefit from this.

Assisted by: Claude
---
 .../flang/Optimizer/Builder/FIRBuilder.h      | 13 +----
 .../flang/Optimizer/Dialect/FIRBoxUtils.h     | 29 ++++++++++
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    | 26 ---------
 flang/lib/Optimizer/Dialect/CMakeLists.txt    |  1 +
 flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp   | 42 ++++++++++++++
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 31 ++++++++++-
 .../test/Fir/array-coor-canonicalization.fir  | 55 ++++++++++++++-----
 7 files changed, 146 insertions(+), 51 deletions(-)
 create mode 100644 flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
 create mode 100644 flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index dc99174d6b993..4a386a3559c6d 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -16,6 +16,7 @@
 #ifndef FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
 #define FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
 
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -736,6 +737,8 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
 
 namespace fir::factory {
 
+using fir::genDimInfoFromBox;
+
 //===----------------------------------------------------------------------===//
 // ExtendedValue inquiry helpers
 //===----------------------------------------------------------------------===//
@@ -972,16 +975,6 @@ uint64_t getProgramAddressSpace(mlir::DataLayout *dataLayout);
 llvm::SmallVector<mlir::Value> updateRuntimeExtentsForEmptyArrays(
     fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents);
 
-/// Given \p box of type fir::BaseBoxType representing an array,
-/// the function generates code to fetch the lower bounds,
-/// the extents and the strides from the box. The values are returned via
-/// \p lbounds, \p extents and \p strides.
-void genDimInfoFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
-                       mlir::Value box,
-                       llvm::SmallVectorImpl<mlir::Value> *lbounds,
-                       llvm::SmallVectorImpl<mlir::Value> *extents,
-                       llvm::SmallVectorImpl<mlir::Value> *strides);
-
 /// Generate an LLVM dialect lifetime start marker at the current insertion
 /// point given an fir.alloca. Returns the value to be passed to the lifetime
 /// end marker.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h b/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
new file mode 100644
index 0000000000000..c3a752d1c1e0e
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
@@ -0,0 +1,29 @@
+//===-- Optimizer/Dialect/FIRBoxUtils.h -- FIR box utilities --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
+#define FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace fir {
+
+/// Given \p box of type fir::BaseBoxType representing an array, generate code
+/// to fetch the lower bounds, extents, and/or strides from the box. Non-null
+/// output pointers receive the corresponding values.
+void genDimInfoFromBox(mlir::OpBuilder &builder, mlir::Location loc,
+                       mlir::Value box,
+                       llvm::SmallVectorImpl<mlir::Value> *lbounds,
+                       llvm::SmallVectorImpl<mlir::Value> *extents,
+                       llvm::SmallVectorImpl<mlir::Value> *strides);
+
+} // namespace fir
+
+#endif // FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 4ce5c6955b5c6..759955eeadcd1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1979,32 +1979,6 @@ llvm::SmallVector<mlir::Value> fir::factory::updateRuntimeExtentsForEmptyArrays(
   return newExtents;
 }
 
-void fir::factory::genDimInfoFromBox(
-    fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value box,
-    llvm::SmallVectorImpl<mlir::Value> *lbounds,
-    llvm::SmallVectorImpl<mlir::Value> *extents,
-    llvm::SmallVectorImpl<mlir::Value> *strides) {
-  auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
-  assert(boxType && "must be a box");
-  if (!lbounds && !extents && !strides)
-    return;
-
-  unsigned rank = fir::getBoxRank(boxType);
-  assert(!boxType.isAssumedRank() && "must be an array of known rank");
-  mlir::Type idxTy = builder.getIndexType();
-  for (unsigned i = 0; i < rank; ++i) {
-    mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
-    auto dimInfo =
-        fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, idxTy, box, dim);
-    if (lbounds)
-      lbounds->push_back(dimInfo.getLowerBound());
-    if (extents)
-      extents->push_back(dimInfo.getExtent());
-    if (strides)
-      strides->push_back(dimInfo.getByteStride());
-  }
-}
-
 mlir::Value fir::factory::genLifetimeStart(mlir::OpBuilder &builder,
                                            mlir::Location loc,
                                            fir::AllocaOp alloc,
diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt
index 8fc076d88b78a..a50736d17f987 100644
--- a/flang/lib/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(MIF)
 add_flang_library(FIRDialect
   CUDAKernelOpInterface.cpp
   FIRAttr.cpp
+  FIRBoxUtils.cpp
   FIRDialect.cpp
   FIROperationMoveOpInterface.cpp
   FIROps.cpp
diff --git a/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp b/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp
new file mode 100644
index 0000000000000..b0fe1939a1e0f
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp
@@ -0,0 +1,42 @@
+//===-- FIRBoxUtils.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace fir {
+
+void genDimInfoFromBox(mlir::OpBuilder &builder, mlir::Location loc,
+                       mlir::Value box,
+                       llvm::SmallVectorImpl<mlir::Value> *lbounds,
+                       llvm::SmallVectorImpl<mlir::Value> *extents,
+                       llvm::SmallVectorImpl<mlir::Value> *strides) {
+  auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
+  assert(boxType && "must be a box");
+  if (!lbounds && !extents && !strides)
+    return;
+
+  unsigned rank = fir::getBoxRank(boxType);
+  assert(!boxType.isAssumedRank() && "must be an array of known rank");
+  mlir::Type idxTy = builder.getIndexType();
+  for (unsigned i = 0; i < rank; ++i) {
+    mlir::Value dim = mlir::arith::ConstantIndexOp::create(builder, loc, i);
+    auto dimInfo =
+        fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, idxTy, box, dim);
+    if (lbounds)
+      lbounds->push_back(dimInfo.getLowerBound());
+    if (extents)
+      extents->push_back(dimInfo.getExtent());
+    if (strides)
+      strides->push_back(dimInfo.getByteStride());
+  }
+}
+
+} // namespace fir
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 937f5c3f07e7d..767b052276f24 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/CUDAKernelOpInterface.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -694,8 +695,18 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
       // ranks differ is out of scope.
       if (op.getSlice())
         return mlir::failure();
-      if (!boxedShape)
-        return mlir::failure();
+      if (!boxedShape) {
+        // fir.rebox with a rank-reducing slice often has no shape operand.
+        // Synthesize fir.shape from box_dims extents. Do not synthesize
+        // fir.shift: rebox descriptors already have (lb-1)*stride in base_addr
+        // and FIRToMemRef must not subtract lb again for rebox memrefs.
+        if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()))
+          return mlir::failure();
+        boxedShape = getShapeFromBoxDims(boxedMemref, origBoxRank, op.getLoc(),
+                                         op, rewriter);
+        if (!boxedShape)
+          return mlir::failure();
+      }
       // Avoid emitting a plain ref array_coor whose shape is a ShiftType:
       // the verifier rejects this (shift can only pair with fir.box memref).
       if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()) &&
@@ -1025,6 +1036,22 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
     return typeparams;
   }
 
+  // Build fir.shape<rank> from fir.box_dims extents of box.
+  static mlir::Value getShapeFromBoxDims(mlir::Value box, unsigned rank,
+                                         mlir::Location loc,
+                                         mlir::Operation *insertBefore,
+                                         mlir::PatternRewriter &rewriter) {
+    auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
+    if (!boxType || boxType.isAssumedRank() || fir::getBoxRank(boxType) != rank)
+      return nullptr;
+    mlir::OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(insertBefore);
+    llvm::SmallVector<mlir::Value> extents;
+    fir::genDimInfoFromBox(rewriter, loc, box, /*lbounds=*/nullptr, &extents,
+                           /*strides=*/nullptr);
+    return fir::ShapeOp::create(rewriter, loc, extents);
+  }
+
   // If v is a shape_shift operation:
   //   fir.shape_shift %l1, %e1, %l2, %e2, ...
   // create:
diff --git a/flang/test/Fir/array-coor-canonicalization.fir b/flang/test/Fir/array-coor-canonicalization.fir
index 5c06d709c524a..88e23b404f2b6 100644
--- a/flang/test/Fir/array-coor-canonicalization.fir
+++ b/flang/test/Fir/array-coor-canonicalization.fir
@@ -633,19 +633,22 @@ func.func @test21_box_addr_rank_reducing(%arg0: !fir.ref<!fir.array<16x6xf32>>,
 }
 
 // Rank-reducing slice via fir.rebox with no shift: 3D source box -> 2D box.
-// The rank-reducing branch currently requires the rebox to carry an explicit
-// shift (so the new array_coor on the source box gets a valid shape operand
-// of source rank). Without a shift we leave the IR unchanged.
-// CHECK-LABEL:   func.func @test22_rebox_rank_reducing(
-// CHECK-SAME:                      %[[VAL_0:.*]]: !fir.box<!fir.array<?x?x?xf32>>,
-// CHECK-SAME:                      %[[VAL_1:.*]]: index) -> !fir.ref<f32> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 100 : index
-// CHECK:           %[[VAL_4:.*]] = fir.undefined index
-// CHECK:           %[[VAL_5:.*]] = fir.slice %[[VAL_2]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]], %[[VAL_4]], %[[VAL_4]], %[[VAL_2]], %[[VAL_3]], %[[VAL_2]] : (index, index, index, index, index, index, index, index, index) -> !fir.slice<3>
-// CHECK:           %[[VAL_6:.*]] = fir.rebox %[[VAL_0]] {{\[}}%[[VAL_5]]] : (!fir.box<!fir.array<?x?x?xf32>>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xf32>>
-// CHECK:           %[[VAL_7:.*]] = fir.array_coor %[[VAL_6]] %[[VAL_2]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
-// CHECK:           return %[[VAL_7]] : !fir.ref<f32>
+// array_coor on the sliced box is folded to a full-rank array_coor on the
+// source box; fir.shape is synthesized from fir.box_dims when rebox has no
+// shape operand (fir.shift is not used: rebox base_addr is pre-adjusted).
+// CHECK-LABEL:   func.func @test22_rebox_rank_reducing({{.*}}) -> !fir.ref<f32> {
+// CHECK:           %[[T22_C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[T22_C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[T22_C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[T22_C100:.*]] = arith.constant 100 : index
+// CHECK:           %[[T22_U:.*]] = fir.undefined index
+// CHECK:           %[[T22_SL:.*]] = fir.slice %[[T22_C1]], %[[T22_C100]], %[[T22_C1]], %{{.*}}, %[[T22_U]], %[[T22_U]], %[[T22_C1]], %[[T22_C100]], %[[T22_C1]] : (index, index, index, index, index, index, index, index, index) -> !fir.slice<3>
+// CHECK:           %[[T22_BD0:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C0]]
+// CHECK:           %[[T22_BD1:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C1]]
+// CHECK:           %[[T22_BD2:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C2]]
+// CHECK:           %[[T22_SH:.*]] = fir.shape %[[T22_BD0]]#1, %[[T22_BD1]]#1, %[[T22_BD2]]#1 : (index, index, index) -> !fir.shape<3>
+// CHECK:           %[[T22_AC:.*]] = fir.array_coor %{{.*}}(%[[T22_SH]]) {{\[}}%[[T22_SL]]] %[[T22_C1]], %{{.*}}, %[[T22_C1]] : (!fir.box<!fir.array<?x?x?xf32>>, !fir.shape<3>, !fir.slice<3>, index, index, index) -> !fir.ref<f32>
+// CHECK:           return %[[T22_AC]] : !fir.ref<f32>
 // CHECK:         }
 func.func @test22_rebox_rank_reducing(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %j: index) -> !fir.ref<f32> {
   %c1 = arith.constant 1 : index
@@ -849,3 +852,29 @@ func.func @test29_merged_box_addr_no_fold(%arg0: !fir.ref<!fir.array<4x3xf32>>)
   %ac = fir.array_coor %ba(%sh) %c1, %c1 : (!fir.ref<!fir.array<4x3xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
   return %ac : !fir.ref<f32>
 }
+
+// Rank-reducing 2D->1D slice via fir.rebox: array_coor uses the sliced 1D box
+// directly; fold to 2-index array_coor on the 2D source.
+// CHECK-LABEL:   func.func @test30_rebox_rank_reducing_2d({{.*}}) -> !fir.ref<f32> {
+// CHECK:           %[[T29_C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[T29_C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[T29_C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[T29_U:.*]] = fir.undefined index
+// CHECK:           %[[T29_SL:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[T29_C1]], %[[T29_C2]], %[[T29_U]], %[[T29_U]] : (index, index, index, index, index, index) -> !fir.slice<2>
+// CHECK:           %[[T29_BD0:.*]]:3 = fir.box_dims %{{.*}}, %[[T29_C0]]
+// CHECK:           %[[T29_BD1:.*]]:3 = fir.box_dims %{{.*}}, %[[T29_C1]]
+// CHECK:           %[[T29_SH:.*]] = fir.shape %[[T29_BD0]]#1, %[[T29_BD1]]#1 : (index, index) -> !fir.shape<2>
+// CHECK:           %[[T29_AC:.*]] = fir.array_coor %{{.*}}(%[[T29_SH]]) {{\[}}%[[T29_SL]]] %{{.*}}, %[[T29_C2]] : (!fir.box<!fir.array<?x?xf32>>, !fir.shape<2>, !fir.slice<2>, index, index) -> !fir.ref<f32>
+// CHECK:           return %[[T29_AC]] : !fir.ref<f32>
+// CHECK:         }
+func.func @test30_rebox_rank_reducing_2d(
+    %arg0: !fir.box<!fir.array<?x?xf32>>, %si: index, %ei: index,
+    %row: index) -> !fir.ref<f32> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %u = fir.undefined index
+  %s = fir.slice %si, %ei, %c1, %c2, %u, %u : (index, index, index, index, index, index) -> !fir.slice<2>
+  %b = fir.rebox %arg0 [%s] : (!fir.box<!fir.array<?x?xf32>>, !fir.slice<2>) -> !fir.box<!fir.array<?xf32>>
+  %ac = fir.array_coor %b %row : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+  return %ac : !fir.ref<f32>
+}

>From 83f6c5a43e8b73814e73dbd1af70579b46d4f3ab Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 26 May 2026 10:00:22 -0700
Subject: [PATCH 2/2] Fixed comment.

---
 flang/lib/Optimizer/Dialect/FIROps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 767b052276f24..55cafc2b32a3c 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -698,8 +698,8 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
       if (!boxedShape) {
         // fir.rebox with a rank-reducing slice often has no shape operand.
         // Synthesize fir.shape from box_dims extents. Do not synthesize
-        // fir.shift: rebox descriptors already have (lb-1)*stride in base_addr
-        // and FIRToMemRef must not subtract lb again for rebox memrefs.
+        // fir.shift: the descriptor produced by fir.rebox have default
+        // lower bounds.
         if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()))
           return mlir::failure();
         boxedShape = getShapeFromBoxDims(boxedMemref, origBoxRank, op.getLoc(),



More information about the flang-commits mailing list