[flang-commits] [flang] [flang][FIRToMemRef] Preserve descriptor strides for boxed static-shape array_coor (PR #190859)

Susan Tan ス-ザン タン via flang-commits flang-commits at lists.llvm.org
Tue Apr 7 14:43:51 PDT 2026


https://github.com/SusanTan created https://github.com/llvm/llvm-project/pull/190859

Fix FIRToMemRef to avoid the static-shape for descriptor-backed array operands lowered without reinterpret, so boxed sections with static extents still preserve runtime stride semantics (e.g. a(1:10:2) in ASSOCIATE). 

>From ccb7f588ca4fb1125773248cf714c0aea44607c2 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Tue, 7 Apr 2026 14:42:00 -0700
Subject: [PATCH] change rebox lowering

---
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  |  7 ++++-
 .../FIRToMemRef/array-coor-block-arg.mlir     | 21 ++++++++++++++
 flang/test/Transforms/FIRToMemRef/slice.mlir  | 29 +++++++++++++++----
 3 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 4f8c8582fb0e2..144ef9c723398 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -584,8 +584,13 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   MemRefType memRefTy = dyn_cast<MemRefType>(convertedVal.getType());
 
   bool isRebox = firMemref.getDefiningOp<fir::ReboxOp>() != nullptr;
+  bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
+                      firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
 
-  if (memRefTy.hasStaticShape() && !isRebox)
+  // Static shape does not imply contiguous layout for descriptor-backed
+  // entities (e.g. boxed array sections with non-unit stride). Keep the
+  // reinterpret-cast path so descriptor strides are preserved.
+  if (memRefTy.hasStaticShape() && !isRebox && !isDescriptor)
     return std::pair{*converted, indices};
 
   unsigned rank = arrayCoorOp.getIndices().size();
diff --git a/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir b/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
index e5b90323545f8..168fbe67b767c 100644
--- a/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
+++ b/flang/test/Transforms/FIRToMemRef/array-coor-block-arg.mlir
@@ -44,3 +44,24 @@ func.func @block_arg_boxed_array(%arg0: !fir.box<!fir.array<?xi32>>) {
 // CHECK:           memref.store {{%.+}}, {{%.+}}[{{%.+}}] : memref<?xi32, strided<[?], offset: ?>>
 // CHECK-NOT:       fir.array_coor
 
+// Verify fir.array_coor lowering keeps descriptor stride semantics when the
+// box element type has static extent. A boxed section may still be strided.
+func.func @block_arg_boxed_static_array(%arg0: !fir.box<!fir.array<5xi32>>) {
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i64 = arith.constant 2 : i64
+  %elt = fir.array_coor %arg0 %c2_i64 : (!fir.box<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
+  fir.store %c1_i32 to %elt : !fir.ref<i32>
+  return
+}
+
+// CHECK-LABEL: func.func @block_arg_boxed_static_array
+// CHECK:         [[BOXADDR:%.+]] = fir.box_addr %arg0
+// CHECK:         [[BASE:%.+]] = fir.convert [[BOXADDR]] : (!fir.ref<!fir.array<5xi32>>) -> memref<5xi32>
+// CHECK:         [[ELE:%.+]] = fir.box_elesize
+// CHECK:         [[DIMS:%.+]]:3 = fir.box_dims
+// CHECK:         [[DIV:%.+]] = arith.divsi {{%.+}}, [[ELE]] : index
+// CHECK:         [[REINT:%.+]] = memref.reinterpret_cast [[BASE]]
+// CHECK-SAME:     : memref<5xi32> to memref<?xi32, strided<[?], offset: ?>>
+// CHECK:         memref.store {{%.+}}, [[REINT]][{{%.+}}] : memref<?xi32, strided<[?], offset: ?>>
+// CHECK-NOT:     fir.array_coor
+
diff --git a/flang/test/Transforms/FIRToMemRef/slice.mlir b/flang/test/Transforms/FIRToMemRef/slice.mlir
index 994807f591085..737babd4733a0 100644
--- a/flang/test/Transforms/FIRToMemRef/slice.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice.mlir
@@ -55,7 +55,10 @@
 // CHECK:           [[MUL2:%[0-9]+]] = arith.muli [[SUB3]], %[[C2]] : index
 // CHECK:           [[SUB4:%[0-9]+]] = arith.subi %[[C2]], %[[C1_2]] : index
 // CHECK:           [[ADD4:%[0-9]+]] = arith.addi [[MUL2]], [[SUB4]] : index
-// CHECK:           [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][[[ADD4]], [[ADD3]]] : memref<7x5xi32>
+// CHECK:           %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK:           [[REINTERPRET:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: [%[[C0_6]]], sizes: [%[[C7]], %[[C5]]], strides: [%[[C5]], %[[C1_5]]] : memref<7x5xi32> to memref<?x?xi32, strided<[?, ?], offset: ?>>
+// CHECK:           [[LOAD:%[0-9]+]] = memref.load [[REINTERPRET]][[[ADD4]], [[ADD3]]] : memref<?x?xi32, strided<[?, ?], offset: ?>>
 func.func @slice_2d(%arg0: !fir.ref<!fir.array<5x7xi32>>, %arg1: !fir.ref<!fir.array<5x7xi32>>) {
   %c4 = arith.constant 4 : index
   %c2 = arith.constant 2 : index
@@ -143,7 +146,11 @@ func.func @slice_2d(%arg0: !fir.ref<!fir.array<5x7xi32>>, %arg1: !fir.ref<!fir.a
 // CHECK:             [[MUL3:%[0-9]+]] = arith.muli [[SUB5]], %[[C4]] : index
 // CHECK:             [[SUB6:%[0-9]+]] = arith.subi %[[C3]], %[[C1_0]] : index
 // CHECK:             [[ADD3:%[0-9]+]] = arith.addi [[MUL3]], [[SUB6]] : index
-// CHECK:             [[LOAD:%[0-9]+]] = memref.load [[CONVERT2]][[[ADD3]], [[ADD2]], [[ADD1]]] : memref<7x7x5xi32>
+// CHECK:             %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK:             %[[STRIDE0:.*]] = arith.muli %[[C7]], %[[C5]] : index
+// CHECK:             %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK:             [[REINTERPRET3D:%.*]] = memref.reinterpret_cast [[CONVERT2]] to offset: [%[[C0_5]]], sizes: [%[[C7]], %[[C7]], %[[C5]]], strides: [%[[STRIDE0]], %[[C5]], %[[C1_4]]] : memref<7x7x5xi32> to memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>>
+// CHECK:             [[LOAD:%[0-9]+]] = memref.load [[REINTERPRET3D]][[[ADD3]], [[ADD2]], [[ADD1]]] : memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>>
 func.func @slice_3d(%arg0: !fir.ref<!fir.array<5x7x7xi32>> {fir.bindc_name = "a", llvm.nocapture}, %arg1: !fir.ref<!fir.array<5x7x7xi32>> {fir.bindc_name = "b", llvm.nocapture}) attributes {fir.internal_name = "_QPcopy"} {
   %c4 = arith.constant 4 : index
   %c2 = arith.constant 2 : index
@@ -193,7 +200,10 @@ func.func @slice_3d(%arg0: !fir.ref<!fir.array<5x7x7xi32>> {fir.bindc_name = "a"
 // CHECK:         [[MUL1:%[0-9]+]] = arith.muli [[SUB2]], %[[C1]] : index
 // CHECK:         [[SUB3:%[0-9]+]] = arith.subi %[[C1]], %[[C1_1]] : index
 // CHECK:         [[ADD2:%[0-9]+]] = arith.addi [[MUL1]], [[SUB3]] : index
-// CHECK:         [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][[[ADD2]], [[SUB1]]] : memref<3x3xi32>
+// CHECK:         %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:         %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:         [[REINTERPRET2D:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: [%[[C0_4]]], sizes: [%[[C3]], %[[C3]]], strides: [%[[C3]], %[[C1_3]]] : memref<3x3xi32> to memref<?x?xi32, strided<[?, ?], offset: ?>>
+// CHECK:         [[LOAD:%[0-9]+]] = memref.load [[REINTERPRET2D]][[[ADD2]], [[SUB1]]] : memref<?x?xi32, strided<[?, ?], offset: ?>>
 func.func @extract_row(%arg0: !fir.ref<!fir.array<3x3xi32>>) {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -241,7 +251,10 @@ func.func @extract_row(%arg0: !fir.ref<!fir.array<3x3xi32>>) {
 // CHECK:         [[MUL1:%[0-9]+]] = arith.muli [[SUB2]], %[[C11]] : index
 // CHECK:         [[SUB3:%[0-9]+]] = arith.subi %[[C1]], %[[C1_1]] : index
 // CHECK:         [[ADD2:%[0-9]+]] = arith.addi [[MUL1]], [[SUB3]] : index
-// CHECK:         memref.store %[[CST]], [[CONVERT]][[[SUB1]], [[ADD2]]] : memref<5x100xf32>
+// CHECK:         %[[C1_3:.*]] = arith.constant 1 : index
+// CHECK:         %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK:         [[REINTERPRET:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: [%[[C0_4]]], sizes: [%[[C5]], %[[C100]]], strides: [%[[C100]], %[[C1_3]]] : memref<5x100xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+// CHECK:         memref.store %[[CST]], [[REINTERPRET]][[[SUB1]], [[ADD2]]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
 func.func @extract_column(%arg0: !fir.ref<!fir.array<100x5xf32>> {fir.bindc_name = "tmp", llvm.nocapture}) attributes {fir.internal_name = "_QPextract_column"} {
   %c10 = arith.constant 10 : index
   %c11 = arith.constant 11 : index
@@ -289,7 +302,13 @@ func.func @extract_column(%arg0: !fir.ref<!fir.array<100x5xf32>> {fir.bindc_name
 // CHECK:         [[MUL1:%.*]] = arith.muli [[SUB1]], %[[C1_0]] : index
 // CHECK:         [[SUB2:%[0-9]+]] = arith.subi %[[C0]], %[[C0]] : index
 // CHECK:         [[ADD3:%.*]] = arith.addi [[MUL1]], [[SUB2]] : index
-// CHECK:         [[LOADVAL:%.*]] = memref.load [[CONVERT]][[[ADD3]]] : memref<7xf32>
+// CHECK:         [[ELE:%.*]] = fir.box_elesize [[EMBOX]] : (!fir.box<!fir.array<7xf32>>) -> index
+// CHECK:         %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK:         [[DIMS:%.*]]:3 = fir.box_dims [[EMBOX]], %[[C0_2]] : (!fir.box<!fir.array<7xf32>>, index) -> (index, index, index)
+// CHECK:         [[DIV:%.*]] = arith.divsi [[DIMS]]#2, [[ELE]] : index
+// CHECK:         %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK:         [[REINTERPRET:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: [%[[C0_3]]], sizes: [[[DIMS]]#1], strides: [[[DIV]]] : memref<7xf32> to memref<?xf32, strided<[?], offset: ?>>
+// CHECK:         [[LOADVAL:%.*]] = memref.load [[REINTERPRET]][[[ADD3]]] : memref<?xf32, strided<[?], offset: ?>>
 func.func @noslice() {
   %c7 = arith.constant 7 : index
   %c-1 = arith.constant -1 : index



More information about the flang-commits mailing list