[flang-commits] [flang] [flang][FIRToMemRef] [flang][fir-to-memref] Lower complex projected slices via memref<...x2xT> reinterpretation (PR #196123)

Susan Tan ス-ザン タン via flang-commits flang-commits at lists.llvm.org
Fri May 8 08:21:34 PDT 2026


https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/196123

>From a25881489a6a90f3de167f72fb54fae8d334aa7e Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 10:13:56 -0700
Subject: [PATCH 1/5] add complex implementation

---
 .../flang/Optimizer/Transforms/Passes.td      |   3 +-
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 127 +++++++++++----
 .../FIRToMemRef/slice-projected.mlir          | 150 +++++++++++++-----
 3 files changed, 204 insertions(+), 76 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e107672adf907..93813dcedb045 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -221,7 +221,8 @@ def FIRToMemRef : Pass<"fir-to-memref", "::mlir::func::FuncOp"> {
     Lower FIR memory operations (`fir.alloca`, `fir.load`, `fir.store`, 'fir.array_coor', and etc.) to MLIR's MemRef core dialect.
   }];
   let dependentDialects = ["fir::FIROpsDialect", "mlir::arith::ArithDialect",
-                           "mlir::memref::MemRefDialect"];
+                           "mlir::memref::MemRefDialect",
+                           "mlir::complex::ComplexDialect"];
 }
 
 // This needs to be a "mlir::ModuleOp" pass, because we are creating debug for
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index ec58d6f3f1447..316641813c37a 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -44,6 +44,7 @@
 #include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -140,6 +141,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
 
   bool memrefIsOptional(Operation *) const;
 
+  std::optional<bool> complexProjectionOf(Value firMemref) const;
+
   Value canonicalizeIndex(Value, PatternRewriter &) const;
 
   // Logical section information used by FIRToMemRef. For projected slices, the
@@ -150,6 +153,7 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     SmallVector<Value> shiftVec;
     SmallVector<Value> sliceVec;
     bool hasProjectedSlice = false;
+    std::optional<bool> projectionIsImaginary; // set when hasProjectedSlice
   };
 
   template <typename OpTy>
@@ -171,6 +175,17 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     return sliceOp && !sliceOp.getFields().empty();
   }
 
+  // Returns true for imaginary, false for real, nullopt if not a constant.
+  static std::optional<bool> sliceProjectionIsImaginary(fir::SliceOp sliceOp) {
+    auto fields = sliceOp.getFields();
+    if (fields.empty())
+      return std::nullopt;
+    if (auto cst = fields[0].getDefiningOp<arith::ConstantOp>())
+      if (auto attr = mlir::dyn_cast<IntegerAttr>(cst.getValueAttr()))
+        return attr.getInt() != 0;
+    return std::nullopt;
+  }
+
   unsigned getRankFromEmbox(fir::EmboxOp embox) const {
     auto memrefType = embox.getMemref().getType();
     Type unwrappedType = fir::unwrapRefType(memrefType);
@@ -313,15 +328,12 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
     }
 
     if (auto sliceOp = getSliceOp(op.getSlice())) {
-      // A slice path changes the physical projection of the boxed entity (for
-      // example, `complex -> real` for `%re`). Preserve shape/shift for logical
-      // indexing, but do not treat the triplets alone as layout information.
       if (hasProjectedSlice(sliceOp)) {
         info.hasProjectedSlice = true;
-      } else {
-        auto triples = sliceOp.getTriples();
-        info.sliceVec.append(triples.begin(), triples.end());
+        info.projectionIsImaginary = sliceProjectionIsImaginary(sliceOp);
       }
+      auto triples = sliceOp.getTriples();
+      info.sliceVec.append(triples.begin(), triples.end());
     }
   }
 }
@@ -473,9 +485,6 @@ FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
     rank = getRankFromEmbox(embox);
   }
 
-  // Projected boxed slices leave `sliceVec` empty on purpose: indices are
-  // computed in the logical section coordinate space, while stride/base come
-  // later from the box descriptor.
   SmallVector<Value> &shiftVec = sliceInfo.shiftVec;
   SmallVector<Value> &sliceVec = sliceInfo.sliceVec;
   SmallVector<Value> sliceLbs, sliceStrides;
@@ -626,6 +635,20 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
                       firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
 
+  // For complex projections, getFIRConvert uses embox.getMemref() directly,
+  // converting the underlying array to memref<N×complex<T>> and ignoring the
+  // projected box descriptor.  The indices from getMemrefIndices are correct
+  // as-is; no descriptor rebuild is needed.
+  SliceInfo sliceInfo;
+  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
+  if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+    collectSliceInfoFrom(embox, sliceInfo);
+  else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+    collectSliceInfoFrom(rebox, sliceInfo);
+  if (sliceInfo.hasProjectedSlice &&
+      isa<mlir::ComplexType>(memRefTy.getElementType()))
+    return std::pair{*converted, indices};
+
   // Static shape does not imply contiguous layout for descriptor-backed
   // entities (e.g. boxed array sections with non-unit stride). Keep the
   // reinterpret-cast path so descriptor strides are preserved.
@@ -633,7 +656,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
     return std::pair{*converted, indices};
 
   unsigned rank = arrayCoorOp.getIndices().size();
-
   if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
     rank = getRankFromEmbox(embox);
 
@@ -642,29 +664,17 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   SmallVector<Value> strides;
   strides.reserve(rank);
 
-  SliceInfo sliceInfo;
-  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
-
-  Value box = firMemref;
-  if (!isa<BlockArgument>(firMemref)) {
-    if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) {
-      collectSliceInfoFrom(embox, sliceInfo);
-    } else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>()) {
-      collectSliceInfoFrom(rebox, sliceInfo);
-    }
-  }
-
   SmallVector<Value> &shapeVec = sliceInfo.shapeVec;
   if (sliceInfo.hasProjectedSlice || shapeVec.empty()) {
     // Projected slices carry their physical layout in the descriptor. Rebuild
     // the MemRef view from box metadata instead of from slice triplets.
     auto boxElementSize =
-        fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box);
+        fir::BoxEleSizeOp::create(rewriter, loc, indexTy, firMemref);
 
     for (unsigned i = 0; i < rank; ++i) {
       Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1);
       auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy,
-                                            indexTy, box, dim);
+                                            indexTy, firMemref, dim);
 
       Value extent = boxDims->getResult(1);
       sizes.push_back(castTypeToIndexType(extent, rewriter));
@@ -712,6 +722,41 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   return std::pair{result, indices};
 }
 
+static Value complexComponent(PatternRewriter &rewriter, Location loc,
+                              Type elemTy, Value complexVal, bool isImaginary) {
+  return isImaginary ? complex::ImOp::create(rewriter, loc, elemTy, complexVal)
+                           .getResult()
+                     : complex::ReOp::create(rewriter, loc, elemTy, complexVal)
+                           .getResult();
+}
+
+static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
+                             mlir::ComplexType complexTy, Value converted,
+                             ValueRange indices, Value value,
+                             bool isImaginary) {
+  Type elemTy = complexTy.getElementType();
+  Value old = memref::LoadOp::create(rewriter, loc, converted, indices);
+  Value other = complexComponent(rewriter, loc, elemTy, old, !isImaginary);
+  Value re = isImaginary ? other : value;
+  Value im = isImaginary ? value : other;
+  return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
+}
+
+/// If \p firMemref is defined by a fir.array_coor that indexes a
+/// complex-component projection (z%re / z%im), return false for the real
+/// component or true for the imaginary component.
+/// Returns std::nullopt when the IR does not match the projection pattern.
+std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
+  auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
+  if (!arrayCoor)
+    return std::nullopt;
+  SliceInfo info;
+  collectSliceInfoFrom(arrayCoor, info);
+  if (auto embox = arrayCoor.getMemref().getDefiningOp<fir::EmboxOp>())
+    collectSliceInfoFrom(embox, info);
+  return info.projectionIsImaginary;
+}
+
 FailureOr<Value>
 FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
                            PatternRewriter &rewriter,
@@ -791,9 +836,7 @@ FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
                         "the same, bailing out of conversion\n");
           return failure();
         }
-        // Keep `box_addr` on the projected box so the descriptor remains the
-        // source of truth for projected element type and stride.
-        if (!projectedSlice && embox.getSlice() &&
+        if (embox.getSlice() &&
             embox.getSlice().getDefiningOp<fir::SliceOp>()) {
           Type originalType = embox.getMemref().getType();
           basePtr = embox.getMemref();
@@ -1096,8 +1139,13 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
              loadOp.dump(); assert(succeeded(verify(loadOp))));
 
   if (loadOp.getType() != originalType) {
-    Value castVal =
-        createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
+    // z%re / z%im: extract the scalar component; otherwise type-convert.
+    auto isComplex = complexProjectionOf(firMemref);
+    Value castVal = isComplex
+                        ? complexComponent(rewriter, loadOp.getLoc(),
+                                           originalType, loadOp, *isComplex)
+                        : createTypeConversion(rewriter, loadOp.getLoc(),
+                                               originalType, loadOp);
     loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
   }
 
@@ -1135,9 +1183,26 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
     value =
         createTypeConversion(rewriter, store.getLoc(), convertedType, value);
 
-  Attribute attr = (store.getOperation())->getAttr("tbaa");
+  // For a complex-component projection (z%re / z%im), memref holds complex<T>
+  // but the stored value is scalar T.  Read-modify-write to preserve the
+  // untouched component:
+  //   %old  = memref.load %mem[%idx] : memref<N×complex<T>>
+  //   %re   = complex.re %old : complex<T>      // (or %im for imaginary)
+  //   %new  = complex.create %re, %val : complex<T>
+  //   memref.store %new, %mem[%idx] : memref<N×complex<T>>
+  Value storeVal = value;
+  if (auto isComplex = complexProjectionOf(firMemref)) {
+    auto complexTy = dyn_cast<mlir::ComplexType>(
+        dyn_cast<MemRefType>(converted.getType()).getElementType());
+    assert(complexTy &&
+           "complex projection converted memref must hold complex<T>");
+    storeVal = complexRMWValue(rewriter, store.getLoc(), complexTy, converted,
+                               indices, value, *isComplex);
+  }
+
+  Attribute attr = store.getOperation()->getAttr("tbaa");
   memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
-      store, value, converted, indices);
+      store, storeVal, converted, indices);
   if (attr)
     storeOp.getOperation()->setAttr("tbaa", attr);
 
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7b0fbdf748173..1ef5305099377 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -2,12 +2,9 @@
 
 // Tests for fir.slice with a path component (projected component slice).
 // A projected slice changes the element type of the boxed view, e.g.
-// z%re projects complex<f32> -> f32.  The layout (strides / base address)
-// must come from the projected box descriptor, NOT from reconstructing the
-// triplets, because memref.reinterpret_cast requires the same element type
-// on both sides and the triplet strides are in storage-element units
-// (complex<f32>) while the MemRef strides must be in projected-element units
-// (f32).
+// z%re projects complex<f32> -> f32.  The pass bypasses the box descriptor
+// and uses the underlying complex array directly, extracting the component
+// with complex.re / complex.im at the load/store site.
 //
 // Derived from:
 //   complex, target :: z(4) = 0.
@@ -16,40 +13,18 @@
 //   r = r + z(4:1:-1)%re
 
 // ----------------------------------------------------------------------------
-// Forward projected slice: z(1:4:1)%re
-// The slice path %c0 projects complex<f32> -> f32 (real part).
-// Expected lowering:
-//   - fir.box_addr on the projected box (!fir.box<!fir.array<4xf32>>)
-//   - fir.convert to memref<4xf32>   (NOT to memref<4xcomplex<f32>>)
-//   - index = i - 1  (1-based, no triplet arithmetic)
-//   - strides from fir.box_dims / fir.box_elesize on the projected box
+// Forward projected slice load: z(1:4:1)%re
+// The fir.convert appears inside the loop body (insertion point tracks the
+// array_coor inside the loop).  elemIdx = (i - 1) * step + (lb - 1) = i - 1
+// for step=1, lb=1.  Indices are reversed (col-major → row-major) but for 1D
+// that is a no-op.
 // ----------------------------------------------------------------------------
 // CHECK-LABEL: func.func @projected_slice_fwd
-// CHECK:       [[C1:%.*]] = arith.constant 1 : index
-// CHECK:       [[C4:%.*]] = arith.constant 4 : index
-// CHECK:       [[C0:%.*]] = arith.constant 0 : index
-// CHECK:       [[SHAPE:%.*]] = fir.shape [[C4]] : (index) -> !fir.shape<1>
-// CHECK:       [[SLICE:%.*]] = fir.slice [[C1]], [[C4]], [[C1]] path [[C0]] : (index, index, index, index) -> !fir.slice<1>
-// CHECK:       [[EMBOX:%.*]] = fir.embox %arg0([[SHAPE]]) {{\[}}[[SLICE]]{{\]}} : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
-// CHECK:       fir.do_loop [[I:%.*]] = [[C1]] to [[C4]] step [[C1]] unordered {
-// Projected box_addr gives f32 pointer, not complex<f32>.
-// CHECK:         [[BOXADDR:%.*]] = fir.box_addr [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> !fir.ref<!fir.array<4xf32>>
-// CHECK:         [[CONVERT:%.*]] = fir.convert [[BOXADDR]] : (!fir.ref<!fir.array<4xf32>>) -> memref<4xf32>
-// Index: i-1 (1-based). The lowering emits: delta=i-1, scaled=delta*1,
-// offset=1-1=0, finalIdx=scaled+offset.  The addi result is what feeds the load.
-// CHECK:         [[C1_0:%.*]] = arith.constant 1 : index
-// CHECK:         [[DELTA:%.*]] = arith.subi [[I]], [[C1_0]] : index
-// CHECK:         [[SCALED:%.*]] = arith.muli [[DELTA]], [[C1_0]] : index
-// CHECK:         [[OFFSET:%.*]] = arith.subi [[C1_0]], [[C1_0]] : index
-// CHECK:         [[IDX:%.*]] = arith.addi [[SCALED]], [[OFFSET]] : index
-// Layout: extent and stride come from the projected box descriptor.
-// CHECK:         [[ELE:%.*]] = fir.box_elesize [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> index
-// CHECK:         [[C0_0:%.*]] = arith.constant 0 : index
-// CHECK:         [[DIMS:%.*]]:3 = fir.box_dims [[EMBOX]], [[C0_0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK:         [[STRIDE:%.*]] = arith.divsi [[DIMS]]#2, [[ELE]] : index
-// CHECK:         [[C0_1:%.*]] = arith.constant 0 : index
-// CHECK:         [[VIEW:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: {{\[}}[[C0_1]]{{\]}}, sizes: {{\[}}[[DIMS]]#1{{\]}}, strides: {{\[}}[[STRIDE]]{{\]}} : memref<4xf32> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK:         memref.load [[VIEW]]{{\[}}[[IDX]]{{\]}} : memref<?xf32, strided<[?], offset: ?>>
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         complex.re [[CVAL]] : complex<f32>
 func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -64,16 +39,103 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   return
 }
 
+// ----------------------------------------------------------------------------
+// Backward projected slice load: z(4:1:-1)%re
+// step = -1, lb = 4  →  elemIdx = (i - 1) * (-1) + (4 - 1) = 3 - (i-1)
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_bwd
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cm1 = arith.constant -1 : index
+  %c0 = arith.constant 0 : index
+  %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+  %slice = fir.slice %c4, %c1, %cm1 path %c0 : (index, index, index, index) -> !fir.slice<1>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+  fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+    %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+    %val = fir.load %coor : !fir.ref<f32>
+  }
+  return
+}
+
+// ----------------------------------------------------------------------------
+// Imaginary component store: z(1:4:1)%im = val
+// Read-modify-write: load complex, update imaginary, store back.
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_store_im
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK:         [[IDX:%.*]] = arith.addi
+// CHECK:         [[OLD:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         [[RE:%.*]]  = complex.re [[OLD]] : complex<f32>
+// CHECK:         [[NEW:%.*]] = complex.create [[RE]], %arg1 : complex<f32>
+// CHECK:         memref.store [[NEW]], [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
+                                    %arg1: f32) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_im = arith.constant 1 : index  // imaginary component index
+  %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+  %slice = fir.slice %c1, %c4, %c1 path %c1_im : (index, index, index, index) -> !fir.slice<1>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+  fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+    %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+    fir.store %arg1 to %coor : !fir.ref<f32>
+  }
+  return
+}
+
+// ----------------------------------------------------------------------------
+// 2-D boxed projected slice load: z(1:2:1, 1:3:1)%re
+// Storage: !fir.array<2x3xcomplex<f32>>
+//
+// convertMemrefType reverses Fortran column-major extents to MLIR row-major:
+//   !fir.ref<!fir.array<2x3xcomplex<f32>>> → memref<3x2xcomplex<f32>>
+//
+// Per-dimension element index (0-based, column-major):
+//   elemIdx_i = (i-1)*1 + (1-1) = i-1   (Fortran dim 1, size 2)
+//   elemIdx_j = (j-1)*1 + (1-1) = j-1   (Fortran dim 2, size 3)
+//
+// After reversing for MLIR row-major access:
+//   memref.load [elemIdx_j, elemIdx_i]
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_2d
+// CHECK:       fir.do_loop [[I:%.*]] =
+// CHECK:         fir.do_loop [[J:%.*]] =
+// CHECK:           [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<2x3xcomplex<f32>>>) -> memref<3x2xcomplex<f32>>
+// CHECK:           [[IDX_I:%.*]] = arith.addi
+// CHECK:           [[IDX_J:%.*]] = arith.addi
+// CHECK:           [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX_J]], [[IDX_I]]{{\]}} : memref<3x2xcomplex<f32>>
+// CHECK:           complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c0 = arith.constant 0 : index
+  %shape = fir.shape %c2, %c3 : (index, index) -> !fir.shape<2>
+  %slice = fir.slice %c1, %c2, %c1, %c1, %c3, %c1 path %c0 : (index, index, index, index, index, index, index) -> !fir.slice<2>
+  %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<2x3xcomplex<f32>>>, !fir.shape<2>, !fir.slice<2>) -> !fir.box<!fir.array<2x3xf32>>
+  fir.do_loop %i = %c1 to %c2 step %c1 unordered {
+    fir.do_loop %j = %c1 to %c3 step %c1 unordered {
+      %coor = fir.array_coor %embox %i, %j : (!fir.box<!fir.array<2x3xf32>>, index, index) -> !fir.ref<f32>
+      %val = fir.load %coor : !fir.ref<f32>
+    }
+  }
+  return
+}
+
 // ----------------------------------------------------------------------------
 // Derived-type component projection: a%x where a : TYPE{x:f64, y:complex<f64>}
 //
 // This is NOT a complex projection — the storage element is the derived type T,
-// not complex<T>.  FIRToMemRef cannot safely compute element-unit strides via
-// divsi(byte_stride, elesize) because sizeof(T)/sizeof(component) may not be an
-// integer (e.g. sizeof(T)=24, sizeof(complex<f64>)=16 -> 1.5, truncated to 1).
-//
-// The pass must leave fir.array_coor and fir.store/fir.load unconverted;
-// downstream FIR-to-LLVM lowering handles them correctly via the descriptor.
+// not complex<T>.  FIRToMemRef cannot safely handle this; downstream
+// FIR-to-LLVM lowering handles it correctly via the descriptor.
 //
 // CHECK-LABEL: func.func @derived_component_not_projected
 // The fir.array_coor must survive (not be erased).

>From 4405fa86f5d027fbea99fa876c5787592d10e940 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 10:32:14 -0700
Subject: [PATCH 2/5] tweak

---
 flang/lib/Optimizer/Transforms/FIRToMemRef.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 316641813c37a..e5acc48131b62 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -742,10 +742,6 @@ static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
   return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
 }
 
-/// If \p firMemref is defined by a fir.array_coor that indexes a
-/// complex-component projection (z%re / z%im), return false for the real
-/// component or true for the imaginary component.
-/// Returns std::nullopt when the IR does not match the projection pattern.
 std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
   auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
   if (!arrayCoor)

>From cfb8260ed2b5fdf8b21792fcc5a4ce0c242385e6 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 11:51:56 -0700
Subject: [PATCH 3/5] flatten complex to x2

---
 .../flang/Optimizer/Transforms/Passes.td      |  3 +-
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 86 +++++--------------
 .../FIRToMemRef/slice-projected.mlir          | 37 +++++---
 3 files changed, 44 insertions(+), 82 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 93813dcedb045..e107672adf907 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -221,8 +221,7 @@ def FIRToMemRef : Pass<"fir-to-memref", "::mlir::func::FuncOp"> {
     Lower FIR memory operations (`fir.alloca`, `fir.load`, `fir.store`, 'fir.array_coor', and etc.) to MLIR's MemRef core dialect.
   }];
   let dependentDialects = ["fir::FIROpsDialect", "mlir::arith::ArithDialect",
-                           "mlir::memref::MemRefDialect",
-                           "mlir::complex::ComplexDialect"];
+                           "mlir::memref::MemRefDialect"];
 }
 
 // This needs to be a "mlir::ModuleOp" pass, because we are creating debug for
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index e5acc48131b62..f0c3873366b9d 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -44,7 +44,6 @@
 #include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -141,8 +140,6 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
 
   bool memrefIsOptional(Operation *) const;
 
-  std::optional<bool> complexProjectionOf(Value firMemref) const;
-
   Value canonicalizeIndex(Value, PatternRewriter &) const;
 
   // Logical section information used by FIRToMemRef. For projected slices, the
@@ -635,19 +632,29 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
                       firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
 
-  // For complex projections, getFIRConvert uses embox.getMemref() directly,
-  // converting the underlying array to memref<N×complex<T>> and ignoring the
-  // projected box descriptor.  The indices from getMemrefIndices are correct
-  // as-is; no descriptor rebuild is needed.
+  // For complex projections, reinterpret memref<d0×...×complex<T>> as
+  // memref<d0×...×2×T> and append the component index (0=re, 1=im) so that
+  // each load/store touches exactly sizeof(T) bytes.
   SliceInfo sliceInfo;
   collectSliceInfoFrom(arrayCoorOp, sliceInfo);
   if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
     collectSliceInfoFrom(embox, sliceInfo);
   else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
     collectSliceInfoFrom(rebox, sliceInfo);
-  if (sliceInfo.hasProjectedSlice &&
-      isa<mlir::ComplexType>(memRefTy.getElementType()))
-    return std::pair{*converted, indices};
+  if (sliceInfo.projectionIsImaginary) {
+    auto srcTy = cast<MemRefType>((*converted).getType());
+    auto complexTy = cast<mlir::ComplexType>(srcTy.getElementType());
+    SmallVector<int64_t> shape(srcTy.getShape());
+    shape.push_back(2);
+    Value compMemref =
+        fir::ConvertOp::create(
+            rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
+            *converted)
+            .getResult();
+    indices.push_back(arith::ConstantIndexOp::create(
+        rewriter, loc, *sliceInfo.projectionIsImaginary ? 1 : 0));
+    return std::pair{compMemref, indices};
+  }
 
   // Static shape does not imply contiguous layout for descriptor-backed
   // entities (e.g. boxed array sections with non-unit stride). Keep the
@@ -722,37 +729,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   return std::pair{result, indices};
 }
 
-static Value complexComponent(PatternRewriter &rewriter, Location loc,
-                              Type elemTy, Value complexVal, bool isImaginary) {
-  return isImaginary ? complex::ImOp::create(rewriter, loc, elemTy, complexVal)
-                           .getResult()
-                     : complex::ReOp::create(rewriter, loc, elemTy, complexVal)
-                           .getResult();
-}
-
-static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
-                             mlir::ComplexType complexTy, Value converted,
-                             ValueRange indices, Value value,
-                             bool isImaginary) {
-  Type elemTy = complexTy.getElementType();
-  Value old = memref::LoadOp::create(rewriter, loc, converted, indices);
-  Value other = complexComponent(rewriter, loc, elemTy, old, !isImaginary);
-  Value re = isImaginary ? other : value;
-  Value im = isImaginary ? value : other;
-  return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
-}
-
-std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
-  auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
-  if (!arrayCoor)
-    return std::nullopt;
-  SliceInfo info;
-  collectSliceInfoFrom(arrayCoor, info);
-  if (auto embox = arrayCoor.getMemref().getDefiningOp<fir::EmboxOp>())
-    collectSliceInfoFrom(embox, info);
-  return info.projectionIsImaginary;
-}
-
 FailureOr<Value>
 FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
                            PatternRewriter &rewriter,
@@ -1135,13 +1111,8 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
              loadOp.dump(); assert(succeeded(verify(loadOp))));
 
   if (loadOp.getType() != originalType) {
-    // z%re / z%im: extract the scalar component; otherwise type-convert.
-    auto isComplex = complexProjectionOf(firMemref);
-    Value castVal = isComplex
-                        ? complexComponent(rewriter, loadOp.getLoc(),
-                                           originalType, loadOp, *isComplex)
-                        : createTypeConversion(rewriter, loadOp.getLoc(),
-                                               originalType, loadOp);
+    Value castVal =
+        createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
     loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
   }
 
@@ -1179,26 +1150,9 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
     value =
         createTypeConversion(rewriter, store.getLoc(), convertedType, value);
 
-  // For a complex-component projection (z%re / z%im), memref holds complex<T>
-  // but the stored value is scalar T.  Read-modify-write to preserve the
-  // untouched component:
-  //   %old  = memref.load %mem[%idx] : memref<N×complex<T>>
-  //   %re   = complex.re %old : complex<T>      // (or %im for imaginary)
-  //   %new  = complex.create %re, %val : complex<T>
-  //   memref.store %new, %mem[%idx] : memref<N×complex<T>>
-  Value storeVal = value;
-  if (auto isComplex = complexProjectionOf(firMemref)) {
-    auto complexTy = dyn_cast<mlir::ComplexType>(
-        dyn_cast<MemRefType>(converted.getType()).getElementType());
-    assert(complexTy &&
-           "complex projection converted memref must hold complex<T>");
-    storeVal = complexRMWValue(rewriter, store.getLoc(), complexTy, converted,
-                               indices, value, *isComplex);
-  }
-
   Attribute attr = store.getOperation()->getAttr("tbaa");
   memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
-      store, storeVal, converted, indices);
+      store, value, converted, indices);
   if (attr)
     storeOp.getOperation()->setAttr("tbaa", attr);
 
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 1ef5305099377..7bd3bbd19f36e 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -3,8 +3,8 @@
 // Tests for fir.slice with a path component (projected component slice).
 // A projected slice changes the element type of the boxed view, e.g.
 // z%re projects complex<f32> -> f32.  The pass bypasses the box descriptor
-// and uses the underlying complex array directly, extracting the component
-// with complex.re / complex.im at the load/store site.
+// and reinterprets the underlying complex array as memref<...x2xf32>, then
+// appends the component index (0=re, 1=im) as the final memref index.
 //
 // Derived from:
 //   complex, target :: z(4) = 0.
@@ -23,8 +23,10 @@
 // CHECK:       fir.do_loop [[I:%.*]] =
 // CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
 // CHECK:         [[IDX:%.*]] = arith.addi
-// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK:         complex.re [[CVAL]] : complex<f32>
+// CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK:         arith.constant 0
+// CHECK:         memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT:     complex.re
 func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -47,8 +49,10 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
 // CHECK:       fir.do_loop [[I:%.*]] =
 // CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
 // CHECK:         [[IDX:%.*]] = arith.addi
-// CHECK:         [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK:         complex.re [[CVAL]] : complex<f32>
+// CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK:         arith.constant 0
+// CHECK:         memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT:     complex.re
 func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -66,16 +70,18 @@ func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
 
 // ----------------------------------------------------------------------------
 // Imaginary component store: z(1:4:1)%im = val
-// Read-modify-write: load complex, update imaginary, store back.
+// Direct scalar store — no read-modify-write, no complex.create.
 // ----------------------------------------------------------------------------
 // CHECK-LABEL: func.func @projected_slice_store_im
 // CHECK:       fir.do_loop [[I:%.*]] =
 // CHECK:         [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
 // CHECK:         [[IDX:%.*]] = arith.addi
-// CHECK:         [[OLD:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK:         [[RE:%.*]]  = complex.re [[OLD]] : complex<f32>
-// CHECK:         [[NEW:%.*]] = complex.create [[RE]], %arg1 : complex<f32>
-// CHECK:         memref.store [[NEW]], [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK:         arith.constant 1
+// CHECK:         memref.store %arg1, [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT:     complex.re
+// CHECK-NOT:     complex.create
+// CHECK-NOT:     memref.load
 func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
                                     %arg1: f32) {
   %c1 = arith.constant 1 : index
@@ -97,13 +103,15 @@ func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
 //
 // convertMemrefType reverses Fortran column-major extents to MLIR row-major:
 //   !fir.ref<!fir.array<2x3xcomplex<f32>>> → memref<3x2xcomplex<f32>>
+// Reinterpret adds the component dimension:
+//   memref<3x2xcomplex<f32>> → memref<3x2x2xf32>
 //
 // Per-dimension element index (0-based, column-major):
 //   elemIdx_i = (i-1)*1 + (1-1) = i-1   (Fortran dim 1, size 2)
 //   elemIdx_j = (j-1)*1 + (1-1) = j-1   (Fortran dim 2, size 3)
 //
 // After reversing for MLIR row-major access:
-//   memref.load [elemIdx_j, elemIdx_i]
+//   memref.load [elemIdx_j, elemIdx_i, 0]
 // ----------------------------------------------------------------------------
 // CHECK-LABEL: func.func @projected_slice_2d
 // CHECK:       fir.do_loop [[I:%.*]] =
@@ -111,8 +119,9 @@ func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
 // CHECK:           [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<2x3xcomplex<f32>>>) -> memref<3x2xcomplex<f32>>
 // CHECK:           [[IDX_I:%.*]] = arith.addi
 // CHECK:           [[IDX_J:%.*]] = arith.addi
-// CHECK:           [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX_J]], [[IDX_I]]{{\]}} : memref<3x2xcomplex<f32>>
-// CHECK:           complex.re [[CVAL]] : complex<f32>
+// CHECK:           [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<3x2xcomplex<f32>>) -> memref<3x2x2xf32>
+// CHECK:           arith.constant 0
+// CHECK:           memref.load [[COMP]][[[IDX_J]], [[IDX_I]], {{%.*}}] : memref<3x2x2xf32>
 func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index

>From 62389d77af4afd0c5f3526fdfaf4388c8a6612ea Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 11:57:44 -0700
Subject: [PATCH 4/5] tweak

---
 flang/test/Transforms/FIRToMemRef/slice-projected.mlir | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7bd3bbd19f36e..17af59086122c 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -26,7 +26,6 @@
 // CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
 // CHECK:         arith.constant 0
 // CHECK:         memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT:     complex.re
 func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -52,7 +51,6 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
 // CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
 // CHECK:         arith.constant 0
 // CHECK:         memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT:     complex.re
 func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -79,9 +77,6 @@ func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
 // CHECK:         [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
 // CHECK:         arith.constant 1
 // CHECK:         memref.store %arg1, [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT:     complex.re
-// CHECK-NOT:     complex.create
-// CHECK-NOT:     memref.load
 func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
                                     %arg1: f32) {
   %c1 = arith.constant 1 : index

>From 1fc896dbb83570d46e8390dc27af670831f98250 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Fri, 8 May 2026 08:21:19 -0700
Subject: [PATCH 5/5] tweak

---
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 52 +++++++++++--------
 1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index f0c3873366b9d..6fc9af500eb72 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -150,7 +150,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     SmallVector<Value> shiftVec;
     SmallVector<Value> sliceVec;
     bool hasProjectedSlice = false;
-    std::optional<bool> projectionIsImaginary; // set when hasProjectedSlice
+    // Constant value of the first projected-slice field, if any.
+    std::optional<std::int64_t> projectedSliceStart;
   };
 
   template <typename OpTy>
@@ -172,15 +173,13 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
     return sliceOp && !sliceOp.getFields().empty();
   }
 
-  // Returns true for imaginary, false for real, nullopt if not a constant.
-  static std::optional<bool> sliceProjectionIsImaginary(fir::SliceOp sliceOp) {
+  // Returns the constant first projected-slice field, if available.
+  static std::optional<std::int64_t>
+  getProjectedSliceStartIfConstant(fir::SliceOp sliceOp) {
     auto fields = sliceOp.getFields();
     if (fields.empty())
       return std::nullopt;
-    if (auto cst = fields[0].getDefiningOp<arith::ConstantOp>())
-      if (auto attr = mlir::dyn_cast<IntegerAttr>(cst.getValueAttr()))
-        return attr.getInt() != 0;
-    return std::nullopt;
+    return fir::getIntIfConstant(fields.front());
   }
 
   unsigned getRankFromEmbox(fir::EmboxOp embox) const {
@@ -327,7 +326,7 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
     if (auto sliceOp = getSliceOp(op.getSlice())) {
       if (hasProjectedSlice(sliceOp)) {
         info.hasProjectedSlice = true;
-        info.projectionIsImaginary = sliceProjectionIsImaginary(sliceOp);
+        info.projectedSliceStart = getProjectedSliceStartIfConstant(sliceOp);
       }
       auto triples = sliceOp.getTriples();
       info.sliceVec.append(triples.begin(), triples.end());
@@ -641,19 +640,30 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
     collectSliceInfoFrom(embox, sliceInfo);
   else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
     collectSliceInfoFrom(rebox, sliceInfo);
-  if (sliceInfo.projectionIsImaginary) {
-    auto srcTy = cast<MemRefType>((*converted).getType());
-    auto complexTy = cast<mlir::ComplexType>(srcTy.getElementType());
-    SmallVector<int64_t> shape(srcTy.getShape());
-    shape.push_back(2);
-    Value compMemref =
-        fir::ConvertOp::create(
-            rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
-            *converted)
-            .getResult();
-    indices.push_back(arith::ConstantIndexOp::create(
-        rewriter, loc, *sliceInfo.projectionIsImaginary ? 1 : 0));
-    return std::pair{compMemref, indices};
+  auto srcTy = cast<MemRefType>((*converted).getType());
+  if (sliceInfo.hasProjectedSlice) {
+    if (auto complexTy = dyn_cast<mlir::ComplexType>(srcTy.getElementType())) {
+      if (!sliceInfo.projectedSliceStart ||
+          (*sliceInfo.projectedSliceStart != 0 &&
+           *sliceInfo.projectedSliceStart != 1)) {
+        LLVM_DEBUG(
+            llvm::dbgs()
+            << "FIRToMemRef: projected complex slice selector must be constant "
+               "0 (real) or 1 (imaginary), bailing out of conversion\n");
+        return failure();
+      }
+      auto projection = *sliceInfo.projectedSliceStart;
+      SmallVector<int64_t> shape(srcTy.getShape());
+      shape.push_back(2);
+      Value compMemref =
+          fir::ConvertOp::create(
+              rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
+              *converted)
+              .getResult();
+      indices.push_back(
+          arith::ConstantIndexOp::create(rewriter, loc, projection));
+      return std::pair{compMemref, indices};
+    }
   }
 
   // Static shape does not imply contiguous layout for descriptor-backed



More information about the flang-commits mailing list