[flang-commits] [flang] [flang] Enabled pulling of rebox into array_coor. (PR #199161)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue May 26 10:02:08 PDT 2026
https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/199161
>From 419cd18ed35b2606abd6c0190a4eedf964abdf25 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 21 May 2026 20:27:21 -0700
Subject: [PATCH 1/2] [flang] Enabled pulling of rebox into array_coor.
This patch enables pulling slicing `fir.rebox` operations
into `fir.array_coor`. This helps preserve information about
the original rank of the array being accessed.
`FIRToMemRef` and later passes may benefit from this.
Assisted by: Claude
---
.../flang/Optimizer/Builder/FIRBuilder.h | 13 +----
.../flang/Optimizer/Dialect/FIRBoxUtils.h | 29 ++++++++++
flang/lib/Optimizer/Builder/FIRBuilder.cpp | 26 ---------
flang/lib/Optimizer/Dialect/CMakeLists.txt | 1 +
flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp | 42 ++++++++++++++
flang/lib/Optimizer/Dialect/FIROps.cpp | 31 ++++++++++-
.../test/Fir/array-coor-canonicalization.fir | 55 ++++++++++++++-----
7 files changed, 146 insertions(+), 51 deletions(-)
create mode 100644 flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
create mode 100644 flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index dc99174d6b993..4a386a3559c6d 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -16,6 +16,7 @@
#ifndef FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
#define FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -736,6 +737,8 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
namespace fir::factory {
+using fir::genDimInfoFromBox;
+
//===----------------------------------------------------------------------===//
// ExtendedValue inquiry helpers
//===----------------------------------------------------------------------===//
@@ -972,16 +975,6 @@ uint64_t getProgramAddressSpace(mlir::DataLayout *dataLayout);
llvm::SmallVector<mlir::Value> updateRuntimeExtentsForEmptyArrays(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents);
-/// Given \p box of type fir::BaseBoxType representing an array,
-/// the function generates code to fetch the lower bounds,
-/// the extents and the strides from the box. The values are returned via
-/// \p lbounds, \p extents and \p strides.
-void genDimInfoFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value box,
- llvm::SmallVectorImpl<mlir::Value> *lbounds,
- llvm::SmallVectorImpl<mlir::Value> *extents,
- llvm::SmallVectorImpl<mlir::Value> *strides);
-
/// Generate an LLVM dialect lifetime start marker at the current insertion
/// point given an fir.alloca. Returns the value to be passed to the lifetime
/// end marker.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h b/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
new file mode 100644
index 0000000000000..c3a752d1c1e0e
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/FIRBoxUtils.h
@@ -0,0 +1,29 @@
+//===-- Optimizer/Dialect/FIRBoxUtils.h -- FIR box utilities --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
+#define FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace fir {
+
+/// Given \p box of type fir::BaseBoxType representing an array, generate code
+/// to fetch the lower bounds, extents, and/or strides from the box. Non-null
+/// output pointers receive the corresponding values.
+void genDimInfoFromBox(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value box,
+ llvm::SmallVectorImpl<mlir::Value> *lbounds,
+ llvm::SmallVectorImpl<mlir::Value> *extents,
+ llvm::SmallVectorImpl<mlir::Value> *strides);
+
+} // namespace fir
+
+#endif // FORTRAN_OPTIMIZER_DIALECT_FIRBOXUTILS_H
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 4ce5c6955b5c6..759955eeadcd1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1979,32 +1979,6 @@ llvm::SmallVector<mlir::Value> fir::factory::updateRuntimeExtentsForEmptyArrays(
return newExtents;
}
-void fir::factory::genDimInfoFromBox(
- fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value box,
- llvm::SmallVectorImpl<mlir::Value> *lbounds,
- llvm::SmallVectorImpl<mlir::Value> *extents,
- llvm::SmallVectorImpl<mlir::Value> *strides) {
- auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
- assert(boxType && "must be a box");
- if (!lbounds && !extents && !strides)
- return;
-
- unsigned rank = fir::getBoxRank(boxType);
- assert(!boxType.isAssumedRank() && "must be an array of known rank");
- mlir::Type idxTy = builder.getIndexType();
- for (unsigned i = 0; i < rank; ++i) {
- mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
- auto dimInfo =
- fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, idxTy, box, dim);
- if (lbounds)
- lbounds->push_back(dimInfo.getLowerBound());
- if (extents)
- extents->push_back(dimInfo.getExtent());
- if (strides)
- strides->push_back(dimInfo.getByteStride());
- }
-}
-
mlir::Value fir::factory::genLifetimeStart(mlir::OpBuilder &builder,
mlir::Location loc,
fir::AllocaOp alloc,
diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt
index 8fc076d88b78a..a50736d17f987 100644
--- a/flang/lib/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(MIF)
add_flang_library(FIRDialect
CUDAKernelOpInterface.cpp
FIRAttr.cpp
+ FIRBoxUtils.cpp
FIRDialect.cpp
FIROperationMoveOpInterface.cpp
FIROps.cpp
diff --git a/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp b/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp
new file mode 100644
index 0000000000000..b0fe1939a1e0f
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/FIRBoxUtils.cpp
@@ -0,0 +1,42 @@
+//===-- FIRBoxUtils.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace fir {
+
+void genDimInfoFromBox(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value box,
+ llvm::SmallVectorImpl<mlir::Value> *lbounds,
+ llvm::SmallVectorImpl<mlir::Value> *extents,
+ llvm::SmallVectorImpl<mlir::Value> *strides) {
+ auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
+ assert(boxType && "must be a box");
+ if (!lbounds && !extents && !strides)
+ return;
+
+ unsigned rank = fir::getBoxRank(boxType);
+ assert(!boxType.isAssumedRank() && "must be an array of known rank");
+ mlir::Type idxTy = builder.getIndexType();
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value dim = mlir::arith::ConstantIndexOp::create(builder, loc, i);
+ auto dimInfo =
+ fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, idxTy, box, dim);
+ if (lbounds)
+ lbounds->push_back(dimInfo.getLowerBound());
+ if (extents)
+ extents->push_back(dimInfo.getExtent());
+ if (strides)
+ strides->push_back(dimInfo.getByteStride());
+ }
+}
+
+} // namespace fir
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 937f5c3f07e7d..767b052276f24 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -13,6 +13,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/CUDAKernelOpInterface.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
+#include "flang/Optimizer/Dialect/FIRBoxUtils.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -694,8 +695,18 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
// ranks differ is out of scope.
if (op.getSlice())
return mlir::failure();
- if (!boxedShape)
- return mlir::failure();
+ if (!boxedShape) {
+ // fir.rebox with a rank-reducing slice often has no shape operand.
+ // Synthesize fir.shape from box_dims extents. Do not synthesize
+ // fir.shift: rebox descriptors already have (lb-1)*stride in base_addr
+ // and FIRToMemRef must not subtract lb again for rebox memrefs.
+ if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()))
+ return mlir::failure();
+ boxedShape = getShapeFromBoxDims(boxedMemref, origBoxRank, op.getLoc(),
+ op, rewriter);
+ if (!boxedShape)
+ return mlir::failure();
+ }
// Avoid emitting a plain ref array_coor whose shape is a ShiftType:
// the verifier rejects this (shift can only pair with fir.box memref).
if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()) &&
@@ -1025,6 +1036,22 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
return typeparams;
}
+ // Build fir.shape<rank> from fir.box_dims extents of box.
+ static mlir::Value getShapeFromBoxDims(mlir::Value box, unsigned rank,
+ mlir::Location loc,
+ mlir::Operation *insertBefore,
+ mlir::PatternRewriter &rewriter) {
+ auto boxType = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
+ if (!boxType || boxType.isAssumedRank() || fir::getBoxRank(boxType) != rank)
+ return nullptr;
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(insertBefore);
+ llvm::SmallVector<mlir::Value> extents;
+ fir::genDimInfoFromBox(rewriter, loc, box, /*lbounds=*/nullptr, &extents,
+ /*strides=*/nullptr);
+ return fir::ShapeOp::create(rewriter, loc, extents);
+ }
+
// If v is a shape_shift operation:
// fir.shape_shift %l1, %e1, %l2, %e2, ...
// create:
diff --git a/flang/test/Fir/array-coor-canonicalization.fir b/flang/test/Fir/array-coor-canonicalization.fir
index 5c06d709c524a..88e23b404f2b6 100644
--- a/flang/test/Fir/array-coor-canonicalization.fir
+++ b/flang/test/Fir/array-coor-canonicalization.fir
@@ -633,19 +633,22 @@ func.func @test21_box_addr_rank_reducing(%arg0: !fir.ref<!fir.array<16x6xf32>>,
}
// Rank-reducing slice via fir.rebox with no shift: 3D source box -> 2D box.
-// The rank-reducing branch currently requires the rebox to carry an explicit
-// shift (so the new array_coor on the source box gets a valid shape operand
-// of source rank). Without a shift we leave the IR unchanged.
-// CHECK-LABEL: func.func @test22_rebox_rank_reducing(
-// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x?x?xf32>>,
-// CHECK-SAME: %[[VAL_1:.*]]: index) -> !fir.ref<f32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 100 : index
-// CHECK: %[[VAL_4:.*]] = fir.undefined index
-// CHECK: %[[VAL_5:.*]] = fir.slice %[[VAL_2]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]], %[[VAL_4]], %[[VAL_4]], %[[VAL_2]], %[[VAL_3]], %[[VAL_2]] : (index, index, index, index, index, index, index, index, index) -> !fir.slice<3>
-// CHECK: %[[VAL_6:.*]] = fir.rebox %[[VAL_0]] {{\[}}%[[VAL_5]]] : (!fir.box<!fir.array<?x?x?xf32>>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xf32>>
-// CHECK: %[[VAL_7:.*]] = fir.array_coor %[[VAL_6]] %[[VAL_2]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
-// CHECK: return %[[VAL_7]] : !fir.ref<f32>
+// array_coor on the sliced box is folded to a full-rank array_coor on the
+// source box; fir.shape is synthesized from fir.box_dims when rebox has no
+// shape operand (fir.shift is not used: rebox base_addr is pre-adjusted).
+// CHECK-LABEL: func.func @test22_rebox_rank_reducing({{.*}}) -> !fir.ref<f32> {
+// CHECK: %[[T22_C2:.*]] = arith.constant 2 : index
+// CHECK: %[[T22_C0:.*]] = arith.constant 0 : index
+// CHECK: %[[T22_C1:.*]] = arith.constant 1 : index
+// CHECK: %[[T22_C100:.*]] = arith.constant 100 : index
+// CHECK: %[[T22_U:.*]] = fir.undefined index
+// CHECK: %[[T22_SL:.*]] = fir.slice %[[T22_C1]], %[[T22_C100]], %[[T22_C1]], %{{.*}}, %[[T22_U]], %[[T22_U]], %[[T22_C1]], %[[T22_C100]], %[[T22_C1]] : (index, index, index, index, index, index, index, index, index) -> !fir.slice<3>
+// CHECK: %[[T22_BD0:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C0]]
+// CHECK: %[[T22_BD1:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C1]]
+// CHECK: %[[T22_BD2:.*]]:3 = fir.box_dims %{{.*}}, %[[T22_C2]]
+// CHECK: %[[T22_SH:.*]] = fir.shape %[[T22_BD0]]#1, %[[T22_BD1]]#1, %[[T22_BD2]]#1 : (index, index, index) -> !fir.shape<3>
+// CHECK: %[[T22_AC:.*]] = fir.array_coor %{{.*}}(%[[T22_SH]]) {{\[}}%[[T22_SL]]] %[[T22_C1]], %{{.*}}, %[[T22_C1]] : (!fir.box<!fir.array<?x?x?xf32>>, !fir.shape<3>, !fir.slice<3>, index, index, index) -> !fir.ref<f32>
+// CHECK: return %[[T22_AC]] : !fir.ref<f32>
// CHECK: }
func.func @test22_rebox_rank_reducing(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %j: index) -> !fir.ref<f32> {
%c1 = arith.constant 1 : index
@@ -849,3 +852,29 @@ func.func @test29_merged_box_addr_no_fold(%arg0: !fir.ref<!fir.array<4x3xf32>>)
%ac = fir.array_coor %ba(%sh) %c1, %c1 : (!fir.ref<!fir.array<4x3xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
return %ac : !fir.ref<f32>
}
+
+// Rank-reducing 2D->1D slice via fir.rebox: array_coor uses the sliced 1D box
+// directly; fold to 2-index array_coor on the 2D source.
+// CHECK-LABEL: func.func @test30_rebox_rank_reducing_2d({{.*}}) -> !fir.ref<f32> {
+// CHECK: %[[T29_C0:.*]] = arith.constant 0 : index
+// CHECK: %[[T29_C1:.*]] = arith.constant 1 : index
+// CHECK: %[[T29_C2:.*]] = arith.constant 2 : index
+// CHECK: %[[T29_U:.*]] = fir.undefined index
+// CHECK: %[[T29_SL:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[T29_C1]], %[[T29_C2]], %[[T29_U]], %[[T29_U]] : (index, index, index, index, index, index) -> !fir.slice<2>
+// CHECK: %[[T29_BD0:.*]]:3 = fir.box_dims %{{.*}}, %[[T29_C0]]
+// CHECK: %[[T29_BD1:.*]]:3 = fir.box_dims %{{.*}}, %[[T29_C1]]
+// CHECK: %[[T29_SH:.*]] = fir.shape %[[T29_BD0]]#1, %[[T29_BD1]]#1 : (index, index) -> !fir.shape<2>
+// CHECK: %[[T29_AC:.*]] = fir.array_coor %{{.*}}(%[[T29_SH]]) {{\[}}%[[T29_SL]]] %{{.*}}, %[[T29_C2]] : (!fir.box<!fir.array<?x?xf32>>, !fir.shape<2>, !fir.slice<2>, index, index) -> !fir.ref<f32>
+// CHECK: return %[[T29_AC]] : !fir.ref<f32>
+// CHECK: }
+func.func @test30_rebox_rank_reducing_2d(
+ %arg0: !fir.box<!fir.array<?x?xf32>>, %si: index, %ei: index,
+ %row: index) -> !fir.ref<f32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %u = fir.undefined index
+ %s = fir.slice %si, %ei, %c1, %c2, %u, %u : (index, index, index, index, index, index) -> !fir.slice<2>
+ %b = fir.rebox %arg0 [%s] : (!fir.box<!fir.array<?x?xf32>>, !fir.slice<2>) -> !fir.box<!fir.array<?xf32>>
+ %ac = fir.array_coor %b %row : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+ return %ac : !fir.ref<f32>
+}
>From 83f6c5a43e8b73814e73dbd1af70579b46d4f3ab Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 26 May 2026 10:00:22 -0700
Subject: [PATCH 2/2] Fixed comment.
---
flang/lib/Optimizer/Dialect/FIROps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 767b052276f24..55cafc2b32a3c 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -698,8 +698,8 @@ struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
if (!boxedShape) {
// fir.rebox with a rank-reducing slice often has no shape operand.
// Synthesize fir.shape from box_dims extents. Do not synthesize
- // fir.shift: rebox descriptors already have (lb-1)*stride in base_addr
- // and FIRToMemRef must not subtract lb again for rebox memrefs.
+ // fir.shift: the descriptor produced by fir.rebox have default
+ // lower bounds.
if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType()))
return mlir::failure();
boxedShape = getShapeFromBoxDims(boxedMemref, origBoxRank, op.getLoc(),
More information about the flang-commits
mailing list