[flang-commits] [flang] [FIR] Route embox + projected complex slice through shapeVec (PR #205042)
via flang-commits
flang-commits at lists.llvm.org
Mon Jun 22 00:10:13 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Kareem Ergawy (ergawy)
<details>
<summary>Changes</summary>
When the array_coor base is a fir.embox with a projected complex %re/%im slice, take the shapeVec path instead of the descriptor (fir.box_dims) path. The descriptor path iterates source-rank dims while querying the rank-reduced embox result box, which miscompiles slices that collapse dims (e.g. complex(:,k)%re). For embox-derived boxes the underlying storage is contiguous, so the shape-derived layout is both correct and the natural place to encode that static shape is available. Non-embox boxes (rebox, assumed-shape) still go through fir.box_dims.
Co-Authored-By: Claude Sonnet 4.6 <noreply@<!-- -->anthropic.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/205042.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/Transforms/FIRToMemRef.cpp (+2-13)
- (modified) flang/test/Transforms/FIRToMemRef/slice-projected.mlir (+19-51)
``````````diff
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 3f738a25ec98b..1dcc056387c48 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -773,20 +773,9 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
//
// box_dims path: query the descriptor at runtime. Required when:
// (a) we have no shape information at all; or
- // (b) the array_coor base is a fir.box that is NOT a fir.embox result;
- // or a fir.box with a projected slice (layout in the descriptor); or
- // (c) embox cannot supply layout for this coor (non-embox box above).
- // getFIRConvert materializes fir.box_addr(box) -- an opaque pointer
- // with no layout in its type -- so strides must come from the
- // descriptor. This matches CodeGen XArrayCoorOp's boxed branch
- // (getStrideFromBox); shape/shape_shift on array_coor is
- // informational only (lower bounds for index translation).
- // Projected complex %re/%im on a bare ref uses the shapeVec path with
- // strides scaled by two scalar slots per complex.
- const bool boxNeedsDescriptorStrides =
- firMemrefIsBox && (!firMemrefIsEmbox || sliceInfo.hasProjectedSlice);
+ // (b) the array_coor base is a fir.box that is NOT a fir.embox result.
const bool descriptorOwnsLayout =
- shapeVec.empty() || boxNeedsDescriptorStrides;
+ shapeVec.empty() || (firMemrefIsBox && !firMemrefIsEmbox);
if (descriptorOwnsLayout) {
// Plain `!fir.ref` without recoverable shape extents cannot use fir.box_*.
if (shapeVec.empty() && !sliceInfo.hasProjectedSlice && !isDescriptor &&
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7d29fca000fad..0a5cb672333ed 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -29,22 +29,12 @@
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
// CHECK: [[COMP:%[0-9]+]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
-// CHECK: %[[FWD_C_RE:.*]] = arith.constant 0 : index
-// CHECK: %[[FWD_C_SZF32:.*]] = arith.constant 4 : index
-// CHECK: %[[FWD_C_DIM0:.*]] = arith.constant 0 : index
-// CHECK: [[BD:%[0-9]+]]:3 = fir.box_dims %2, %[[FWD_C_DIM0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK: [[STRIDE:%[0-9]+]] = arith.divsi [[BD]]#2, %[[FWD_C_SZF32]] : index
-// Reinterpret applies the embox descriptor layout onto the scalar view:
-// sizes[0] = box extent (section length in f32 slots)
-// sizes[1] = 2 for the (re, im) pair exposed by memref<4x2xf32>
-// strides[0] = box_dims byte_stride / sizeof(f32) (not box_elesize)
-// strides[1] = 1 between adjacent real/imag scalars
-// Without this, memref.load would use dense strides from fir.convert only.
-// CHECK: %[[FWD_C_PAIR:.*]] = arith.constant 2 : index
-// CHECK: %[[FWD_C_COMP_STRIDE:.*]] = arith.constant 1 : index
-// CHECK: %[[FWD_C_OFF:.*]] = arith.constant 0 : index
-// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%[[FWD_C_OFF]]], sizes: [[[BD]]#1, %[[FWD_C_PAIR]]], strides: [[[STRIDE]], %[[FWD_C_COMP_STRIDE]]] : memref<4x2xf32> to memref<?x?xf32, strided<
-// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX]], %[[FWD_C_RE]]] : memref<?x?xf32, strided<
+// Reinterpret applies the shape-derived layout onto the scalar view:
+// sizes = [shape extent, 2 (re/im pair)]
+// strides = [2 (one complex == two scalar slots), 1]
+// CHECK-NOT: fir.box_dims
+// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%c0{{.*}}], sizes: [%c4{{.*}}, %c2{{.*}}], strides: [%c2{{.*}}, %c1{{.*}}] : memref<4x2xf32> to memref<?x?xf32, strided<
+// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX]], %c0{{.*}}] : memref<?x?xf32, strided<
func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -68,17 +58,10 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
// CHECK: [[COMP:%[0-9]+]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
-// CHECK: %[[BWD_C_RE:.*]] = arith.constant 0 : index
-// CHECK: %[[BWD_C_SZF32:.*]] = arith.constant 4 : index
-// CHECK: %[[BWD_C_DIM0:.*]] = arith.constant 0 : index
-// CHECK: [[BD:%[0-9]+]]:3 = fir.box_dims %2, %[[BWD_C_DIM0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK: [[STRIDE:%[0-9]+]] = arith.divsi [[BD]]#2, %[[BWD_C_SZF32]] : index
// Same reinterpret as forward; slice triple only changes [[IDX]], not strides.
-// CHECK: %[[BWD_C_PAIR:.*]] = arith.constant 2 : index
-// CHECK: %[[BWD_C_COMP_STRIDE:.*]] = arith.constant 1 : index
-// CHECK: %[[BWD_C_OFF:.*]] = arith.constant 0 : index
-// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%[[BWD_C_OFF]]], sizes: [[[BD]]#1, %[[BWD_C_PAIR]]], strides: [[[STRIDE]], %[[BWD_C_COMP_STRIDE]]] : memref<4x2xf32> to memref<?x?xf32, strided<
-// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX]], %[[BWD_C_RE]]] : memref<?x?xf32, strided<
+// CHECK-NOT: fir.box_dims
+// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%c0{{.*}}], sizes: [%c4{{.*}}, %c2{{.*}}], strides: [%c2{{.*}}, %c1{{.*}}] : memref<4x2xf32> to memref<?x?xf32, strided<
+// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX]], %c0{{.*}}] : memref<?x?xf32, strided<
func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -103,17 +86,10 @@ func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
// CHECK: [[COMP:%[0-9]+]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
-// CHECK: %[[IM_C_IM:.*]] = arith.constant 1 : index
-// CHECK: %[[IM_C_SZF32:.*]] = arith.constant 4 : index
-// CHECK: %[[IM_C_DIM0:.*]] = arith.constant 0 : index
-// CHECK: [[BD:%[0-9]+]]:3 = fir.box_dims %2, %[[IM_C_DIM0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK: [[STRIDE:%[0-9]+]] = arith.divsi [[BD]]#2, %[[IM_C_SZF32]] : index
// Same layout as %re; store uses component index 1 for imaginary.
-// CHECK: %[[IM_C_PAIR:.*]] = arith.constant 2 : index
-// CHECK: %[[IM_C_COMP_STRIDE:.*]] = arith.constant 1 : index
-// CHECK: %[[IM_C_OFF:.*]] = arith.constant 0 : index
-// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%[[IM_C_OFF]]], sizes: [[[BD]]#1, %[[IM_C_PAIR]]], strides: [[[STRIDE]], %[[IM_C_COMP_STRIDE]]] : memref<4x2xf32> to memref<?x?xf32, strided<
-// CHECK: memref.store %arg1, [[VIEW]][[[IDX]], %[[IM_C_IM]]] : memref<?x?xf32, strided<
+// CHECK-NOT: fir.box_dims
+// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%c0{{.*}}], sizes: [%c4{{.*}}, %c2{{.*}}], strides: [%c2{{.*}}, %c1{{.*}}] : memref<4x2xf32> to memref<?x?xf32, strided<
+// CHECK: memref.store %arg1, [[VIEW]][[[IDX]], %c1{{.*}}] : memref<?x?xf32, strided<
func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
%arg1: f32) {
%c1 = arith.constant 1 : index
@@ -152,21 +128,13 @@ func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
// CHECK: [[IDX_I:%.*]] = arith.addi
// CHECK: [[IDX_J:%.*]] = arith.addi
// CHECK: [[COMP:%[0-9]+]] = fir.convert [[MEMREF]] : (memref<3x2xcomplex<f32>>) -> memref<3x2x2xf32>
-// CHECK: %[[D2_C_RE:.*]] = arith.constant 0 : index
-// CHECK: %[[D2_C_SZF32:.*]] = arith.constant 4 : index
-// CHECK: %[[D2_C_DIM1:.*]] = arith.constant 1 : index
-// CHECK: [[BD0:%[0-9]+]]:3 = fir.box_dims %2, %[[D2_C_DIM1]] : (!fir.box<!fir.array<2x3xf32>>, index) -> (index, index, index)
-// CHECK: [[STR0:%[0-9]+]] = arith.divsi [[BD0]]#2, %[[D2_C_SZF32]] : index
-// CHECK: %[[D2_C_DIM0:.*]] = arith.constant 0 : index
-// CHECK: [[BD1:%[0-9]+]]:3 = fir.box_dims %2, %[[D2_C_DIM0]] : (!fir.box<!fir.array<2x3xf32>>, index) -> (index, index, index)
-// CHECK: [[STR1:%[0-9]+]] = arith.divsi [[BD1]]#2, %[[D2_C_SZF32]] : index
-// 2-D embox: two box_dims strides (both / sizeof(f32)), plus pair dim (2, 1).
-// Row-major memref indices are [j, i, 0] after Fortran dim reversal.
-// CHECK: %[[D2_C_PAIR:.*]] = arith.constant 2 : index
-// CHECK: %[[D2_C_COMP_STRIDE:.*]] = arith.constant 1 : index
-// CHECK: %[[D2_C_OFF:.*]] = arith.constant 0 : index
-// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%[[D2_C_OFF]]], sizes: [[[BD0]]#1, [[BD1]]#1, %[[D2_C_PAIR]]], strides: [[[STR0]], [[STR1]], %[[D2_C_COMP_STRIDE]]] : memref<3x2x2xf32> to memref<?x?x?xf32, strided<
-// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX_J]], [[IDX_I]], %[[D2_C_RE]]] : memref<?x?x?xf32, strided<
+// 2-D shapeVec path: outer stride = inner_extent * 2 (pair slots), inner
+// stride = 2, pair stride = 1. Row-major memref indices are [j, i, 0] after
+// Fortran dim reversal.
+// CHECK-NOT: fir.box_dims
+// CHECK: [[STR0:%.*]] = arith.muli %c2{{.*}}, %c2{{.*}} : index
+// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[COMP]] to offset: [%c0{{.*}}], sizes: [%c3{{.*}}, %c2{{.*}}, %c2{{.*}}], strides: [[[STR0]], %c2{{.*}}, %c1{{.*}}] : memref<3x2x2xf32> to memref<?x?x?xf32, strided<
+// CHECK: [[LOAD:%[0-9]+]] = memref.load [[VIEW]][[[IDX_J]], [[IDX_I]], %c0{{.*}}] : memref<?x?x?xf32, strided<
func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/205042
More information about the flang-commits
mailing list