[flang-commits] [flang] [flang] Lowering a ArrayCoorOp to arithmetic computations when a fir memref is a block argument (PR #182139)
Susan Tan ス-ザン タン via flang-commits
flang-commits at lists.llvm.org
Wed Feb 18 12:56:48 PST 2026
https://github.com/SusanTan created https://github.com/llvm/llvm-project/pull/182139
Remove the special-case that handled `fir.array_coor` with a block-argument base by converting the element ref result (!fir.ref<i32> -> memref<i32>) and leaving fir.array_coor alive.
Instead, we now always convert the base (!fir.ref<!fir.array<...>> -> memref<...>) and compute the memref indices from the fir.array_coor operands, so loads/stores become memref.load/store base[indices] and fir.array_coor can be erased when it’s only used by memory ops.
>From 24e947ee158ea512abb2163e71bdf275e7af4621 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 18 Feb 2026 12:36:15 -0800
Subject: [PATCH 1/2] change array-coor lowering
---
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 76 ++++++++++---------
.../FIRToMemRef/array-coor-block-arg.mlir | 28 +++++++
.../Transforms/FIRToMemRef/no-declare.mlir | 13 ++--
3 files changed, 74 insertions(+), 43 deletions(-)
create mode 100644 flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index bf125eb8d04ef..72b50720f918b 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -459,52 +459,54 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
if (typeConverter.isEmptyArray(firMemref.getType()))
return failure();
- if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
- Value elemRef = arrayCoorOp.getResult();
- rewriter.setInsertionPointAfter(arrayCoorOp);
- Location loc = arrayCoorOp->getLoc();
- Type elemMemrefTy = typeConverter.convertMemrefType(elemRef.getType());
- Value converted =
- fir::ConvertOp::create(rewriter, loc, elemMemrefTy, elemRef);
- SmallVector<Value> indices;
- return std::pair{converted, indices};
- }
-
- Operation *memref = firMemref.getDefiningOp();
+ Location loc = arrayCoorOp->getLoc();
+ // Prefer lowering the array-coordinates computation to a memref + indices.
+ // This allows erasing fir.array_coor when it is only used by load/store even
+ // if the base address is a block argument (e.g. region arguments).
+ Operation *memref = nullptr;
FailureOr<Value> converted;
- if (enableFIRConvertOptimizations && isMarshalLike(memref) &&
- !fir::isa_fir_type(firMemref.getType())) {
- converted = firMemref;
+ if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
rewriter.setInsertionPoint(arrayCoorOp);
+ Type memrefTy = typeConverter.convertMemrefType(blockArg.getType());
+ converted =
+ fir::ConvertOp::create(rewriter, loc, memrefTy, blockArg).getResult();
+ rewriter.setInsertionPointAfter(arrayCoorOp);
} else {
- Operation *arrayCoorOperation = arrayCoorOp.getOperation();
- rewriter.setInsertionPoint(arrayCoorOp);
- if (memrefIsOptional(memref)) {
- auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
- if (ifOp) {
- Operation *condition = ifOp.getCondition().getDefiningOp();
- if (condition && isa<fir::IsPresentOp>(condition))
- if (condition->getOperand(0) == firMemref) {
- if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion())
- rewriter.setInsertionPointToStart(
- &(ifOp.getThenRegion().front()));
- else if (arrayCoorOperation->getParentRegion() ==
- &ifOp.getElseRegion())
- rewriter.setInsertionPointToStart(
- &(ifOp.getElseRegion().front()));
- }
+ memref = firMemref.getDefiningOp();
+ if (enableFIRConvertOptimizations && isMarshalLike(memref) &&
+ !fir::isa_fir_type(firMemref.getType())) {
+ converted = firMemref;
+ rewriter.setInsertionPoint(arrayCoorOp);
+ } else {
+ Operation *arrayCoorOperation = arrayCoorOp.getOperation();
+ rewriter.setInsertionPoint(arrayCoorOp);
+ if (memrefIsOptional(memref)) {
+ auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
+ if (ifOp) {
+ Operation *condition = ifOp.getCondition().getDefiningOp();
+ if (condition && isa<fir::IsPresentOp>(condition))
+ if (condition->getOperand(0) == firMemref) {
+ if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getThenRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getThenRegion().front()));
+ else if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getElseRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getElseRegion().front()));
+ }
+ }
}
- }
- converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
- if (failed(converted))
- return failure();
+ converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
- rewriter.setInsertionPointAfter(arrayCoorOp);
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ }
}
- Location loc = arrayCoorOp->getLoc();
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
FailureOr<SmallVector<Value>> failureOrIndices =
getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one);
diff --git a/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir b/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
new file mode 100644
index 0000000000000..7e0acbae78af4
--- /dev/null
+++ b/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
@@ -0,0 +1,28 @@
+// Verify fir.array_coor lowering when the base is a block argument.
+// This used to take a shortcut (convert the element ref result) which kept the
+// fir.array_coor alive. We prefer converting the base to a memref and
+// computing indices so that fir.array_coor can be erased when only used by
+// load/store.
+//
+// RUN: fir-opt %s --fir-to-memref --allow-unregistered-dialect | FileCheck %s
+
+func.func @block_arg_memref(%arg0: !fir.ref<!fir.array<32xi32>>) {
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c32 : (index) -> !fir.shape<1>
+ %elt = fir.array_coor %arg0(%shape) %c1 : (!fir.ref<!fir.array<32xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ fir.store %c1_i32 to %elt : !fir.ref<i32>
+ return
+}
+
+// CHECK-LABEL: func.func @block_arg_memref
+// CHECK: [[BASE:%.+]] = fir.convert %arg0 : (!fir.ref<!fir.array<32xi32>>) -> memref<32xi32>
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[SUB:%.+]] = arith.subi %c1, [[ONE]] : index
+// CHECK: [[MUL:%.+]] = arith.muli [[SUB]], [[ONE]] : index
+// CHECK: [[SUB2:%.+]] = arith.subi [[ONE]], [[ONE]] : index
+// CHECK: [[IDX:%.+]] = arith.addi [[MUL]], [[SUB2]] : index
+// CHECK: memref.store {{%.+}}, [[BASE]][[[IDX]]] : memref<32xi32>
+// CHECK-NOT: fir.array_coor
+
diff --git a/flang/test/Transforms/FIRToMemRef/no-declare.mlir b/flang/test/Transforms/FIRToMemRef/no-declare.mlir
index 664da0a0b38a2..3971e2552ce56 100644
--- a/flang/test/Transforms/FIRToMemRef/no-declare.mlir
+++ b/flang/test/Transforms/FIRToMemRef/no-declare.mlir
@@ -5,15 +5,15 @@
// CHECK-LABEL: func.func @nodeclare
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C1]] : (index) -> !fir.shape<1>
-// CHECK: %[[COOR:.*]] = fir.array_coor %arg0(%[[SHAPE]]) %[[C1]] : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
-// CHECK: %[[C0:.*]] = fir.convert %[[COOR]] : (!fir.ref<i32>) -> memref<i32>
-// CHECK: %[[C1M:.*]] = fir.convert %[[COOR]] : (!fir.ref<i32>) -> memref<i32>
-// CHECK: %[[L0:.*]] = memref.load %[[C1M]][] : memref<i32>
+// CHECK: %[[M0:.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<1xi32>>) -> memref<1xi32>
+// CHECK: %[[M1:.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<1xi32>>) -> memref<1xi32>
+// CHECK: %[[L0:.*]] = memref.load %[[M0]][%{{.*}}] : memref<1xi32>
// CHECK: %[[CARG1:.*]] = fir.convert %arg1 : (!fir.ref<i32>) -> memref<i32>
// CHECK: memref.store %[[L0]], %[[CARG1]][] : memref<i32>
-// CHECK: %[[L1:.*]] = memref.load %[[C0]][] : memref<i32>
+// CHECK: %[[L1:.*]] = memref.load %[[M1]][%{{.*}}] : memref<1xi32>
// CHECK: %[[CARG2:.*]] = fir.convert %arg2 : (!fir.ref<i32>) -> memref<i32>
// CHECK: memref.store %[[L1]], %[[CARG2]][] : memref<i32>
+// CHECK-NOT: fir.array_coor
func.func @nodeclare(%arg0: !fir.ref<!fir.array<1xi32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "b"}, %arg2: !fir.ref<i32> {fir.bindc_name = "c"}) attributes {fir.internal_name = ""} {
%c1 = arith.constant 1 : index
@@ -27,8 +27,9 @@ func.func @nodeclare(%arg0: !fir.ref<!fir.array<1xi32>> {fir.bindc_name = "a"},
}
// CHECK-LABEL: func.func @nodeclare_regions
-// CHECK-COUNT-4: fir.convert %{{.*}} : (!fir.ref<i32>) -> memref<i32>
+// CHECK-COUNT-4: fir.convert %{{.*}} : (!fir.ref<!fir.array<6xi32>>) -> memref<6xi32>
// CHECK-COUNT-1: fir.convert %{{.*}} : (i32) -> f32
+// CHECK-NOT: fir.array_coor
func.func @nodeclare_regions(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "h11"}, %arg1: !fir.ref<!fir.array<6xi32>> {fir.bindc_name = "rslt"}) attributes {fir.internal_name = "_QPsub11"} {
%cst = arith.constant 1.100000e+01 : f32
>From a78d304c9330b1ad34a99b55e7c19989a31cb85a Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 18 Feb 2026 12:53:03 -0800
Subject: [PATCH 2/2] refactor
---
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 55 +++++++++----------
1 file changed, 26 insertions(+), 29 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 72b50720f918b..e8c30e62430c0 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -472,39 +472,36 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
converted =
fir::ConvertOp::create(rewriter, loc, memrefTy, blockArg).getResult();
rewriter.setInsertionPointAfter(arrayCoorOp);
+ } else if ((memref = firMemref.getDefiningOp()) &&
+ enableFIRConvertOptimizations && isMarshalLike(memref) &&
+ !fir::isa_fir_type(firMemref.getType())) {
+ converted = firMemref;
+ rewriter.setInsertionPoint(arrayCoorOp);
} else {
- memref = firMemref.getDefiningOp();
- if (enableFIRConvertOptimizations && isMarshalLike(memref) &&
- !fir::isa_fir_type(firMemref.getType())) {
- converted = firMemref;
- rewriter.setInsertionPoint(arrayCoorOp);
- } else {
- Operation *arrayCoorOperation = arrayCoorOp.getOperation();
- rewriter.setInsertionPoint(arrayCoorOp);
- if (memrefIsOptional(memref)) {
- auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
- if (ifOp) {
- Operation *condition = ifOp.getCondition().getDefiningOp();
- if (condition && isa<fir::IsPresentOp>(condition))
- if (condition->getOperand(0) == firMemref) {
- if (arrayCoorOperation->getParentRegion() ==
- &ifOp.getThenRegion())
- rewriter.setInsertionPointToStart(
- &(ifOp.getThenRegion().front()));
- else if (arrayCoorOperation->getParentRegion() ==
- &ifOp.getElseRegion())
- rewriter.setInsertionPointToStart(
- &(ifOp.getElseRegion().front()));
- }
- }
+ Operation *arrayCoorOperation = arrayCoorOp.getOperation();
+ rewriter.setInsertionPoint(arrayCoorOp);
+ if (memrefIsOptional(memref)) {
+ auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
+ if (ifOp) {
+ Operation *condition = ifOp.getCondition().getDefiningOp();
+ if (condition && isa<fir::IsPresentOp>(condition))
+ if (condition->getOperand(0) == firMemref) {
+ if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getThenRegion().front()));
+ else if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getElseRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getElseRegion().front()));
+ }
}
+ }
- converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
- if (failed(converted))
- return failure();
+ converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
- rewriter.setInsertionPointAfter(arrayCoorOp);
- }
+ rewriter.setInsertionPointAfter(arrayCoorOp);
}
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
More information about the flang-commits
mailing list