[flang-commits] [flang] [flang][FIRToMemRef] [flang][fir-to-memref] Lower complex projected slices via memref<...x2xT> reinterpretation (PR #196123)
Susan Tan ス-ザン タン via flang-commits
flang-commits at lists.llvm.org
Fri May 8 08:21:34 PDT 2026
https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/196123
>From a25881489a6a90f3de167f72fb54fae8d334aa7e Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 10:13:56 -0700
Subject: [PATCH 1/5] add complex implementation
---
.../flang/Optimizer/Transforms/Passes.td | 3 +-
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 127 +++++++++++----
.../FIRToMemRef/slice-projected.mlir | 150 +++++++++++++-----
3 files changed, 204 insertions(+), 76 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e107672adf907..93813dcedb045 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -221,7 +221,8 @@ def FIRToMemRef : Pass<"fir-to-memref", "::mlir::func::FuncOp"> {
Lower FIR memory operations (`fir.alloca`, `fir.load`, `fir.store`, 'fir.array_coor', and etc.) to MLIR's MemRef core dialect.
}];
let dependentDialects = ["fir::FIROpsDialect", "mlir::arith::ArithDialect",
- "mlir::memref::MemRefDialect"];
+ "mlir::memref::MemRefDialect",
+ "mlir::complex::ComplexDialect"];
}
// This needs to be a "mlir::ModuleOp" pass, because we are creating debug for
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index ec58d6f3f1447..316641813c37a 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -44,6 +44,7 @@
#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -140,6 +141,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
bool memrefIsOptional(Operation *) const;
+ std::optional<bool> complexProjectionOf(Value firMemref) const;
+
Value canonicalizeIndex(Value, PatternRewriter &) const;
// Logical section information used by FIRToMemRef. For projected slices, the
@@ -150,6 +153,7 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
SmallVector<Value> shiftVec;
SmallVector<Value> sliceVec;
bool hasProjectedSlice = false;
+ std::optional<bool> projectionIsImaginary; // set when hasProjectedSlice
};
template <typename OpTy>
@@ -171,6 +175,17 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
return sliceOp && !sliceOp.getFields().empty();
}
+ // Returns true for imaginary, false for real, nullopt if not a constant.
+ static std::optional<bool> sliceProjectionIsImaginary(fir::SliceOp sliceOp) {
+ auto fields = sliceOp.getFields();
+ if (fields.empty())
+ return std::nullopt;
+ if (auto cst = fields[0].getDefiningOp<arith::ConstantOp>())
+ if (auto attr = mlir::dyn_cast<IntegerAttr>(cst.getValueAttr()))
+ return attr.getInt() != 0;
+ return std::nullopt;
+ }
+
unsigned getRankFromEmbox(fir::EmboxOp embox) const {
auto memrefType = embox.getMemref().getType();
Type unwrappedType = fir::unwrapRefType(memrefType);
@@ -313,15 +328,12 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
}
if (auto sliceOp = getSliceOp(op.getSlice())) {
- // A slice path changes the physical projection of the boxed entity (for
- // example, `complex -> real` for `%re`). Preserve shape/shift for logical
- // indexing, but do not treat the triplets alone as layout information.
if (hasProjectedSlice(sliceOp)) {
info.hasProjectedSlice = true;
- } else {
- auto triples = sliceOp.getTriples();
- info.sliceVec.append(triples.begin(), triples.end());
+ info.projectionIsImaginary = sliceProjectionIsImaginary(sliceOp);
}
+ auto triples = sliceOp.getTriples();
+ info.sliceVec.append(triples.begin(), triples.end());
}
}
}
@@ -473,9 +485,6 @@ FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
rank = getRankFromEmbox(embox);
}
- // Projected boxed slices leave `sliceVec` empty on purpose: indices are
- // computed in the logical section coordinate space, while stride/base come
- // later from the box descriptor.
SmallVector<Value> &shiftVec = sliceInfo.shiftVec;
SmallVector<Value> &sliceVec = sliceInfo.sliceVec;
SmallVector<Value> sliceLbs, sliceStrides;
@@ -626,6 +635,20 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
+ // For complex projections, getFIRConvert uses embox.getMemref() directly,
+ // converting the underlying array to memref<N×complex<T>> and ignoring the
+ // projected box descriptor. The indices from getMemrefIndices are correct
+ // as-is; no descriptor rebuild is needed.
+ SliceInfo sliceInfo;
+ collectSliceInfoFrom(arrayCoorOp, sliceInfo);
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+ collectSliceInfoFrom(embox, sliceInfo);
+ else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+ collectSliceInfoFrom(rebox, sliceInfo);
+ if (sliceInfo.hasProjectedSlice &&
+ isa<mlir::ComplexType>(memRefTy.getElementType()))
+ return std::pair{*converted, indices};
+
// Static shape does not imply contiguous layout for descriptor-backed
// entities (e.g. boxed array sections with non-unit stride). Keep the
// reinterpret-cast path so descriptor strides are preserved.
@@ -633,7 +656,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
return std::pair{*converted, indices};
unsigned rank = arrayCoorOp.getIndices().size();
-
if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
rank = getRankFromEmbox(embox);
@@ -642,29 +664,17 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
SmallVector<Value> strides;
strides.reserve(rank);
- SliceInfo sliceInfo;
- collectSliceInfoFrom(arrayCoorOp, sliceInfo);
-
- Value box = firMemref;
- if (!isa<BlockArgument>(firMemref)) {
- if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) {
- collectSliceInfoFrom(embox, sliceInfo);
- } else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>()) {
- collectSliceInfoFrom(rebox, sliceInfo);
- }
- }
-
SmallVector<Value> &shapeVec = sliceInfo.shapeVec;
if (sliceInfo.hasProjectedSlice || shapeVec.empty()) {
// Projected slices carry their physical layout in the descriptor. Rebuild
// the MemRef view from box metadata instead of from slice triplets.
auto boxElementSize =
- fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box);
+ fir::BoxEleSizeOp::create(rewriter, loc, indexTy, firMemref);
for (unsigned i = 0; i < rank; ++i) {
Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1);
auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy,
- indexTy, box, dim);
+ indexTy, firMemref, dim);
Value extent = boxDims->getResult(1);
sizes.push_back(castTypeToIndexType(extent, rewriter));
@@ -712,6 +722,41 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
return std::pair{result, indices};
}
+static Value complexComponent(PatternRewriter &rewriter, Location loc,
+ Type elemTy, Value complexVal, bool isImaginary) {
+ return isImaginary ? complex::ImOp::create(rewriter, loc, elemTy, complexVal)
+ .getResult()
+ : complex::ReOp::create(rewriter, loc, elemTy, complexVal)
+ .getResult();
+}
+
+static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
+ mlir::ComplexType complexTy, Value converted,
+ ValueRange indices, Value value,
+ bool isImaginary) {
+ Type elemTy = complexTy.getElementType();
+ Value old = memref::LoadOp::create(rewriter, loc, converted, indices);
+ Value other = complexComponent(rewriter, loc, elemTy, old, !isImaginary);
+ Value re = isImaginary ? other : value;
+ Value im = isImaginary ? value : other;
+ return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
+}
+
+/// If \p firMemref is defined by a fir.array_coor that indexes a
+/// complex-component projection (z%re / z%im), return false for the real
+/// component or true for the imaginary component.
+/// Returns std::nullopt when the IR does not match the projection pattern.
+std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
+ auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
+ if (!arrayCoor)
+ return std::nullopt;
+ SliceInfo info;
+ collectSliceInfoFrom(arrayCoor, info);
+ if (auto embox = arrayCoor.getMemref().getDefiningOp<fir::EmboxOp>())
+ collectSliceInfoFrom(embox, info);
+ return info.projectionIsImaginary;
+}
+
FailureOr<Value>
FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
PatternRewriter &rewriter,
@@ -791,9 +836,7 @@ FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
"the same, bailing out of conversion\n");
return failure();
}
- // Keep `box_addr` on the projected box so the descriptor remains the
- // source of truth for projected element type and stride.
- if (!projectedSlice && embox.getSlice() &&
+ if (embox.getSlice() &&
embox.getSlice().getDefiningOp<fir::SliceOp>()) {
Type originalType = embox.getMemref().getType();
basePtr = embox.getMemref();
@@ -1096,8 +1139,13 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
loadOp.dump(); assert(succeeded(verify(loadOp))));
if (loadOp.getType() != originalType) {
- Value castVal =
- createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
+ // z%re / z%im: extract the scalar component; otherwise type-convert.
+ auto isComplex = complexProjectionOf(firMemref);
+ Value castVal = isComplex
+ ? complexComponent(rewriter, loadOp.getLoc(),
+ originalType, loadOp, *isComplex)
+ : createTypeConversion(rewriter, loadOp.getLoc(),
+ originalType, loadOp);
loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
}
@@ -1135,9 +1183,26 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
value =
createTypeConversion(rewriter, store.getLoc(), convertedType, value);
- Attribute attr = (store.getOperation())->getAttr("tbaa");
+ // For a complex-component projection (z%re / z%im), memref holds complex<T>
+ // but the stored value is scalar T. Read-modify-write to preserve the
+ // untouched component:
+ // %old = memref.load %mem[%idx] : memref<N×complex<T>>
+ // %re = complex.re %old : complex<T> // (or %im for imaginary)
+ // %new = complex.create %re, %val : complex<T>
+ // memref.store %new, %mem[%idx] : memref<N×complex<T>>
+ Value storeVal = value;
+ if (auto isComplex = complexProjectionOf(firMemref)) {
+ auto complexTy = dyn_cast<mlir::ComplexType>(
+ dyn_cast<MemRefType>(converted.getType()).getElementType());
+ assert(complexTy &&
+ "complex projection converted memref must hold complex<T>");
+ storeVal = complexRMWValue(rewriter, store.getLoc(), complexTy, converted,
+ indices, value, *isComplex);
+ }
+
+ Attribute attr = store.getOperation()->getAttr("tbaa");
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
- store, value, converted, indices);
+ store, storeVal, converted, indices);
if (attr)
storeOp.getOperation()->setAttr("tbaa", attr);
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7b0fbdf748173..1ef5305099377 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -2,12 +2,9 @@
// Tests for fir.slice with a path component (projected component slice).
// A projected slice changes the element type of the boxed view, e.g.
-// z%re projects complex<f32> -> f32. The layout (strides / base address)
-// must come from the projected box descriptor, NOT from reconstructing the
-// triplets, because memref.reinterpret_cast requires the same element type
-// on both sides and the triplet strides are in storage-element units
-// (complex<f32>) while the MemRef strides must be in projected-element units
-// (f32).
+// z%re projects complex<f32> -> f32. The pass bypasses the box descriptor
+// and uses the underlying complex array directly, extracting the component
+// with complex.re / complex.im at the load/store site.
//
// Derived from:
// complex, target :: z(4) = 0.
@@ -16,40 +13,18 @@
// r = r + z(4:1:-1)%re
// ----------------------------------------------------------------------------
-// Forward projected slice: z(1:4:1)%re
-// The slice path %c0 projects complex<f32> -> f32 (real part).
-// Expected lowering:
-// - fir.box_addr on the projected box (!fir.box<!fir.array<4xf32>>)
-// - fir.convert to memref<4xf32> (NOT to memref<4xcomplex<f32>>)
-// - index = i - 1 (1-based, no triplet arithmetic)
-// - strides from fir.box_dims / fir.box_elesize on the projected box
+// Forward projected slice load: z(1:4:1)%re
+// The fir.convert appears inside the loop body (insertion point tracks the
+// array_coor inside the loop). elemIdx = (i - 1) * step + (lb - 1) = i - 1
+// for step=1, lb=1. Indices are reversed (col-major → row-major) but for 1D
+// that is a no-op.
// ----------------------------------------------------------------------------
// CHECK-LABEL: func.func @projected_slice_fwd
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
-// CHECK: [[C4:%.*]] = arith.constant 4 : index
-// CHECK: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: [[SHAPE:%.*]] = fir.shape [[C4]] : (index) -> !fir.shape<1>
-// CHECK: [[SLICE:%.*]] = fir.slice [[C1]], [[C4]], [[C1]] path [[C0]] : (index, index, index, index) -> !fir.slice<1>
-// CHECK: [[EMBOX:%.*]] = fir.embox %arg0([[SHAPE]]) {{\[}}[[SLICE]]{{\]}} : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
-// CHECK: fir.do_loop [[I:%.*]] = [[C1]] to [[C4]] step [[C1]] unordered {
-// Projected box_addr gives f32 pointer, not complex<f32>.
-// CHECK: [[BOXADDR:%.*]] = fir.box_addr [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> !fir.ref<!fir.array<4xf32>>
-// CHECK: [[CONVERT:%.*]] = fir.convert [[BOXADDR]] : (!fir.ref<!fir.array<4xf32>>) -> memref<4xf32>
-// Index: i-1 (1-based). The lowering emits: delta=i-1, scaled=delta*1,
-// offset=1-1=0, finalIdx=scaled+offset. The addi result is what feeds the load.
-// CHECK: [[C1_0:%.*]] = arith.constant 1 : index
-// CHECK: [[DELTA:%.*]] = arith.subi [[I]], [[C1_0]] : index
-// CHECK: [[SCALED:%.*]] = arith.muli [[DELTA]], [[C1_0]] : index
-// CHECK: [[OFFSET:%.*]] = arith.subi [[C1_0]], [[C1_0]] : index
-// CHECK: [[IDX:%.*]] = arith.addi [[SCALED]], [[OFFSET]] : index
-// Layout: extent and stride come from the projected box descriptor.
-// CHECK: [[ELE:%.*]] = fir.box_elesize [[EMBOX]] : (!fir.box<!fir.array<4xf32>>) -> index
-// CHECK: [[C0_0:%.*]] = arith.constant 0 : index
-// CHECK: [[DIMS:%.*]]:3 = fir.box_dims [[EMBOX]], [[C0_0]] : (!fir.box<!fir.array<4xf32>>, index) -> (index, index, index)
-// CHECK: [[STRIDE:%.*]] = arith.divsi [[DIMS]]#2, [[ELE]] : index
-// CHECK: [[C0_1:%.*]] = arith.constant 0 : index
-// CHECK: [[VIEW:%.*]] = memref.reinterpret_cast [[CONVERT]] to offset: {{\[}}[[C0_1]]{{\]}}, sizes: {{\[}}[[DIMS]]#1{{\]}}, strides: {{\[}}[[STRIDE]]{{\]}} : memref<4xf32> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK: memref.load [[VIEW]]{{\[}}[[IDX]]{{\]}} : memref<?xf32, strided<[?], offset: ?>>
+// CHECK: fir.do_loop [[I:%.*]] =
+// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK: [[IDX:%.*]] = arith.addi
+// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK: complex.re [[CVAL]] : complex<f32>
func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -64,16 +39,103 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
return
}
+// ----------------------------------------------------------------------------
+// Backward projected slice load: z(4:1:-1)%re
+// step = -1, lb = 4 → elemIdx = (i - 1) * (-1) + (4 - 1) = 3 - (i-1)
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_bwd
+// CHECK: fir.do_loop [[I:%.*]] =
+// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK: [[IDX:%.*]] = arith.addi
+// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK: complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cm1 = arith.constant -1 : index
+ %c0 = arith.constant 0 : index
+ %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+ %slice = fir.slice %c4, %c1, %cm1 path %c0 : (index, index, index, index) -> !fir.slice<1>
+ %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+ fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+ %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+ %val = fir.load %coor : !fir.ref<f32>
+ }
+ return
+}
+
+// ----------------------------------------------------------------------------
+// Imaginary component store: z(1:4:1)%im = val
+// Read-modify-write: load complex, update imaginary, store back.
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_store_im
+// CHECK: fir.do_loop [[I:%.*]] =
+// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
+// CHECK: [[IDX:%.*]] = arith.addi
+// CHECK: [[OLD:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK: [[RE:%.*]] = complex.re [[OLD]] : complex<f32>
+// CHECK: [[NEW:%.*]] = complex.create [[RE]], %arg1 : complex<f32>
+// CHECK: memref.store [[NEW]], [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
+ %arg1: f32) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c1_im = arith.constant 1 : index // imaginary component index
+ %shape = fir.shape %c4 : (index) -> !fir.shape<1>
+ %slice = fir.slice %c1, %c4, %c1 path %c1_im : (index, index, index, index) -> !fir.slice<1>
+ %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<4xcomplex<f32>>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<4xf32>>
+ fir.do_loop %i = %c1 to %c4 step %c1 unordered {
+ %coor = fir.array_coor %embox %i : (!fir.box<!fir.array<4xf32>>, index) -> !fir.ref<f32>
+ fir.store %arg1 to %coor : !fir.ref<f32>
+ }
+ return
+}
+
+// ----------------------------------------------------------------------------
+// 2-D boxed projected slice load: z(1:2:1, 1:3:1)%re
+// Storage: !fir.array<2x3xcomplex<f32>>
+//
+// convertMemrefType reverses Fortran column-major extents to MLIR row-major:
+// !fir.ref<!fir.array<2x3xcomplex<f32>>> → memref<3x2xcomplex<f32>>
+//
+// Per-dimension element index (0-based, column-major):
+// elemIdx_i = (i-1)*1 + (1-1) = i-1 (Fortran dim 1, size 2)
+// elemIdx_j = (j-1)*1 + (1-1) = j-1 (Fortran dim 2, size 3)
+//
+// After reversing for MLIR row-major access:
+// memref.load [elemIdx_j, elemIdx_i]
+// ----------------------------------------------------------------------------
+// CHECK-LABEL: func.func @projected_slice_2d
+// CHECK: fir.do_loop [[I:%.*]] =
+// CHECK: fir.do_loop [[J:%.*]] =
+// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<2x3xcomplex<f32>>>) -> memref<3x2xcomplex<f32>>
+// CHECK: [[IDX_I:%.*]] = arith.addi
+// CHECK: [[IDX_J:%.*]] = arith.addi
+// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX_J]], [[IDX_I]]{{\]}} : memref<3x2xcomplex<f32>>
+// CHECK: complex.re [[CVAL]] : complex<f32>
+func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %shape = fir.shape %c2, %c3 : (index, index) -> !fir.shape<2>
+ %slice = fir.slice %c1, %c2, %c1, %c1, %c3, %c1 path %c0 : (index, index, index, index, index, index, index) -> !fir.slice<2>
+ %embox = fir.embox %arg0(%shape) [%slice] : (!fir.ref<!fir.array<2x3xcomplex<f32>>>, !fir.shape<2>, !fir.slice<2>) -> !fir.box<!fir.array<2x3xf32>>
+ fir.do_loop %i = %c1 to %c2 step %c1 unordered {
+ fir.do_loop %j = %c1 to %c3 step %c1 unordered {
+ %coor = fir.array_coor %embox %i, %j : (!fir.box<!fir.array<2x3xf32>>, index, index) -> !fir.ref<f32>
+ %val = fir.load %coor : !fir.ref<f32>
+ }
+ }
+ return
+}
+
// ----------------------------------------------------------------------------
// Derived-type component projection: a%x where a : TYPE{x:f64, y:complex<f64>}
//
// This is NOT a complex projection — the storage element is the derived type T,
-// not complex<T>. FIRToMemRef cannot safely compute element-unit strides via
-// divsi(byte_stride, elesize) because sizeof(T)/sizeof(component) may not be an
-// integer (e.g. sizeof(T)=24, sizeof(complex<f64>)=16 -> 1.5, truncated to 1).
-//
-// The pass must leave fir.array_coor and fir.store/fir.load unconverted;
-// downstream FIR-to-LLVM lowering handles them correctly via the descriptor.
+// not complex<T>. FIRToMemRef cannot safely handle this; downstream
+// FIR-to-LLVM lowering handles it correctly via the descriptor.
//
// CHECK-LABEL: func.func @derived_component_not_projected
// The fir.array_coor must survive (not be erased).
>From 4405fa86f5d027fbea99fa876c5787592d10e940 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 10:32:14 -0700
Subject: [PATCH 2/5] tweak
---
flang/lib/Optimizer/Transforms/FIRToMemRef.cpp | 4 ----
1 file changed, 4 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 316641813c37a..e5acc48131b62 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -742,10 +742,6 @@ static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
}
-/// If \p firMemref is defined by a fir.array_coor that indexes a
-/// complex-component projection (z%re / z%im), return false for the real
-/// component or true for the imaginary component.
-/// Returns std::nullopt when the IR does not match the projection pattern.
std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
if (!arrayCoor)
>From cfb8260ed2b5fdf8b21792fcc5a4ce0c242385e6 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 11:51:56 -0700
Subject: [PATCH 3/5] flatten complex to x2
---
.../flang/Optimizer/Transforms/Passes.td | 3 +-
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 86 +++++--------------
.../FIRToMemRef/slice-projected.mlir | 37 +++++---
3 files changed, 44 insertions(+), 82 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 93813dcedb045..e107672adf907 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -221,8 +221,7 @@ def FIRToMemRef : Pass<"fir-to-memref", "::mlir::func::FuncOp"> {
Lower FIR memory operations (`fir.alloca`, `fir.load`, `fir.store`, 'fir.array_coor', and etc.) to MLIR's MemRef core dialect.
}];
let dependentDialects = ["fir::FIROpsDialect", "mlir::arith::ArithDialect",
- "mlir::memref::MemRefDialect",
- "mlir::complex::ComplexDialect"];
+ "mlir::memref::MemRefDialect"];
}
// This needs to be a "mlir::ModuleOp" pass, because we are creating debug for
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index e5acc48131b62..f0c3873366b9d 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -44,7 +44,6 @@
#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -141,8 +140,6 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
bool memrefIsOptional(Operation *) const;
- std::optional<bool> complexProjectionOf(Value firMemref) const;
-
Value canonicalizeIndex(Value, PatternRewriter &) const;
// Logical section information used by FIRToMemRef. For projected slices, the
@@ -635,19 +632,29 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
bool isDescriptor = mlir::isa<fir::BaseBoxType>(firMemref.getType()) ||
firMemref.getDefiningOp<fir::BoxAddrOp>() != nullptr;
- // For complex projections, getFIRConvert uses embox.getMemref() directly,
- // converting the underlying array to memref<N×complex<T>> and ignoring the
- // projected box descriptor. The indices from getMemrefIndices are correct
- // as-is; no descriptor rebuild is needed.
+ // For complex projections, reinterpret memref<d0×...×complex<T>> as
+ // memref<d0×...×2×T> and append the component index (0=re, 1=im) so that
+ // each load/store touches exactly sizeof(T) bytes.
SliceInfo sliceInfo;
collectSliceInfoFrom(arrayCoorOp, sliceInfo);
if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
collectSliceInfoFrom(embox, sliceInfo);
else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
collectSliceInfoFrom(rebox, sliceInfo);
- if (sliceInfo.hasProjectedSlice &&
- isa<mlir::ComplexType>(memRefTy.getElementType()))
- return std::pair{*converted, indices};
+ if (sliceInfo.projectionIsImaginary) {
+ auto srcTy = cast<MemRefType>((*converted).getType());
+ auto complexTy = cast<mlir::ComplexType>(srcTy.getElementType());
+ SmallVector<int64_t> shape(srcTy.getShape());
+ shape.push_back(2);
+ Value compMemref =
+ fir::ConvertOp::create(
+ rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
+ *converted)
+ .getResult();
+ indices.push_back(arith::ConstantIndexOp::create(
+ rewriter, loc, *sliceInfo.projectionIsImaginary ? 1 : 0));
+ return std::pair{compMemref, indices};
+ }
// Static shape does not imply contiguous layout for descriptor-backed
// entities (e.g. boxed array sections with non-unit stride). Keep the
@@ -722,37 +729,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
return std::pair{result, indices};
}
-static Value complexComponent(PatternRewriter &rewriter, Location loc,
- Type elemTy, Value complexVal, bool isImaginary) {
- return isImaginary ? complex::ImOp::create(rewriter, loc, elemTy, complexVal)
- .getResult()
- : complex::ReOp::create(rewriter, loc, elemTy, complexVal)
- .getResult();
-}
-
-static Value complexRMWValue(PatternRewriter &rewriter, Location loc,
- mlir::ComplexType complexTy, Value converted,
- ValueRange indices, Value value,
- bool isImaginary) {
- Type elemTy = complexTy.getElementType();
- Value old = memref::LoadOp::create(rewriter, loc, converted, indices);
- Value other = complexComponent(rewriter, loc, elemTy, old, !isImaginary);
- Value re = isImaginary ? other : value;
- Value im = isImaginary ? value : other;
- return complex::CreateOp::create(rewriter, loc, complexTy, re, im);
-}
-
-std::optional<bool> FIRToMemRef::complexProjectionOf(Value firMemref) const {
- auto arrayCoor = firMemref.getDefiningOp<fir::ArrayCoorOp>();
- if (!arrayCoor)
- return std::nullopt;
- SliceInfo info;
- collectSliceInfoFrom(arrayCoor, info);
- if (auto embox = arrayCoor.getMemref().getDefiningOp<fir::EmboxOp>())
- collectSliceInfoFrom(embox, info);
- return info.projectionIsImaginary;
-}
-
FailureOr<Value>
FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
PatternRewriter &rewriter,
@@ -1135,13 +1111,8 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
loadOp.dump(); assert(succeeded(verify(loadOp))));
if (loadOp.getType() != originalType) {
- // z%re / z%im: extract the scalar component; otherwise type-convert.
- auto isComplex = complexProjectionOf(firMemref);
- Value castVal = isComplex
- ? complexComponent(rewriter, loadOp.getLoc(),
- originalType, loadOp, *isComplex)
- : createTypeConversion(rewriter, loadOp.getLoc(),
- originalType, loadOp);
+ Value castVal =
+ createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
}
@@ -1179,26 +1150,9 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
value =
createTypeConversion(rewriter, store.getLoc(), convertedType, value);
- // For a complex-component projection (z%re / z%im), memref holds complex<T>
- // but the stored value is scalar T. Read-modify-write to preserve the
- // untouched component:
- // %old = memref.load %mem[%idx] : memref<N×complex<T>>
- // %re = complex.re %old : complex<T> // (or %im for imaginary)
- // %new = complex.create %re, %val : complex<T>
- // memref.store %new, %mem[%idx] : memref<N×complex<T>>
- Value storeVal = value;
- if (auto isComplex = complexProjectionOf(firMemref)) {
- auto complexTy = dyn_cast<mlir::ComplexType>(
- dyn_cast<MemRefType>(converted.getType()).getElementType());
- assert(complexTy &&
- "complex projection converted memref must hold complex<T>");
- storeVal = complexRMWValue(rewriter, store.getLoc(), complexTy, converted,
- indices, value, *isComplex);
- }
-
Attribute attr = store.getOperation()->getAttr("tbaa");
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
- store, storeVal, converted, indices);
+ store, value, converted, indices);
if (attr)
storeOp.getOperation()->setAttr("tbaa", attr);
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 1ef5305099377..7bd3bbd19f36e 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -3,8 +3,8 @@
// Tests for fir.slice with a path component (projected component slice).
// A projected slice changes the element type of the boxed view, e.g.
// z%re projects complex<f32> -> f32. The pass bypasses the box descriptor
-// and uses the underlying complex array directly, extracting the component
-// with complex.re / complex.im at the load/store site.
+// and reinterprets the underlying complex array as memref<...x2xf32>, then
+// appends the component index (0=re, 1=im) as the final memref index.
//
// Derived from:
// complex, target :: z(4) = 0.
@@ -23,8 +23,10 @@
// CHECK: fir.do_loop [[I:%.*]] =
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
-// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK: complex.re [[CVAL]] : complex<f32>
+// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK: arith.constant 0
+// CHECK: memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT: complex.re
func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -47,8 +49,10 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// CHECK: fir.do_loop [[I:%.*]] =
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
-// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK: complex.re [[CVAL]] : complex<f32>
+// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK: arith.constant 0
+// CHECK: memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT: complex.re
func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -66,16 +70,18 @@ func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// ----------------------------------------------------------------------------
// Imaginary component store: z(1:4:1)%im = val
-// Read-modify-write: load complex, update imaginary, store back.
+// Direct scalar store — no read-modify-write, no complex.create.
// ----------------------------------------------------------------------------
// CHECK-LABEL: func.func @projected_slice_store_im
// CHECK: fir.do_loop [[I:%.*]] =
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<4xcomplex<f32>>>) -> memref<4xcomplex<f32>>
// CHECK: [[IDX:%.*]] = arith.addi
-// CHECK: [[OLD:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
-// CHECK: [[RE:%.*]] = complex.re [[OLD]] : complex<f32>
-// CHECK: [[NEW:%.*]] = complex.create [[RE]], %arg1 : complex<f32>
-// CHECK: memref.store [[NEW]], [[MEMREF]]{{\[}}[[IDX]]{{\]}} : memref<4xcomplex<f32>>
+// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
+// CHECK: arith.constant 1
+// CHECK: memref.store %arg1, [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
+// CHECK-NOT: complex.re
+// CHECK-NOT: complex.create
+// CHECK-NOT: memref.load
func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
%arg1: f32) {
%c1 = arith.constant 1 : index
@@ -97,13 +103,15 @@ func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
//
// convertMemrefType reverses Fortran column-major extents to MLIR row-major:
// !fir.ref<!fir.array<2x3xcomplex<f32>>> → memref<3x2xcomplex<f32>>
+// Reinterpret adds the component dimension:
+// memref<3x2xcomplex<f32>> → memref<3x2x2xf32>
//
// Per-dimension element index (0-based, column-major):
// elemIdx_i = (i-1)*1 + (1-1) = i-1 (Fortran dim 1, size 2)
// elemIdx_j = (j-1)*1 + (1-1) = j-1 (Fortran dim 2, size 3)
//
// After reversing for MLIR row-major access:
-// memref.load [elemIdx_j, elemIdx_i]
+// memref.load [elemIdx_j, elemIdx_i, 0]
// ----------------------------------------------------------------------------
// CHECK-LABEL: func.func @projected_slice_2d
// CHECK: fir.do_loop [[I:%.*]] =
@@ -111,8 +119,9 @@ func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
// CHECK: [[MEMREF:%.*]] = fir.convert %arg0 : (!fir.ref<!fir.array<2x3xcomplex<f32>>>) -> memref<3x2xcomplex<f32>>
// CHECK: [[IDX_I:%.*]] = arith.addi
// CHECK: [[IDX_J:%.*]] = arith.addi
-// CHECK: [[CVAL:%.*]] = memref.load [[MEMREF]]{{\[}}[[IDX_J]], [[IDX_I]]{{\]}} : memref<3x2xcomplex<f32>>
-// CHECK: complex.re [[CVAL]] : complex<f32>
+// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<3x2xcomplex<f32>>) -> memref<3x2x2xf32>
+// CHECK: arith.constant 0
+// CHECK: memref.load [[COMP]][[[IDX_J]], [[IDX_I]], {{%.*}}] : memref<3x2x2xf32>
func.func @projected_slice_2d(%arg0: !fir.ref<!fir.array<2x3xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
>From 62389d77af4afd0c5f3526fdfaf4388c8a6612ea Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 6 May 2026 11:57:44 -0700
Subject: [PATCH 4/5] tweak
---
flang/test/Transforms/FIRToMemRef/slice-projected.mlir | 5 -----
1 file changed, 5 deletions(-)
diff --git a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
index 7bd3bbd19f36e..17af59086122c 100644
--- a/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
+++ b/flang/test/Transforms/FIRToMemRef/slice-projected.mlir
@@ -26,7 +26,6 @@
// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
// CHECK: arith.constant 0
// CHECK: memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT: complex.re
func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -52,7 +51,6 @@ func.func @projected_slice_fwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
// CHECK: arith.constant 0
// CHECK: memref.load [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT: complex.re
func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -79,9 +77,6 @@ func.func @projected_slice_bwd(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>) {
// CHECK: [[COMP:%.*]] = fir.convert [[MEMREF]] : (memref<4xcomplex<f32>>) -> memref<4x2xf32>
// CHECK: arith.constant 1
// CHECK: memref.store %arg1, [[COMP]][[[IDX]], {{%.*}}] : memref<4x2xf32>
-// CHECK-NOT: complex.re
-// CHECK-NOT: complex.create
-// CHECK-NOT: memref.load
func.func @projected_slice_store_im(%arg0: !fir.ref<!fir.array<4xcomplex<f32>>>,
%arg1: f32) {
%c1 = arith.constant 1 : index
>From 1fc896dbb83570d46e8390dc27af670831f98250 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Fri, 8 May 2026 08:21:19 -0700
Subject: [PATCH 5/5] tweak
---
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 52 +++++++++++--------
1 file changed, 31 insertions(+), 21 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index f0c3873366b9d..6fc9af500eb72 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -150,7 +150,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
SmallVector<Value> shiftVec;
SmallVector<Value> sliceVec;
bool hasProjectedSlice = false;
- std::optional<bool> projectionIsImaginary; // set when hasProjectedSlice
+ // Constant value of the first projected-slice field, if any.
+ std::optional<std::int64_t> projectedSliceStart;
};
template <typename OpTy>
@@ -172,15 +173,13 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
return sliceOp && !sliceOp.getFields().empty();
}
- // Returns true for imaginary, false for real, nullopt if not a constant.
- static std::optional<bool> sliceProjectionIsImaginary(fir::SliceOp sliceOp) {
+ // Returns the constant first projected-slice field, if available.
+ static std::optional<std::int64_t>
+ getProjectedSliceStartIfConstant(fir::SliceOp sliceOp) {
auto fields = sliceOp.getFields();
if (fields.empty())
return std::nullopt;
- if (auto cst = fields[0].getDefiningOp<arith::ConstantOp>())
- if (auto attr = mlir::dyn_cast<IntegerAttr>(cst.getValueAttr()))
- return attr.getInt() != 0;
- return std::nullopt;
+ return fir::getIntIfConstant(fields.front());
}
unsigned getRankFromEmbox(fir::EmboxOp embox) const {
@@ -327,7 +326,7 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
if (auto sliceOp = getSliceOp(op.getSlice())) {
if (hasProjectedSlice(sliceOp)) {
info.hasProjectedSlice = true;
- info.projectionIsImaginary = sliceProjectionIsImaginary(sliceOp);
+ info.projectedSliceStart = getProjectedSliceStartIfConstant(sliceOp);
}
auto triples = sliceOp.getTriples();
info.sliceVec.append(triples.begin(), triples.end());
@@ -641,19 +640,30 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
collectSliceInfoFrom(embox, sliceInfo);
else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
collectSliceInfoFrom(rebox, sliceInfo);
- if (sliceInfo.projectionIsImaginary) {
- auto srcTy = cast<MemRefType>((*converted).getType());
- auto complexTy = cast<mlir::ComplexType>(srcTy.getElementType());
- SmallVector<int64_t> shape(srcTy.getShape());
- shape.push_back(2);
- Value compMemref =
- fir::ConvertOp::create(
- rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
- *converted)
- .getResult();
- indices.push_back(arith::ConstantIndexOp::create(
- rewriter, loc, *sliceInfo.projectionIsImaginary ? 1 : 0));
- return std::pair{compMemref, indices};
+ auto srcTy = cast<MemRefType>((*converted).getType());
+ if (sliceInfo.hasProjectedSlice) {
+ if (auto complexTy = dyn_cast<mlir::ComplexType>(srcTy.getElementType())) {
+ if (!sliceInfo.projectedSliceStart ||
+ (*sliceInfo.projectedSliceStart != 0 &&
+ *sliceInfo.projectedSliceStart != 1)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: projected complex slice selector must be constant "
+ "0 (real) or 1 (imaginary), bailing out of conversion\n");
+ return failure();
+ }
+ auto projection = *sliceInfo.projectedSliceStart;
+ SmallVector<int64_t> shape(srcTy.getShape());
+ shape.push_back(2);
+ Value compMemref =
+ fir::ConvertOp::create(
+ rewriter, loc, MemRefType::get(shape, complexTy.getElementType()),
+ *converted)
+ .getResult();
+ indices.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, projection));
+ return std::pair{compMemref, indices};
+ }
}
// Static shape does not imply contiguous layout for descriptor-backed
More information about the flang-commits
mailing list