[flang-commits] [flang] [flang][FIRToMemRef] Use complex dialect for lowering projection (z%re/z%im) (PR #196123)

Susan Tan ス-ザン タン via flang-commits flang-commits at lists.llvm.org
Wed May 6 10:19:00 PDT 2026


https://github.com/SusanTan created https://github.com/llvm/llvm-project/pull/196123

Use complex dialect to lower accessing imaginary or real compoenent of a complex number (which becomes `path` operand in various ops). Load sites extract the component via complex.re/complex.im; store sites do a read-modify-write using complex.create. SliceInfo is extended with projectionIsImaginary to avoid re-traversing the IR at load/store sites.

>From a25881489a6a90f3de167f72fb54fae8d334aa7e Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 10:13:56 -0700
Subject: [PATCH] add complex implementation

---
 .../flang/Optimizer/Transforms/Passes.td      |   3 +-
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 127 +++++++++++----
 .../FIRToMemRef/slice-projected.mlir          | 150 +++++++++++++-----
 3 files changed, 204 insertions(+), 76 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e107672adf907..93813dcedb045 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -221,7 +221,8 @@ def FIRToMemRef : Pass<"fir-to-memref", "::mlir::func::FuncOp"> {
     Lower FIR memory operations (`fir.alloca`, `fir.load`, `fir.store`, 'fir.array_coor', and etc.) to MLIR's MemRef core dialect.
   }];
   let dependentDialects = ["fir::FIROpsDialect", "mlir::arith::ArithDialect",
-                           "mlir::memref::MemRefDialect"];
+                           "mlir::memref::MemRefDialect",
+                           "mlir::complex::ComplexDialect"];
 }
 
 // This needs to be a "mlir::ModuleOp" pass, because we are creating debug for
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index ec58d6f3f1447..316641813c37a 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -44,6 +44,7 @@
 #include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -140,6 +141,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
 
   bool memrefIsOptional(Operation *) const;
 
+  std::optional<bool> complexProjectionOf(Value firMemref) const;
+
   Value canonicalizeIndex(Value, PatternRewriter &) const;
 
   // Logical section information used by FIRToMemRef. For projected slices, the
@@ -150,6 +153,7 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     SmallVector<Value> shiftVec;
     SmallVector<Value> sliceVec;
     bool hasProjectedSlice = false;
+    std::optional<bool> projectionIsImaginary; // set when hasProjectedSlice
   };
 
   template <typename OpTy>
@@ -171,6 +175,17 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     return sliceOp && !sliceOp.getFields().empty();
   }
 
+  // Returns true for imaginary, false for real, nullopt if not a constant.
+  static std::optional<bool> sliceProjectionIsImaginary(fir::SliceOp sliceOp) {
+    auto fields = sliceOp.getFields();
+    if (fields.empty())
+      return std::nullopt;
+    if (auto cst = fields[0].getDefiningOp<arith::ConstantOp>())
+      if (auto attr = mlir::dyn_cast<IntegerAttr>(cst.getValueAttr()))
+        return attr.getInt() != 0;
+    return std::nullopt;
+  }
+
   unsigned getRankFromEmbox(fir::EmboxOp embox) const {
     auto memrefType = embox.getMemref().getType();
     Type unwrappedType = fir::unwrapRefType(memrefType);
@@ -313,15 +328,12 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
     }
 
     if (auto sliceOp = getSliceOp(op.getSlice())) {
-      // A slice path changes the physical projection of the boxed entity (for
-      // example, `complex -> real` for `%re`). Preserve shape/shift for logical
-      // indexing, but do not treat the triplets alone as layout information.
       if (hasProjectedSlice(sliceOp)) {
         info.hasProjectedSlice = true;
-      } else {
-        auto triples = sliceOp.getTriples();
-        info.sliceVec.append(triples.begin(), triples.end());
+        info.projectionIsImaginary = sliceProjectionIsImaginary(sliceOp);
       }
+      auto triples = sliceOp.getTriples();
+      info.sliceVec.append(triples.begin(), triples.end());
     }
   }
 }
@@ -473,9 +485,6 @@ FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
     rank = getRankFromEmbox(embox);
   }
 
-  // Projected boxed slices leave `sliceVec` empty on purpose: indices are
-  // computed in the logical section coordinate space, while stride/base come
-  // later from the box descriptor.
   SmallVector<Value> &shiftVec = sliceInfo.shiftVec;
   SmallVector<Value> &sliceVec = sliceInfo.sliceVec;
   SmallVector<Value> sliceLbs, sliceStrides;
@@ -626,6 +635,20 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
                       firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
 
+  // For complex projections, getFIRConvert uses embox.getMemref() directly,
+  // converting the underlying array to memref<N×complex<T>> and ignoring the
+  // projected box descriptor.  The indices from getMemrefIndices are correct
+  // as-is; no descriptor rebuild is needed.
+  SliceInfo sliceInfo;
+  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
+  if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+    collectSliceInfoFrom(embox, sliceInfo);
+  else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+    collectSliceInfoFrom(rebox, sliceInfo);
+  if (sliceInfo.hasProjectedSlice &&
+      isa<mlir::ComplexType>(memRefTy.getElementType()))
+    return std::pair{*converted, indices};
+
   // Static shape does not imply contiguous layout for descriptor-backed
   // entities (e.g. boxed array sections with non-unit stride). Keep the
   // reinterpret-cast path so descriptor strides are preserved.
@@ -633,7 +656,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
     return std::pair{*converted, indices};
 
   unsigned rank = arrayCoorOp.getIndices().size();
-
   if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
     rank = getRankFromEmbox(embox);
 
@@ -642,29 +664,17 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   SmallVector<Value> strides;
   strides.reserve(rank);
 
-  SliceInfo sliceInfo;
-  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
-
-  Value box = firMemref;
-  if (!isa<BlockArgument>(firMemref)) {
-    if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) {
-      collectSliceInfoFrom(embox, sliceInfo);
-    } else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>()) {
-      collectSliceInfoFrom(rebox, sliceInfo);
-    }
-  }
-
   SmallVector<Value> &shapeVec = sliceInfo.shapeVec;
   if (sliceInfo.hasProjectedSlice || shapeVec.empty()) {
     // Projected slices carry their physical layout in the descriptor. Rebuild
     // the MemRef view from box metadata instead of from slice triplets.
     auto boxElementSize =
-        fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box);
+        fir::BoxEleSizeOp::create(rewriter, loc, indexTy, firMemref);
 
     for (unsigned i = 0; i < rank; ++i) {
       Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1);
       auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy,
-                                            indexTy, box, dim);
+                                            indexTy, firMemref, dim);
 
       Value extent = boxDims->getResult(1);
       sizes.push_back(castTypeToIndexType(extent, rewriter));
@@ -712,6 +722,41 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   return std::pair{result, indices};
 }
 
+static Value complexComponent(PatternRewriter &rewriter, Location loc,
+                              Type elemTy, Value complexVal, bool isImaginary) {
+  return isImaginary ? complex::ImOp::create(rewriter, loc, elemTy, complexVal)
+                           .getResult()
+                     : complex::ReOp::create(rewriter, loc, elemTy, complexVal)
+                           .getResult();
+}
+
+static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
+                             mlir::ComplexType complexTy, Value converted,
+                             ValueRange indices, Value value,
+                             bool isImaginary) {
+  Type elemTy = complexTy.getElementType();
+  Value old = memref::LoadOp::create(rewriter, loc, converted, indices);
+  Value other = complexComponent(rewriter, loc, elemTy, old, !isImaginary);
+  Value re = isImaginary ? other : value;
+  Value im = isImaginary ? value : other;
+  return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
+}
+
+/// If \p firMemref is defined by a fir.array_coor that indexes a
+/// complex-component projection (z%re / z%im), return false for the real
+/// component or true for the imaginary component.
+/// Returns std::nullopt when the IR does not match the projection pattern.
+std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
+  auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
+  if (!arrayCoor)
+    return std::nullopt;
+  SliceInfo info;
+  collectSliceInfoFrom(arrayCoor, info);
+  if (auto embox = arrayCoor.getMemref().getDefiningOp<fir::EmboxOp>())
+    collectSliceInfoFrom(embox, info);
+  return info.projectionIsImaginary;
+}
+
 FailureOr<Value>
 FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
                            PatternRewriter &rewriter,
@@ -791,9 +836,7 @@ FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
                         "the same, bailing out of conversion\n");
           return failure();
         }
-        // Keep `box_addr` on the projected box so the descriptor remains the
-        // source of truth for projected element type and stride.
-        if (!projectedSlice && embox.getSlice() &&
+        if (embox.getSlice() &&
             embox.getSlice().getDefiningOp<fir::SliceOp>()) {
           Type originalType = embox.getMemref().getType();
           basePtr = embox.getMemref();
@@ -1096,8 +1139,13 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
              loadOp.dump(); assert(succeeded(verify(loadOp))));
 
   if (loadOp.getType() != originalType) {
-    Value castVal =
-        createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
+    // z%re / z%im: extract the scalar component; otherwise type-convert.
+    auto isComplex = complexProjectionOf(firMemref);
+    Value castVal = isComplex
+                        ? complexComponent(rewriter, loadOp.getLoc(),
+                                           originalType, loadOp, *isComplex)
+                        : createTypeConversion(rewriter, loadOp.getLoc(),
+                                               originalType, loadOp);
     loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
   }
 
@@ -1135,9 +1183,26 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
     value =
         createTypeConversion(rewriter, store.getLoc(), convertedType, value);
 
-  Attribute attr = (store.getOperation())->getAttr("tbaa");
+  // For a complex-component projection (z%re / z%im), memref holds complex<T>
+  // but the stored value is scalar T.  Read-modify-write to preserve the
+  // untouched component:
+  //   %old  = memref.load %mem[%idx] : memref<N×complex<T>>
+  //   %re   = complex.re %old : complex<T>      // (or %im for imaginary)
+  //   %new  = complex.create %re, %val : complex<T>
+  //   memref.store %new, %mem[%idx] : memref<N×complex<T>>
+  Value storeVal = value;
+  if (auto isComplex = complexProjectionOf(firMemref)) {
+    auto complexTy = dyn_cast<mlir::ComplexType>(
+        dyn_cast<MemRefType>(converted.getType()).getElementType());
+    assert(complexTy &&
+           "complex projection converted memref must hold complex<T>");
+    storeVal = complexRMWValue(rewriter, store.getLoc(), complexTy, converted,
+                               indices, value, *isComplex);
+  }
+
+  Attribute attr = store.getOperation()->getAttr("tbaa");
   memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
-      store, value, converted, indices);
+      store, storeVal, converted, indices);
   if (attr)
     storeOp.getOperation()->setAttr("tbaa", attr);
 
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7b0fbdf748173..1ef5305099377 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -2,12 +2,9 @@
 
 // Tests for fir.slice with a path component (projected component slice).
 // A projected slice changes the element type of the boxed view, e.g.
-// z%re projects complex<f32> -> f32.  The layout (strides / base address)
-// must come from the projected box descriptor, NOT from reconstructing the
-// triplets, because memref.reinterpret_cast requires the same element type
-// on both sides and the triplet strides are in storage-element units
-// (complex<f32>) while the MemRef strides must be in projected-element units
-// (f32).
+// z%re projects complex<f32> -> f32.  The pass bypasses the box descriptor
+// and uses the underlying complex array directly, extracting the component
+// with complex.re / complex.im at the load/store site.
 //
 // Derived from:
 //   complex, target :: z(4) = 0.
@@ -16,40 +13,18 @@
 //   r = r + z(4:1:-1)%re
 
 // ----------------------------------------------------------------------------
-// Forward projected slice: z(1:4:1)%re
-// The slice path %c0 projects complex<f32> -> f32 (real part).
-// Expected lowering:
-//   - fir.box_addr on the projected box (!fir.box<!fir.array<4xf32>>)
-//   - fir.convert to memref<4xf32>   (NOT to memref<4xcomplex<f32>>)
-//   - index = i - 1  (1-based, no triplet arithmetic)
-//   - strides from fir.box_dims / fir.box_elesize on the projected box
+// Forward projected slice load: z(1:4:1)%re
+// The fir.convert appears inside the loop body (insertion point tracks the
+// array_coor inside the loop).  elemIdx = (i - 1) * step + (lb - 1) = i - 1
+// for step=1, lb=1.  Indices are reversed (col-major → row-major) but for 1D
+// that is a no-op.
 // ----------------------------------------------------------------------------
 // CHECK-LABEL: func.func @projected_slice_fwd
-// CHECK:       [[C1:%.*]] = arith.constant 1 : index
-// CHECK:       [[C4:%.*]] = arith.constant 4 : index
-// CHECK:       [[C0:%.*]] = arith.constant 0 : index
-// CHECK:       [[SHAPE:%.*]] = fir.shape [[C4]] : (index) -> !fir.shape<1>
-// CHECK:       [[SLICE:%.*]] = fir.slice [[C1]], [[C4]], [[C1]] path [[C0]] : (index, index, index, index) -> !fir.slice<1>
-// CHECK:       [[EMBOX:%.*]] = fir.embox %arg0([[SHAPE]]) {{\[}}[[SLICE]]{{\]}} : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
-// CHECK:       fir.do_loop [[I:%.*]] = [[C1]] to [[C4]] step [[C1]] unordered {
-// Projected box_addr gives f32 pointer, not complex<f32>.
-// CHECK:         [[BOXADDR:%.*]] = fir.box_addr [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> !fir.ref<!fir.array<4xf32>>
-// CHECK:         [[CONVERT:%.*]] = fir.convert [[BOXADDR]] : (!fir.ref<!fir.array<4xf32>>) -> memref<4xf32>
-// Index: i-1 (1-based). The lowering emits: delta=i-1, scaled=delta*1,
-// offset=1-1=0, finalIdx=scaled+offset.  The addi result is what feeds the load.
-// CHECK:         [[C1_0:%.*]] = arith.constant 1 : index
-// CHECK:         [[DELTA:%.*]] = arith.subi [[I]], [[C1_0]] : index
-// CHECK:         [[SCALED:%.*]] = arith.muli [[DELTA]], [[C1_0]] : index
-// CHECK:         [[OFFSET:%.*]] = arith.subi [[C1_0]], [[C1_0]] : index
-// CHECK:         [[IDX:%.*]] = arith.addi [[SCALED]], [[OFFSET]] : index
-// Layout: extent and stride come from the projected box descriptor.
-// CHECK:         [[ELE:%.*]] = fir.box_elesize [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> index
-// CHECK:         [[C0_0:%.*]] = arith.constant 0 : index
-// CHECK:         [[DIMS:%.*]]:3 = fir.box_dims [[EMBOX]], [[C0_0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK:         [[STRIDE:%.*]] = arith.divsi [[DIMS]]#2, [[ELE]] : index
-// CHECK:         [[C0_1:%.*]] = arith.constant 0 : index
-// CHECK:         [[VIEW:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: {{\[}}[[C0_1]]{{\]}}, sizes: {{\[}}[[DIMS]]#1{{\]}}, strides: {{\[}}[[STRIDE]]{{\]}} : memref<4xf32> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK:         memref.load [[VIEW]]{{\[}}[[IDX]]{{\]}} : memref<?xf32, strided<[?], offset: ?>>
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         complex.re [[CVAL]] : complex<f32>
 func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -64,16 +39,103 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   return
 }
 
+// ----------------------------------------------------------------------------
+// Backward projected slice load: z(4:1:-1)%re
+// step = -1, lb = 4  →  elemIdx = (i - 1) * (-1) + (4 - 1) = 3 - (i-1)
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_bwd
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cm1 = arith.constant -1 : index
+  %c0 = arith.constant 0 : index
+  %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+  %slice = fir.slice %c4, %c1, %cm1 path %c0 : (index, index, index, index) -> !fir.slice<1>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+  fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+    %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+    %val = fir.load %coor : !fir.ref<f32>
+  }
+  return
+}
+
+// ----------------------------------------------------------------------------
+// Imaginary component store: z(1:4:1)%im = val
+// Read-modify-write: load complex, update imaginary, store back.
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_store_im
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[OLD:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         [[RE:%.*]]  = complex.re [[OLD]] : complex<f32>
+// CHECK:         [[NEW:%.*]] = complex.create [[RE]], %arg1 : complex<f32>
+// CHECK:         memref.store [[NEW]], [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
+                                    %arg1: f32) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_im = arith.constant 1 : index  // imaginary component index
+  %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+  %slice = fir.slice %c1, %c4, %c1 path %c1_im : (index, index, index, index) -> !fir.slice<1>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+  fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+    %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+    fir.store %arg1 to %coor : !fir.ref<f32>
+  }
+  return
+}
+
+// ----------------------------------------------------------------------------
+// 2-D boxed projected slice load: z(1:2:1, 1:3:1)%re
+// Storage: !fir.array<2x3xcomplex<f32>>
+//
+// convertMemrefType reverses Fortran column-major extents to MLIR row-major:
+//   !fir.ref<!fir.array<2x3xcomplex<f32>>> → memref<3x2xcomplex<f32>>
+//
+// Per-dimension element index (0-based, column-major):
+//   elemIdx_i = (i-1)*1 + (1-1) = i-1   (Fortran dim 1, size 2)
+//   elemIdx_j = (j-1)*1 + (1-1) = j-1   (Fortran dim 2, size 3)
+//
+// After reversing for MLIR row-major access:
+//   memref.load [elemIdx_j, elemIdx_i]
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_2d
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         fir.do_loop [[J:%.*]] =
+// CHECK:           [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<2x3xcomplex<f32>>>) -> memref<3x2xcomplex<f32>>
+// CHECK:           [[IDX_I:%.*]] = arith.addi
+// CHECK:           [[IDX_J:%.*]] = arith.addi
+// CHECK:           [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX_J]], [[IDX_I]]{{\]}} : memref<3x2xcomplex<f32>>
+// CHECK:           complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c0 = arith.constant 0 : index
+  %shape = fir.shape %c2, %c3 : (index, index) -> !fir.shape<2>
+  %slice = fir.slice %c1, %c2, %c1, %c1, %c3, %c1 path %c0 : (index, index, index, index, index, index, index) -> !fir.slice<2>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<2x3xcomplex<f32>>>, !fir.shape<2>, !fir.slice<2>) -> !fir.box<!fir.array<2x3xf32>>
+  fir.do_loop %i = %c1 to %c2 step %c1 unordered {
+    fir.do_loop %j = %c1 to %c3 step %c1 unordered {
+      %coor = fir.array_coor %embox %i, %j : (!fir.box<!fir.array<2x3xf32>>, index, index) -> !fir.ref<f32>
+      %val = fir.load %coor : !fir.ref<f32>
+    }
+  }
+  return
+}
+
 // ----------------------------------------------------------------------------
 // Derived-type component projection: a%x where a : TYPE{x:f64, y:complex<f64>}
 //
 // This is NOT a complex projection — the storage element is the derived type T,
-// not complex<T>.  FIRToMemRef cannot safely compute element-unit strides via
-// divsi(byte_stride, elesize) because sizeof(T)/sizeof(component) may not be an
-// integer (e.g. sizeof(T)=24, sizeof(complex<f64>)=16 -> 1.5, truncated to 1).
-//
-// The pass must leave fir.array_coor and fir.store/fir.load unconverted;
-// downstream FIR-to-LLVM lowering handles them correctly via the descriptor.
+// not complex<T>.  FIRToMemRef cannot safely handle this; downstream
+// FIR-to-LLVM lowering handles it correctly via the descriptor.
 //
 // CHECK-LABEL: func.func @derived_component_not_projected
 // The fir.array_coor must survive (not be erased).



More information about the flang-commits mailing list