[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)

Alan Li llvmlistbot at llvm.org
Fri Sep 19 14:59:08 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/159841

>From bc742d1f6289ba275a0b61e34654dc198846c6e0 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 16 Sep 2025 20:57:41 -0400
Subject: [PATCH 1/6] Update

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 13 +++++--
 mlir/test/Dialect/MemRef/flatten_memref.mlir  | 38 +++++++++++++++++++
 2 files changed, 48 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..e2d5d62f64750 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -99,7 +99,8 @@ namespace {
 static Value getTargetMemref(Operation *op) {
   return llvm::TypeSwitch<Operation *, Value>(op)
       .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
-                     memref::AllocOp>([](auto op) { return op.getMemref(); })
+                     memref::AllocOp, memref::DeallocOp>(
+          [](auto op) { return op.getMemref(); })
       .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
                      vector::MaskedStoreOp, vector::TransferReadOp,
                      vector::TransferWriteOp>(
@@ -189,6 +190,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
             rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
         rewriter.replaceOp(op, newTransferWrite);
       })
+      .template Case<memref::DeallocOp>([&](auto op) {
+        auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
+        rewriter.replaceOp(op, newDealloc);
+      })
       .Default([&](auto op) {
         op->emitOpError("unimplemented: do not know how to replace op.");
       });
@@ -197,7 +202,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
 template <typename T>
 static ValueRange getIndices(T op) {
   if constexpr (std::is_same_v<T, memref::AllocaOp> ||
-                std::is_same_v<T, memref::AllocOp>) {
+                std::is_same_v<T, memref::AllocOp> ||
+                std::is_same_v<T, memref::DeallocOp>) {
     return ValueRange{};
   } else {
     return op.getIndices();
@@ -286,7 +292,8 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
   patterns.insert<MemRefRewritePattern<memref::LoadOp>,
                   MemRefRewritePattern<memref::StoreOp>,
                   MemRefRewritePattern<memref::AllocOp>,
-                  MemRefRewritePattern<memref::AllocaOp>>(
+                  MemRefRewritePattern<memref::AllocaOp>,
+                  MemRefRewritePattern<memref::DeallocOp>>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index e45a10ca0d431..88f46b07dad93 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -298,3 +298,41 @@ func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32,
 // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
 // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
 // CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @dealloc_static_memref(%input: memref<4x8xf32>) {
+  memref.dealloc %input : memref<4x8xf32>
+  return
+}
+
+// CHECK-LABEL: func @dealloc_static_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>)
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32, strided<[1]>>
+// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1]>>
+
+// -----
+
+func.func @dealloc_dynamic_memref(%input: memref<?x?xf32>) {
+  memref.dealloc %input : memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @dealloc_dynamic_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[SIZE:.*]] = affine.max #{{.*}}()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32> to memref<?xf32, strided<[1]>>
+// CHECK: memref.dealloc %[[REINT]] : memref<?xf32, strided<[1]>>
+
+// -----
+
+func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) {
+  memref.dealloc %input : memref<4x8xf32, strided<[8, 1], offset: 100>>
+  return
+}
+
+// CHECK-LABEL: func @dealloc_strided_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>

>From f549e4fbc6215d775bd67d1ec1d679378d111e3a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 09:30:55 -0400
Subject: [PATCH 2/6] New things.

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 212 +++++++++++++++++-
 mlir/test/Dialect/MemRef/flatten_memref.mlir  |  48 ++++
 2 files changed, 259 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index e2d5d62f64750..e0dbc8d8641bd 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -95,6 +96,99 @@ static bool checkLayout(Value val) {
          isa<StridedLayoutAttr>(type.getLayout());
 }
 
+/// Produce an OpFoldResult representing the product of the values or constants
+/// referenced by `indices`. `staticShape` provides the statically known sizes
+/// for the source memref, while `values` contains the mixed (value/attribute)
+/// representation produced by `memref.extract_strided_metadata`.
+static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices,
+                                       OpBuilder &builder, Location loc,
+                                       ArrayRef<int64_t> staticShape,
+                                       ArrayRef<OpFoldResult> values) {
+  AffineExpr product = builder.getAffineConstantExpr(1);
+  SmallVector<OpFoldResult> inputs;
+  unsigned numSymbols = 0;
+  for (int64_t idx : indices) {
+    product = product * builder.getAffineSymbolExpr(numSymbols++);
+    if (ShapedType::isDynamic(staticShape[idx]))
+      inputs.push_back(values[idx]);
+    else
+      inputs.push_back(builder.getIndexAttr(staticShape[idx]));
+  }
+  return affine::makeComposedFoldedAffineApply(builder, loc, product, inputs);
+}
+
+/// Return the collapsed size (as OpFoldResult) for the reassociation group
+/// `groupId` of `collapseShapeOp`.
+static SmallVector<OpFoldResult>
+getCollapsedSize(memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+  SmallVector<OpFoldResult> collapsedSize;
+
+  MemRefType resultType = collapseShapeOp.getResultType();
+  int64_t dimSize = resultType.getDimSize(groupId);
+  if (!ShapedType::isDynamic(dimSize)) {
+    collapsedSize.push_back(builder.getIndexAttr(dimSize));
+    return collapsedSize;
+  }
+
+  auto sourceType = collapseShapeOp.getSrcType();
+  ArrayRef<int64_t> staticShape = sourceType.getShape();
+  ArrayRef<int64_t> reassocGroup =
+      collapseShapeOp.getReassociationIndices()[groupId];
+
+  collapsedSize.push_back(getProductOfValues(reassocGroup, builder,
+                                             collapseShapeOp.getLoc(),
+                                             staticShape, origSizes));
+  return collapsedSize;
+}
+
+/// Return the collapsed stride (as OpFoldResult) for the reassociation group
+/// `groupId` of `collapseShapeOp`.
+static SmallVector<OpFoldResult> getCollapsedStride(
+    memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
+    ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides,
+    unsigned groupId) {
+  ArrayRef<int64_t> reassocGroup =
+      collapseShapeOp.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "reassociation group must contain at least one dimension");
+
+  auto sourceType = collapseShapeOp.getSrcType();
+  auto [strides, offset] = sourceType.getStridesAndOffset();
+  (void)offset;
+  ArrayRef<int64_t> srcShape = sourceType.getShape();
+
+  OpFoldResult lastValidStride = nullptr;
+  for (int64_t dim : reassocGroup) {
+    if (srcShape[dim] == 1)
+      continue;
+    int64_t currentStride = strides[dim];
+    if (ShapedType::isDynamic(currentStride))
+      lastValidStride = origStrides[dim];
+    else
+      lastValidStride = builder.getIndexAttr(currentStride);
+  }
+
+  if (!lastValidStride) {
+    MemRefType collapsedType = collapseShapeOp.getResultType();
+    auto [collapsedStrides, collapsedOffset] =
+        collapsedType.getStridesAndOffset();
+    (void)collapsedOffset;
+    int64_t finalStride = collapsedStrides[groupId];
+    if (ShapedType::isDynamic(finalStride)) {
+      for (int64_t dim : reassocGroup) {
+        assert(srcShape[dim] == 1 && "expected size-one dimensions");
+        if (ShapedType::isDynamic(strides[dim]))
+          return {origStrides[dim]};
+      }
+      llvm_unreachable("expected to find a dynamic stride");
+    }
+    return {builder.getIndexAttr(finalStride)};
+  }
+
+  return {lastValidStride};
+}
+
 namespace {
 static Value getTargetMemref(Operation *op) {
   return llvm::TypeSwitch<Operation *, Value>(op)
@@ -256,6 +350,82 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
   }
 };
 
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+    if (!value)
+      return value;
+    if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+      return splatAttr.reshape(newType);
+    } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+      return denseAttr.reshape(newType);
+    } else if (auto denseResourceAttr =
+                   llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+      return DenseResourceElementsAttr::get(newType,
+                                            denseResourceAttr.getRawHandle());
+    }
+    return {};
+  }
+
+  LogicalResult
+  matchAndRewrite(memref::GlobalOp globalOp,
+                  PatternRewriter &rewriter) const override {
+    auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+    if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+      return failure();
+
+    auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+                                            oldType.getElementType());
+    auto memRefType =
+        MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+                        AffineMap(), oldType.getMemorySpace());
+    auto newInitialValue =
+        flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+    rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+        globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+        memRefType, newInitialValue, globalOp.getConstant(),
+        /*alignment=*/IntegerAttr());
+    return success();
+  }
+};
+
+struct FlattenCollapseShape final
+    : public OpRewritePattern<memref::CollapseShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> collapsedSizes;
+    SmallVector<OpFoldResult> collapsedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    collapsedSizes.reserve(numGroups);
+    collapsedStrides.reserve(numGroups);
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          getCollapsedSize(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+      collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+      collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, collapsedSizes,
+        collapsedStrides);
+    return success();
+  }
+};
+
 struct FlattenMemrefsPass
     : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
   using Base::Base;
@@ -288,12 +458,52 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
       patterns.getContext());
 }
 
+/// Special pattern for GetGlobalOp to avoid infinite loops
+struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::GetGlobalOp op,
+                               PatternRewriter &rewriter) const override {
+    // Check if this get_global references a multi-dimensional global
+    auto module = op->template getParentOfType<ModuleOp>();
+    auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+    if (!globalOp) {
+      return failure();
+    }
+
+    auto globalType = globalOp.getType();
+    auto resultType = op.getType();
+
+    // Only apply if the global has been flattened but the get_global hasn't
+    if (globalType.getRank() == 1 && resultType.getRank() > 1) {
+      auto newGetGlobal = memref::GetGlobalOp::create(
+          rewriter, op.getLoc(), globalType, op.getName());
+
+      // Cast the flattened result back to the original shape
+      memref::ExtractStridedMetadataOp stridedMetadata =
+          memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+      auto castResult = memref::ReinterpretCastOp::create(
+          rewriter, op.getLoc(), resultType, newGetGlobal,
+          /*offset=*/rewriter.getIndexAttr(0),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides());
+      rewriter.replaceOp(op, castResult);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
   patterns.insert<MemRefRewritePattern<memref::LoadOp>,
                   MemRefRewritePattern<memref::StoreOp>,
                   MemRefRewritePattern<memref::AllocOp>,
                   MemRefRewritePattern<memref::AllocaOp>,
-                  MemRefRewritePattern<memref::DeallocOp>>(
+                  MemRefRewritePattern<memref::DeallocOp>,
+                  FlattenCollapseShape,
+                  FlattenGetGlobal,
+                  FlattenGlobal>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 88f46b07dad93..5c61bc87418bd 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -194,6 +194,35 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
 
 // -----
 
+func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+      : memref<2x3x4xf32> into memref<6x4xf32>
+  return %0 : memref<6x4xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [6, 4], strides: [4, 1]
+// CHECK: return %[[REINT]]
+
+// -----
+
+func.func @collapse_shape_dynamic(
+    %arg0: memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>) ->
+    memref<?x4xf32, strided<[?, ?], offset: ?>> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+      : memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>
+        into memref<?x4xf32, strided<[?, ?], offset: ?>>
+  return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
+}
+// CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK-LABEL: func @collapse_shape_dynamic
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
+// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 4], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
+// CHECK: return %[[REINT]]
+
+// -----
+
 func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
    %c0 = arith.constant 0 : i2
    %0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
@@ -336,3 +365,22 @@ func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset
 // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
 // CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+memref.global "private" constant @constant_3x3x1x1xf32 : memref<3x3x1x1xf32> = dense<[[[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]], [[[-2.000000e+00]], [[0.000000e+00]], [[2.000000e+00]]], [[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]]]>
+func.func @load_global_with_offset(%i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
+  %global = memref.get_global @constant_3x3x1x1xf32 : memref<3x3x1x1xf32>
+  %val = memref.load %global[%i0, %i1, %i2, %i3] : memref<3x3x1x1xf32>
+  return %val: f32
+}
+
+//      CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 * 3 + s1 + s2 + s3)>
+//      CHECK: memref.global "private" constant @constant_3x3x1x1xf32 : memref<9xf32> = dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, -2.000000e+00, 0.000000e+00, 2.000000e+00, -1.000000e+00, 0.000000e+00, 1.000000e+00]>
+//CHECK-LABEL: func.func @load_global_with_offset
+// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
+//      CHECK:   %[[GLOBAL:.+]] = memref.get_global @constant_3x3x1x1xf32 : memref<9xf32>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[$MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]]]
+//      CHECK:   %[[REINTERPRET:.+]] = memref.reinterpret_cast %[[GLOBAL]] to offset: [0], sizes: [9], strides: [1] : memref<9xf32> to memref<9xf32, strided<[1]>>
+//      CHECK:   %[[LOAD:.+]] = memref.load %[[REINTERPRET]][%[[INDEX]]] : memref<9xf32, strided<[1]>>
+//      CHECK:   return %[[LOAD]]

>From 79d2edfdb5788c13a8f580c551ca712b32f8b0b7 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 09:40:18 -0400
Subject: [PATCH 3/6] ExpandShapeOp

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 144 ++++++++++++++++++
 mlir/test/Dialect/MemRef/flatten_memref.mlir  |  31 +++-
 2 files changed, 174 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index e0dbc8d8641bd..5394df5b46fb2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -96,6 +96,115 @@ static bool checkLayout(Value val) {
          isa<StridedLayoutAttr>(type.getLayout());
 }
 
+/// Compute the expanded sizes of the given expand_shape for the reassociation
+/// group `groupId`. Portions adapted from
+/// `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
+/// dependency from the MemRef dialect on the Affine dialect.
+static SmallVector<OpFoldResult>
+getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+  ArrayRef<int64_t> reassocGroup =
+      expandShape.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "Reassociation group should have at least one dimension");
+
+  unsigned groupSize = reassocGroup.size();
+  SmallVector<OpFoldResult> expandedSizes(groupSize);
+
+  uint64_t productOfAllStaticSizes = 1;
+  std::optional<unsigned> dynSizeIdx;
+  MemRefType expandShapeType = expandShape.getResultType();
+
+  for (unsigned i = 0; i < groupSize; ++i) {
+    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+    if (ShapedType::isDynamic(dimSize)) {
+      assert(!dynSizeIdx && "there must be at most one dynamic size per group");
+      dynSizeIdx = i;
+      continue;
+    }
+    productOfAllStaticSizes *= dimSize;
+    expandedSizes[i] = builder.getIndexAttr(dimSize);
+  }
+
+  if (dynSizeIdx) {
+    AffineExpr s0 = builder.getAffineSymbolExpr(0);
+    expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply(
+        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
+        origSizes[groupId]);
+  }
+
+  return expandedSizes;
+}
+
+/// Compute the expanded strides of the given expand_shape for the reassociation
+/// group `groupId`.
+static SmallVector<OpFoldResult>
+getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+                   ArrayRef<OpFoldResult> origSizes,
+                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+  ArrayRef<int64_t> reassocGroup =
+      expandShape.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "Reassociation group should have at least one dimension");
+
+  unsigned groupSize = reassocGroup.size();
+  MemRefType expandShapeType = expandShape.getResultType();
+
+  std::optional<int64_t> dynSizeIdx;
+  uint64_t currentStride = 1;
+  SmallVector<OpFoldResult> expandedStrides(groupSize);
+  for (int i = groupSize - 1; i >= 0; --i) {
+    expandedStrides[i] = builder.getIndexAttr(currentStride);
+    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+    if (ShapedType::isDynamic(dimSize)) {
+      assert(!dynSizeIdx && "there must be at most one dynamic size per group");
+      dynSizeIdx = i;
+      continue;
+    }
+    currentStride *= dimSize;
+  }
+
+  auto sourceType = expandShape.getSrcType();
+  auto [strides, offset] = sourceType.getStridesAndOffset();
+  (void)offset;
+
+  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
+                                ? origStrides[groupId]
+                                : builder.getIndexAttr(strides[groupId]);
+
+  int64_t doneStrideIdx = 0;
+  if (dynSizeIdx) {
+    int64_t productOfAllStaticSizes = currentStride;
+    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
+           "dynamic reassociation must originate from dynamic source dim");
+    OpFoldResult origSize = origSizes[groupId];
+
+    AffineExpr s0 = builder.getAffineSymbolExpr(0);
+    AffineExpr s1 = builder.getAffineSymbolExpr(1);
+    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
+      auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
+      assert(baseAttr && "expected attribute stride");
+      int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
+      expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
+          builder, expandShape.getLoc(),
+          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
+          {origSize, origStride});
+    }
+  }
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
+    auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
+    assert(baseAttr && "expected attribute stride");
+    int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
+    expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
+        builder, expandShape.getLoc(), s0 * baseExpandedStride,
+        {origStride});
+  }
+
+  return expandedStrides;
+}
+
 /// Produce an OpFoldResult representing the product of the values or constants
 /// referenced by `indices`. `staticShape` provides the statically known sizes
 /// for the source memref, while `values` contains the mixed (value/attribute)
@@ -426,6 +535,40 @@ struct FlattenCollapseShape final
   }
 };
 
+struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> expandedSizes;
+    SmallVector<OpFoldResult> expandedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    expandedSizes.reserve(op.getResultType().getRank());
+    expandedStrides.reserve(op.getResultType().getRank());
+
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          getExpandedSizes(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+      expandedSizes.append(groupSizes.begin(), groupSizes.end());
+      expandedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+    return success();
+  }
+};
+
 struct FlattenMemrefsPass
     : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
   using Base::Base;
@@ -501,6 +644,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<memref::AllocOp>,
                   MemRefRewritePattern<memref::AllocaOp>,
                   MemRefRewritePattern<memref::DeallocOp>,
+                  FlattenExpandShape,
                   FlattenCollapseShape,
                   FlattenGetGlobal,
                   FlattenGlobal>(
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 5c61bc87418bd..2a5f141dbe328 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -214,7 +214,6 @@ func.func @collapse_shape_dynamic(
   return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
 }
 // CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
 // CHECK-LABEL: func @collapse_shape_dynamic
 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
 // CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
@@ -223,6 +222,36 @@ func.func @collapse_shape_dynamic(
 
 // -----
 
+func.func @expand_shape_static(%arg0: memref<6x4xf32>) -> memref<2x3x4xf32> {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 4]
+      : memref<6x4xf32> into memref<2x3x4xf32>
+  return %0 : memref<2x3x4xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1]
+// CHECK: return %[[REINT]]
+
+// -----
+
+func.func @expand_shape_dynamic(
+    %arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>, %size: index) ->
+    memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>> {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%size, 3, 4]
+      : memref<?x4xf32, strided<[?, ?], offset: ?>>
+        into memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
+  return %0 : memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
+}
+// CHECK: #map = affine_map<()[s0] -> (s0 floordiv 3)>
+// CHECK: #map1 = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @expand_shape_dynamic
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %arg0
+// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#0]
+// CHECK: %[[STRIDE:.*]] = affine.apply #map1()[%[[STRIDES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 3, 4], strides: [%[[STRIDE]], %[[STRIDES]]#0, %[[STRIDES]]#1]
+// CHECK: return %[[REINT]]
+
+// -----
+
 func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
    %c0 = arith.constant 0 : i2
    %0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>

>From d5cda23901336741cc70577bbe15c1b2f1336af6 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 15:32:52 -0400
Subject: [PATCH 4/6] working

---
 .../Dialect/MemRef/Transforms/Transforms.h    |  99 +++++++
 .../Transforms/ExpandStridedMetadata.cpp      |  72 ++---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 245 +++---------------
 3 files changed, 156 insertions(+), 260 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..562b8c11225e8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -14,10 +14,15 @@
 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace mlir {
+class Location;
 class OpBuilder;
 class RewritePatternSet;
 class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
 namespace memref {
 class AllocOp;
 class AllocaOp;
+class CollapseShapeOp;
 class DeallocOp;
+class ExpandShapeOp;
 
 //===----------------------------------------------------------------------===//
 // Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
 memref::AllocaOp allocToAlloca(
     RewriterBase &rewriter, memref::AllocOp alloc,
     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// sizes#i =
+///     baseSizes#groupId / product(expandShapeSizes#j,
+///                                  for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// strides#i =
+///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///                                    reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+///   the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+///   element.)
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId);
+
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+                   ArrayRef<int64_t> maybeConstants,
+                   ArrayRef<OpFoldResult> values,
+                   llvm::function_ref<bool(int64_t)> isDynamic);
+
+/// Compute the collapsed size of the given \p collapseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
+                   ArrayRef<OpFoldResult> origSizes,
+                   ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..6b69d0e366903 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
 
 using namespace mlir;
 using namespace mlir::affine;
+using namespace mlir::memref;
 
 namespace {
 
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
   }
 };
 
-/// Compute the expanded sizes of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+} // namespace
+
+namespace mlir {
+namespace memref {
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
   return expandedSizes;
 }
 
-/// Compute the expanded strides of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
-///
-/// strides#i =
-///     origStrides#reassDim * product(expandShapeSizes#j, for j in
-///                                    reassIdx#i+1..reassIdx#i+group.size-1)
-///
-/// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
                                              OpBuilder &builder,
                                              ArrayRef<OpFoldResult> origSizes,
                                              ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   return expandedStrides;
 }
 
-/// Produce an OpFoldResult object with \p builder at \p loc representing
-/// `prod(valueOrConstant#i, for i in {indices})`,
-/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
-/// values[i] otherwise.
-///
-/// \pre for all index in indices: index < values.size()
-/// \pre for all index in indices: index < maybeConstants.size()
-static OpFoldResult
+OpFoldResult
 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
                    ArrayRef<int64_t> maybeConstants,
                    ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
   SmallVector<OpFoldResult> collapsedSize;
 
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-static SmallVector<OpFoldResult>
-getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
                    ArrayRef<OpFoldResult> origSizes,
                    ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 
   return {lastValidStride};
 }
+} // namespace memref
+} // namespace mlir
+
+namespace {
 
 /// From `reshape_like(memref, subSizes, subStrides))` compute
 ///
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 5394df5b46fb2..f17a0582919fc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -96,208 +96,6 @@ static bool checkLayout(Value val) {
          isa<StridedLayoutAttr>(type.getLayout());
 }
 
-/// Compute the expanded sizes of the given expand_shape for the reassociation
-/// group `groupId`. Portions adapted from
-/// `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
-/// dependency from the MemRef dialect on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
-  ArrayRef<int64_t> reassocGroup =
-      expandShape.getReassociationIndices()[groupId];
-  assert(!reassocGroup.empty() &&
-         "Reassociation group should have at least one dimension");
-
-  unsigned groupSize = reassocGroup.size();
-  SmallVector<OpFoldResult> expandedSizes(groupSize);
-
-  uint64_t productOfAllStaticSizes = 1;
-  std::optional<unsigned> dynSizeIdx;
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  for (unsigned i = 0; i < groupSize; ++i) {
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "there must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-    productOfAllStaticSizes *= dimSize;
-    expandedSizes[i] = builder.getIndexAttr(dimSize);
-  }
-
-  if (dynSizeIdx) {
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
-        origSizes[groupId]);
-  }
-
-  return expandedSizes;
-}
-
-/// Compute the expanded strides of the given expand_shape for the reassociation
-/// group `groupId`.
-static SmallVector<OpFoldResult>
-getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder,
-                   ArrayRef<OpFoldResult> origSizes,
-                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
-  ArrayRef<int64_t> reassocGroup =
-      expandShape.getReassociationIndices()[groupId];
-  assert(!reassocGroup.empty() &&
-         "Reassociation group should have at least one dimension");
-
-  unsigned groupSize = reassocGroup.size();
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-  uint64_t currentStride = 1;
-  SmallVector<OpFoldResult> expandedStrides(groupSize);
-  for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "there must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-    currentStride *= dimSize;
-  }
-
-  auto sourceType = expandShape.getSrcType();
-  auto [strides, offset] = sourceType.getStridesAndOffset();
-  (void)offset;
-
-  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
-                                ? origStrides[groupId]
-                                : builder.getIndexAttr(strides[groupId]);
-
-  int64_t doneStrideIdx = 0;
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "dynamic reassociation must originate from dynamic source dim");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
-      assert(baseAttr && "expected attribute stride");
-      int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
-      expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
-    assert(baseAttr && "expected attribute stride");
-    int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
-    expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride,
-        {origStride});
-  }
-
-  return expandedStrides;
-}
-
-/// Produce an OpFoldResult representing the product of the values or constants
-/// referenced by `indices`. `staticShape` provides the statically known sizes
-/// for the source memref, while `values` contains the mixed (value/attribute)
-/// representation produced by `memref.extract_strided_metadata`.
-static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices,
-                                       OpBuilder &builder, Location loc,
-                                       ArrayRef<int64_t> staticShape,
-                                       ArrayRef<OpFoldResult> values) {
-  AffineExpr product = builder.getAffineConstantExpr(1);
-  SmallVector<OpFoldResult> inputs;
-  unsigned numSymbols = 0;
-  for (int64_t idx : indices) {
-    product = product * builder.getAffineSymbolExpr(numSymbols++);
-    if (ShapedType::isDynamic(staticShape[idx]))
-      inputs.push_back(values[idx]);
-    else
-      inputs.push_back(builder.getIndexAttr(staticShape[idx]));
-  }
-  return affine::makeComposedFoldedAffineApply(builder, loc, product, inputs);
-}
-
-/// Return the collapsed size (as OpFoldResult) for the reassociation group
-/// `groupId` of `collapseShapeOp`.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
-  SmallVector<OpFoldResult> collapsedSize;
-
-  MemRefType resultType = collapseShapeOp.getResultType();
-  int64_t dimSize = resultType.getDimSize(groupId);
-  if (!ShapedType::isDynamic(dimSize)) {
-    collapsedSize.push_back(builder.getIndexAttr(dimSize));
-    return collapsedSize;
-  }
-
-  auto sourceType = collapseShapeOp.getSrcType();
-  ArrayRef<int64_t> staticShape = sourceType.getShape();
-  ArrayRef<int64_t> reassocGroup =
-      collapseShapeOp.getReassociationIndices()[groupId];
-
-  collapsedSize.push_back(getProductOfValues(reassocGroup, builder,
-                                             collapseShapeOp.getLoc(),
-                                             staticShape, origSizes));
-  return collapsedSize;
-}
-
-/// Return the collapsed stride (as OpFoldResult) for the reassociation group
-/// `groupId` of `collapseShapeOp`.
-static SmallVector<OpFoldResult> getCollapsedStride(
-    memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
-    ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides,
-    unsigned groupId) {
-  ArrayRef<int64_t> reassocGroup =
-      collapseShapeOp.getReassociationIndices()[groupId];
-  assert(!reassocGroup.empty() &&
-         "reassociation group must contain at least one dimension");
-
-  auto sourceType = collapseShapeOp.getSrcType();
-  auto [strides, offset] = sourceType.getStridesAndOffset();
-  (void)offset;
-  ArrayRef<int64_t> srcShape = sourceType.getShape();
-
-  OpFoldResult lastValidStride = nullptr;
-  for (int64_t dim : reassocGroup) {
-    if (srcShape[dim] == 1)
-      continue;
-    int64_t currentStride = strides[dim];
-    if (ShapedType::isDynamic(currentStride))
-      lastValidStride = origStrides[dim];
-    else
-      lastValidStride = builder.getIndexAttr(currentStride);
-  }
-
-  if (!lastValidStride) {
-    MemRefType collapsedType = collapseShapeOp.getResultType();
-    auto [collapsedStrides, collapsedOffset] =
-        collapsedType.getStridesAndOffset();
-    (void)collapsedOffset;
-    int64_t finalStride = collapsedStrides[groupId];
-    if (ShapedType::isDynamic(finalStride)) {
-      for (int64_t dim : reassocGroup) {
-        assert(srcShape[dim] == 1 && "expected size-one dimensions");
-        if (ShapedType::isDynamic(strides[dim]))
-          return {origStrides[dim]};
-      }
-      llvm_unreachable("expected to find a dynamic stride");
-    }
-    return {builder.getIndexAttr(finalStride)};
-  }
-
-  return {lastValidStride};
-}
-
 namespace {
 static Value getTargetMemref(Operation *op) {
   return llvm::TypeSwitch<Operation *, Value>(op)
@@ -521,9 +319,9 @@ struct FlattenCollapseShape final
     collapsedStrides.reserve(numGroups);
     for (unsigned i = 0; i < numGroups; ++i) {
       SmallVector<OpFoldResult> groupSizes =
-          getCollapsedSize(op, rewriter, origSizes, i);
+          memref::getCollapsedSize(op, rewriter, origSizes, i);
       SmallVector<OpFoldResult> groupStrides =
-          getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+          memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
       collapsedSizes.append(groupSizes.begin(), groupSizes.end());
       collapsedStrides.append(groupStrides.begin(), groupStrides.end());
     }
@@ -556,9 +354,9 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
 
     for (unsigned i = 0; i < numGroups; ++i) {
       SmallVector<OpFoldResult> groupSizes =
-          getExpandedSizes(op, rewriter, origSizes, i);
+          memref::getExpandedSizes(op, rewriter, origSizes, i);
       SmallVector<OpFoldResult> groupStrides =
-          getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+          memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
       expandedSizes.append(groupSizes.begin(), groupSizes.end());
       expandedStrides.append(groupStrides.begin(), groupStrides.end());
     }
@@ -569,6 +367,40 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
   }
 };
 
+
+/*
+// Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
+struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
+      return rewriter.notifyMatchFailure(
+          op, "expected converted memref of rank <= 1");
+    }
+    Type neededResultType =
+        getTypeConverter()->convertType(op.getResult().getType());
+    if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
+      return failure();
+    Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
+                                              op.getLoc(), rewriter);
+    SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
+        rewriter, op.getLoc(), op.getMixedOffsets());
+    Value linearOffset =
+        linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
+    Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
+    Value newSubView = memref::SubViewOp::create(
+        rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
+        ValueRange({size}), ValueRange({stride}));
+    rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
+                                                newSubView);
+    return success();
+  }
+};
+*/
+
 struct FlattenMemrefsPass
     : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
   using Base::Base;
@@ -646,6 +478,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<memref::DeallocOp>,
                   FlattenExpandShape,
                   FlattenCollapseShape,
+                  //FlattenSubView,
                   FlattenGetGlobal,
                   FlattenGlobal>(
       patterns.getContext());

>From 05cdabb23bf25498c9b4ea5e3bb3b7b1102d037b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 21:56:10 -0400
Subject: [PATCH 5/6] subview

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 181 ++++++++++++++----
 mlir/test/Dialect/MemRef/flatten_memref.mlir  |  13 ++
 2 files changed, 152 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index f17a0582919fc..43a67f1fab2be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
@@ -47,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
   return cast<Value>(in);
 }
 
+
 /// Returns a collapsed memref and the linearized index to access the element
 /// at the specified indices.
 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -90,12 +92,15 @@ static bool needFlattening(Value val) {
   return type.getRank() > 1;
 }
 
-static bool checkLayout(Value val) {
-  auto type = cast<MemRefType>(val.getType());
+static bool checkLayout(MemRefType type) {
   return type.getLayout().isIdentity() ||
          isa<StridedLayoutAttr>(type.getLayout());
 }
 
+static bool checkLayout(Value val) {
+  return checkLayout(cast<MemRefType>(val.getType()));
+}
+
 namespace {
 static Value getTargetMemref(Operation *op) {
   return llvm::TypeSwitch<Operation *, Value>(op)
@@ -368,38 +373,131 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
 };
 
 
-/*
-// Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
-struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
-  using OpConversionPattern::OpConversionPattern;
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
-      return rewriter.notifyMatchFailure(
-          op, "expected converted memref of rank <= 1");
-    }
-    Type neededResultType =
-        getTypeConverter()->convertType(op.getResult().getType());
-    if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+    if (!sourceType || sourceType.getRank() <= 1)
+      return failure();
+    if (!checkLayout(sourceType))
       return failure();
-    Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
-                                              op.getLoc(), rewriter);
-    SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
-        rewriter, op.getLoc(), op.getMixedOffsets());
-    Value linearOffset =
-        linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
-    Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
-    Value newSubView = memref::SubViewOp::create(
-        rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
-        ValueRange({size}), ValueRange({stride}));
-    rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
-                                                newSubView);
+
+    MemRefType resultType = op.getType();
+    if (resultType.getRank() <= 1 || !checkLayout(resultType))
+      return failure();
+
+    unsigned elementBitWidth = sourceType.getElementTypeBitWidth();
+    if (!elementBitWidth)
+      return failure();
+
+    Location loc = op.getLoc();
+
+    // Materialize offsets as values so they can participate in linearization.
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
+
+    SmallVector<Value> offsetValues;
+    offsetValues.reserve(mixedOffsets.size());
+    for (OpFoldResult ofr : mixedOffsets)
+      offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
+
+    auto [flatSource, linearOffset] =
+        getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
+                                  ValueRange(offsetValues));
+
+    memref::ExtractStridedMetadataOp sourceMetadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
+
+    SmallVector<OpFoldResult> sourceStrides =
+        sourceMetadata.getConstifiedMixedStrides();
+    OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset();
+
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+    SmallVector<OpFoldResult> resultSizes;
+    SmallVector<OpFoldResult> resultStrides;
+    resultSizes.reserve(resultType.getRank());
+    resultStrides.reserve(resultType.getRank());
+
+    OpFoldResult resultOffset = sourceOffset;
+    for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
+             mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
+      auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
+      OpFoldResult contribution = [&]() -> OpFoldResult {
+        if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
+          if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+            auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+            auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+            return rewriter.getIndexAttr(offsetInt * strideInt);
+          }
+        }
+        Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr);
+        Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+        return rewriter.create<arith::MulIOp>(loc, offsetVal, strideVal)
+            .getResult();
+      }();
+      resultOffset = [&]() -> OpFoldResult {
+        if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
+          if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
+            auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+            auto contribInt = cast<IntegerAttr>(contribAttr).getInt();
+            return rewriter.getIndexAttr(offsetInt + contribInt);
+          }
+        }
+        Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
+        Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+        return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
+            .getResult();
+      }();
+
+      if (droppedDims.test(idx))
+        continue;
+
+      resultSizes.push_back(sizeOfr);
+      OpFoldResult combinedStride = [&]() -> OpFoldResult {
+        if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
+          if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+            auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt();
+            auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+            return rewriter.getIndexAttr(relStrideInt * strideInt);
+          }
+        }
+        Value relStrideVal =
+            getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr);
+        Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+        return rewriter.create<arith::MulIOp>(loc, relStrideVal, strideVal)
+            .getResult();
+      }();
+      resultStrides.push_back(combinedStride);
+    }
+
+    memref::LinearizedMemRefInfo linearizedInfo;
+    [[maybe_unused]] OpFoldResult linearizedIndex;
+    std::tie(linearizedInfo, linearizedIndex) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
+            resultSizes, resultStrides);
+
+    Value flattenedSize = getValueFromOpFoldResult(
+        rewriter, loc, linearizedInfo.linearizedSize);
+    Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+
+    Value flattenedSubview = memref::SubViewOp::create(
+        rewriter, loc, flatSource, ValueRange{linearOffset},
+        ValueRange{flattenedSize}, ValueRange{strideOne});
+
+    Value replacement = memref::ReinterpretCastOp::create(
+        rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
+        resultStrides);
+
+    rewriter.replaceOp(op, replacement);
     return success();
   }
 };
-*/
 
 struct FlattenMemrefsPass
     : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -422,18 +520,6 @@ struct FlattenMemrefsPass
 
 } // namespace
 
-void memref::populateFlattenVectorOpsOnMemrefPatterns(
-    RewritePatternSet &patterns) {
-  patterns.insert<MemRefRewritePattern<vector::LoadOp>,
-                  MemRefRewritePattern<vector::StoreOp>,
-                  MemRefRewritePattern<vector::TransferReadOp>,
-                  MemRefRewritePattern<vector::TransferWriteOp>,
-                  MemRefRewritePattern<vector::MaskedLoadOp>,
-                  MemRefRewritePattern<vector::MaskedStoreOp>>(
-      patterns.getContext());
-}
-
-/// Special pattern for GetGlobalOp to avoid infinite loops
 struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -470,6 +556,17 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
   }
 };
 
+void memref::populateFlattenVectorOpsOnMemrefPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<vector::LoadOp>,
+                  MemRefRewritePattern<vector::StoreOp>,
+                  MemRefRewritePattern<vector::TransferReadOp>,
+                  MemRefRewritePattern<vector::TransferWriteOp>,
+                  MemRefRewritePattern<vector::MaskedLoadOp>,
+                  MemRefRewritePattern<vector::MaskedStoreOp>>(
+      patterns.getContext());
+}
+
 void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
   patterns.insert<MemRefRewritePattern<memref::LoadOp>,
                   MemRefRewritePattern<memref::StoreOp>,
@@ -478,7 +575,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<memref::DeallocOp>,
                   FlattenExpandShape,
                   FlattenCollapseShape,
-                  //FlattenSubView,
+                  FlattenSubView,
                   FlattenGetGlobal,
                   FlattenGlobal>(
       patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 2a5f141dbe328..05290f96f45e9 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -194,6 +194,19 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
 
 // -----
 
+func.func @flatten_subview_static(%arg0: memref<3x4xf32, strided<[4, 1], offset: 0>>) -> memref<2x2xf32, strided<[4, 1], offset: 1>> {
+  %sub = memref.subview %arg0[0, 1] [2, 2] [1, 1]
+      : memref<3x4xf32, strided<[4, 1], offset: 0>> to memref<2x2xf32, strided<[4, 1], offset: 1>>
+  return %sub : memref<2x2xf32, strided<[4, 1], offset: 1>>
+}
+// CHECK-LABEL: func @flatten_subview_static
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[FLAT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12], strides: [1]
+// CHECK: %[[SUB:.*]] = memref.subview %[[FLAT]][%[[C1]]] [%[[C8]]] [%[[C1]]]
+// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[SUB]] to offset: [1], sizes: [2, 2], strides: [4, 1]
+// CHECK: return %[[CAST]]
+
 func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
       : memref<2x3x4xf32> into memref<6x4xf32>

>From 70690626399a48b9801f854ec5d311b4c0e8106b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 19 Sep 2025 17:58:44 -0400
Subject: [PATCH 6/6] Fix format and C++20 issue

---
 .../Dialect/MemRef/Transforms/Transforms.h    | 34 +++++-----
 .../Transforms/ExpandStridedMetadata.cpp      | 34 +++++-----
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 67 ++++++++++---------
 3 files changed, 71 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 562b8c11225e8..d40cb5ee8a064 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,8 +18,8 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 class Location;
@@ -236,9 +236,10 @@ memref::AllocaOp allocToAlloca(
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId);
 
 /// Compute the expanded strides of the given \p expandShape for the
 /// \p groupId-th reassociation group.
@@ -277,11 +278,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
 ///
 /// \pre for all index in indices: index < values.size()
 /// \pre for all index in indices: index < maybeConstants.size()
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
-                   ArrayRef<int64_t> maybeConstants,
-                   ArrayRef<OpFoldResult> values,
-                   llvm::function_ref<bool(int64_t)> isDynamic);
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+                                Location loc, ArrayRef<int64_t> maybeConstants,
+                                ArrayRef<OpFoldResult> values,
+                                llvm::function_ref<bool(int64_t)> isDynamic);
 
 /// Compute the collapsed size of the given \p collapseShape for the
 /// \p groupId-th reassociation group.
@@ -291,9 +291,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId);
 
 /// Compute the collapsed stride of the given \p collpaseShape for the
 /// \p groupId-th reassociation group.
@@ -307,10 +308,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
-                   ArrayRef<OpFoldResult> origSizes,
-                   ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId);
 
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 6b69d0e366903..96bceae88af9d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -21,10 +21,10 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
@@ -256,9 +256,10 @@ struct ExtractStridedMetadataOpSubviewFolder
 
 namespace mlir {
 namespace memref {
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       expandShape.getReassociationIndices()[groupId];
   assert(!reassocGroup.empty() &&
@@ -372,11 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
   return expandedStrides;
 }
 
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
-                   ArrayRef<int64_t> maybeConstants,
-                   ArrayRef<OpFoldResult> values,
-                   llvm::function_ref<bool(int64_t)> isDynamic) {
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+                                Location loc, ArrayRef<int64_t> maybeConstants,
+                                ArrayRef<OpFoldResult> values,
+                                llvm::function_ref<bool(int64_t)> isDynamic) {
   AffineExpr productOfValues = builder.getAffineConstantExpr(1);
   SmallVector<OpFoldResult> inputValues;
   unsigned numberOfSymbols = 0;
@@ -410,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId) {
   SmallVector<OpFoldResult> collapsedSize;
 
   MemRefType collapseShapeType = collapseShape.getResultType();
@@ -451,10 +452,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
-                   ArrayRef<OpFoldResult> origSizes,
-                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       collapseShape.getReassociationIndices()[groupId];
   assert(!reassocGroup.empty() &&
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 43a67f1fab2be..1e6805139f7ff 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,9 +21,9 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
   return cast<Value>(in);
 }
 
-
 /// Returns a collapsed memref and the linearized index to access the element
 /// at the specified indices.
 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
     return {};
   }
 
-  LogicalResult
-  matchAndRewrite(memref::GlobalOp globalOp,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+                                PatternRewriter &rewriter) const override {
     auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
     if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
       return failure();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
 
     SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
     OpFoldResult offset = metadata.getConstifiedMixedOffset();
 
     SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
   }
 };
 
-struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+struct FlattenExpandShape final
+    : public OpRewritePattern<memref::ExpandShapeOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
 
     SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
     OpFoldResult offset = metadata.getConstifiedMixedOffset();
 
     SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
   }
 };
 
-
 // Flattens memref subview ops with more than 1 dimension into 1-D accesses.
 struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
     for (OpFoldResult ofr : mixedOffsets)
       offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
 
-    auto [flatSource, linearOffset] =
-        getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
-                                  ValueRange(offsetValues));
+    auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+        rewriter, loc, op.getSource(), ValueRange(offsetValues));
 
     memref::ExtractStridedMetadataOp sourceMetadata =
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
@@ -424,9 +423,14 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
     resultStrides.reserve(resultType.getRank());
 
     OpFoldResult resultOffset = sourceOffset;
-    for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
+    for (auto zipped : llvm::enumerate(llvm::zip_equal(
              mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
-      auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
+      auto idx = zipped.index();
+      auto it = zipped.value();
+      auto offsetOfr = std::get<0>(it);
+      auto strideOfr = std::get<1>(it);
+      auto sizeOfr = std::get<2>(it);
+      auto relativeStrideOfr = std::get<3>(it);
       OpFoldResult contribution = [&]() -> OpFoldResult {
         if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
           if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
@@ -449,7 +453,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
           }
         }
         Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
-        Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+        Value contribVal =
+            getValueFromOpFoldResult(rewriter, loc, contribution);
         return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
             .getResult();
       }();
@@ -478,12 +483,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     [[maybe_unused]] OpFoldResult linearizedIndex;
     std::tie(linearizedInfo, linearizedIndex) =
-        memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
-            resultSizes, resultStrides);
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, resultOffset,
+                                                 resultSizes, resultStrides);
 
-    Value flattenedSize = getValueFromOpFoldResult(
-        rewriter, loc, linearizedInfo.linearizedSize);
+    Value flattenedSize =
+        getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
     Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
 
     Value flattenedSubview = memref::SubViewOp::create(
@@ -524,10 +529,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::GetGlobalOp op,
-                               PatternRewriter &rewriter) const override {
+                                PatternRewriter &rewriter) const override {
     // Check if this get_global references a multi-dimensional global
     auto module = op->template getParentOfType<ModuleOp>();
-    auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+    auto globalOp =
+        module.template lookupSymbol<memref::GlobalOp>(op.getName());
     if (!globalOp) {
       return failure();
     }
@@ -537,12 +543,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
 
     // Only apply if the global has been flattened but the get_global hasn't
     if (globalType.getRank() == 1 && resultType.getRank() > 1) {
-      auto newGetGlobal = memref::GetGlobalOp::create(
-          rewriter, op.getLoc(), globalType, op.getName());
+      auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(),
+                                                      globalType, op.getName());
 
       // Cast the flattened result back to the original shape
       memref::ExtractStridedMetadataOp stridedMetadata =
-          memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+          memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(),
+                                                   op.getResult());
       auto castResult = memref::ReinterpretCastOp::create(
           rewriter, op.getLoc(), resultType, newGetGlobal,
           /*offset=*/rewriter.getIndexAttr(0),
@@ -572,13 +579,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<memref::StoreOp>,
                   MemRefRewritePattern<memref::AllocOp>,
                   MemRefRewritePattern<memref::AllocaOp>,
-                  MemRefRewritePattern<memref::DeallocOp>,
-                  FlattenExpandShape,
-                  FlattenCollapseShape,
-                  FlattenSubView,
-                  FlattenGetGlobal,
-                  FlattenGlobal>(
-      patterns.getContext());
+                  MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
+                  FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
+                  FlattenGlobal>(patterns.getContext());
 }
 
 void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {



More information about the Mlir-commits mailing list