[Mlir-commits] [mlir] [MLIR] Add a utility pass to linearize `memref` (PR #136797)

Alan Li llvmlistbot at llvm.org
Tue May 6 09:43:43 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/136797

>From b57c07fea33fb1f696b6410d14e1fbfff23babf6 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 14:42:38 -0400
Subject: [PATCH 01/10] First commit.

---
 .../include/mlir/Dialect/MemRef/Transforms/Passes.td | 12 ++++++++++++
 .../mlir/Dialect/MemRef/Transforms/Transforms.h      |  3 +++
 mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt    |  2 ++
 3 files changed, 17 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a46f73350bb3c..a2a9047bda808 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -245,5 +245,17 @@ def ExpandReallocPass : Pass<"expand-realloc"> {
   ];
 }
 
+def FlattenMemrefsPass : Pass<"flatten-memref"> {
+  let summary = "Flatten a multiple dimensional memref to 1-dimensional";
+  let description = [{
+
+  }];
+
+  let constructor = "mlir::memref::createFlattenMemrefsPass()";
+  let dependentDialects = [
+      "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 62a2297c80e78..0649bf9c099f9 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,6 +144,9 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
+
+void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
+
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
 /// given independencies. If the op is already independent of all
 /// independencies, the same AllocaOp result is returned.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ecab97bc2b8e7..48e8bccd369fa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
   ExtractAddressComputations.cpp
+  FlattenMemRefs.cpp
   FoldMemRefAliasOps.cpp
   IndependenceTransforms.cpp
   MultiBuffer.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineTransforms
+  MLIRAffineDialect
   MLIRAffineUtils
   MLIRArithDialect
   MLIRArithTransforms

>From c003e70ea3d9d43a0679304fdfeceeb4f18940cb Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 14:45:51 -0400
Subject: [PATCH 02/10] Adding test cases

---
 mlir/test/Dialect/MemRef/flatten_memref.mlir | 225 +++++++++++++++++++
 1 file changed, 225 insertions(+)
 create mode 100644 mlir/test/Dialect/MemRef/flatten_memref.mlir

diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
new file mode 100644
index 0000000000000..6c9b09985acf7
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt --flatten-memref %s --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) -> f32 {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>>
+  return %value : f32
+}
+// CHECK: func @load_scalar_from_memref
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
+// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 {
+  %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
+  return %value : f32
+}
+// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
+// CHECK: func @load_scalar_from_memref_static_dim_2
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [12]
+// CHECK-SAME: to memref<32xf32, strided<[12], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index) -> f32 {
+  %value = memref.load %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %value : f32
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @load_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
+// CHECK: memref.load %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> {
+  %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
+  return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
+}
+// CHECK: func @load_scalar_from_memref_subview
+
+// -----
+
+func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
+  memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
+// CHECK: func @store_scalar_from_memref_static_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[12], offset: 100>>
+
+// -----
+
+func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index, %value: f32) {
+  memref.store %value, %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @store_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> {
+  %c3 = arith.constant 3 : index
+  %c6 = arith.constant 6 : index
+  %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
+  return %value : vector<8xf32>
+}
+// CHECK: func @load_vector_from_memref
+// CHECK: %[[C30:.*]] = arith.constant 30
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1]
+// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]]
+
+// -----
+
+func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+  return %value : vector<3xi2>
+}
+// CHECK: func @load_vector_from_memref_odd
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]]
+
+// -----
+
+func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> {
+  %value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+  return %value : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @load_vector_from_memref_dynamic
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.load %[[REINT]][%[[IDX]]] : memref<21xi2, strided<[1]>>, vector<3xi2>
+
+// -----
+
+func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+  return
+}
+// CHECK: func @store_vector_to_memref_odd
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
+
+// -----
+
+func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) {
+  vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.store %[[ARG1]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  vector.maskedstore %input[%c1, %c3], %mask, %value  : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+  return
+}
+// CHECK: func @mask_store_vector_to_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.maskedstore %[[REINT]][%[[C10]]], %[[ARG2]], %[[ARG1]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) {
+  vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+  return
+}
+// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: vector<3xi1>)
+// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedstore %[[REINT]][%[[IDX]]], %[[ARG4]], %[[ARG1]]
+
+// -----
+func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %result : vector<3xi2>
+}
+// CHECK: func @mask_load_vector_from_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.maskedload %[[REINT]][%[[C10]]], %[[MASK]], %[[PASSTHRU]]
+
+// -----
+
+func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+  %result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %result : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_load_vector_from_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<3xi1>, %[[ARG4:.*]]: vector<3xi2>)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedload %[[REINT]][%[[IDX]]], %[[ARG3]]
+
+// -----
+
+func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
+   %c0 = arith.constant 0 : i2
+   %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
+   return %0 : vector<8xi2>
+}
+// CHECK: func @transfer_read_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[C0:.*]] = arith.constant 0 : i2
+// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]]
+
+// -----
+
+func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) {
+   vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2>
+   return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK: func @transfer_write_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]]

>From 9170789b65289d8fe8f53253a2d989f9b33cd630 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 16:23:39 -0400
Subject: [PATCH 03/10] missing the file

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 357 ++++++++++++++++++
 1 file changed, 357 insertions(+)
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
new file mode 100644
index 0000000000000..6685896624536
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -0,0 +1,357 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass  ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+  if (auto *parentOp = val.getDefiningOp()) {
+    builder.setInsertionPointAfter(parentOp);
+  } else {
+    builder.setInsertionPointToStart(val.getParentBlock());
+  }
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
+                  OpFoldResult>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+                        ArrayRef<OpFoldResult> subOffsets,
+                        ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+  auto sourceType = cast<MemRefType>(source.getType());
+  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    setInsertionPointToStart(rewriter, source);
+    newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+  }
+
+  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
+
+  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+    return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+                                      : rewriter.getIndexAttr(dim);
+  };
+
+  OpFoldResult origOffset =
+      getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+  OpFoldResult outmostDim =
+      getDim(sourceType.getShape().front(),
+             newExtractStridedMetadata.getSizes().front());
+
+  SmallVector<OpFoldResult> origStrides;
+  origStrides.reserve(sourceRank);
+
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(sourceRank);
+
+  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+  for (auto i : llvm::seq(0u, sourceRank)) {
+    OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+    if (!subStrides.empty()) {
+      strides.push_back(affine::makeComposedFoldedAffineApply(
+          rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+    }
+
+    origStrides.emplace_back(origStride);
+  }
+
+  // Compute linearized index:
+  auto &&[expr, values] =
+      computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
+  OpFoldResult linearizedIndex =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+
+  // Compute collapsed size: (the outmost stride * outmost dimension).
+  SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
+  OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
+
+  return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
+          origStrides, origOffset, collapsedSize};
+}
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+                                      OpFoldResult in) {
+  if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+    return rewriter.create<arith::ConstantIndexOp>(
+        loc, cast<IntegerAttr>(offsetAttr).getInt());
+  }
+  return cast<Value>(in);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+                                                         Location loc,
+                                                         Value source,
+                                                         ValueRange indices) {
+  auto &&[base, index, strides, offset, collapsedShape] =
+      getFlatOffsetAndStrides(rewriter, loc, source,
+                              getAsOpFoldResult(indices));
+
+  return std::make_pair(
+      rewriter.create<memref::ReinterpretCastOp>(
+          loc, source,
+          /* offset = */ offset,
+          /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
+          /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
+      getValueFromOpFoldResult(rewriter, loc, index));
+}
+
+static bool needFlattenning(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getLayout().isIdentity() ||
+         isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+template <typename T>
+static Value getTargetMemref(T op) {
+  if constexpr (std::is_same_v<T, memref::LoadOp>) {
+    return op.getMemref();
+  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+    return op.getMemref();
+  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+    return op.getSource();
+  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+    return op.getSource();
+  }
+  return {};
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+                      Value offset) {
+  if constexpr (std::is_same_v<T, memref::LoadOp>) {
+    auto newLoad = rewriter.create<memref::LoadOp>(
+        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    newLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newLoad.getResult());
+  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    newLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newLoad.getResult());
+  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+    auto newStore = rewriter.create<memref::StoreOp>(
+        op->getLoc(), op->getOperands().front(), flatMemref,
+        ValueRange{offset});
+    newStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newStore);
+  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+    auto newStore = rewriter.create<vector::StoreOp>(
+        op->getLoc(), op->getOperands().front(), flatMemref,
+        ValueRange{offset});
+    newStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newStore);
+  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+    auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+        op.getPadding());
+    rewriter.replaceOp(op, newTransferRead.getResult());
+  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+    auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+        op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
+    rewriter.replaceOp(op, newTransferWrite);
+  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+    auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+        op.getMask(), op.getPassThru());
+    newMaskedLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newMaskedLoad.getResult());
+  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+    auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+        op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
+        op.getValueToStore());
+    newMaskedStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newMaskedStore);
+  } else {
+    op.emitOpError("unimplemented: do not know how to replace op.");
+  }
+}
+
+template <typename T>
+struct MemRefRewritePatternBase : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    Value memref = getTargetMemref<T>(op);
+    if (!needFlattenning(memref) || !checkLayout(memref))
+      return rewriter.notifyMatchFailure(op,
+                                         "nothing to do or unsupported layout");
+    auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
+        rewriter, op->getLoc(), memref, op.getIndices());
+    replaceOp<T>(op, rewriter, flatMemref, offset);
+    return success();
+  }
+};
+
+struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
+  using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
+  using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
+  using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
+  using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedLoad
+    : public MemRefRewritePatternBase<vector::MaskedLoadOp> {
+  using MemRefRewritePatternBase<
+      vector::MaskedLoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedStore
+    : public MemRefRewritePatternBase<vector::MaskedStoreOp> {
+  using MemRefRewritePatternBase<
+      vector::MaskedStoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferRead
+    : public MemRefRewritePatternBase<vector::TransferReadOp> {
+  using MemRefRewritePatternBase<
+      vector::TransferReadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferWrite
+    : public MemRefRewritePatternBase<vector::TransferWriteOp> {
+  using MemRefRewritePatternBase<
+      vector::TransferWriteOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    Value memref = op.getSource();
+    if (!needFlattenning(memref))
+      return rewriter.notifyMatchFailure(op, "nothing to do");
+
+    if (!checkLayout(memref))
+      return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+    Location loc = op.getLoc();
+    SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+    auto &&[base, finalOffset, strides, _, __] =
+        getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+    auto srcType = cast<MemRefType>(memref.getType());
+    auto resultType = cast<MemRefType>(op.getType());
+    unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+    SmallVector<OpFoldResult> finalSizes;
+    finalSizes.reserve(subRank);
+
+    SmallVector<OpFoldResult> finalStrides;
+    finalStrides.reserve(subRank);
+
+    for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+      if (droppedDims.test(i))
+        continue;
+
+      finalSizes.push_back(subSizes[i]);
+      finalStrides.push_back(strides[i]);
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, resultType, base, finalOffset, finalSizes, finalStrides);
+    return success();
+  }
+};
+
+struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    memref::populateFlattenMemrefsPatterns(patterns);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
+                  FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
+                  FlattenVectorLoad, FlattenVectorStore,
+                  FlattenVectorTransferRead, FlattenVectorTransferWrite>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
+  return std::make_unique<FlattenMemrefsPass>();
+}
+

>From abe4379bbcd525eee99bfccedc6864ca1cab14a5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 21:00:29 -0400
Subject: [PATCH 04/10] Fix linking issue

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 48 ++++++++++---------
 .../MemRef/Transforms/FlattenMemRefs.cpp      |  2 +-
 2 files changed, 27 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index eb23403a68813..e729e73f6ae0c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5083,6 +5083,31 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
   return ret;
 }
 
+namespace mlir {
+namespace affine {
+OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+                            ArrayRef<OpFoldResult> terms) {
+  int64_t nDynamic = 0;
+  SmallVector<Value> dynamicPart;
+  AffineExpr result = builder.getAffineConstantExpr(1);
+  for (OpFoldResult term : terms) {
+    if (!term)
+      return term;
+    std::optional<int64_t> maybeConst = getConstantIntValue(term);
+    if (maybeConst) {
+      result = result * builder.getAffineConstantExpr(*maybeConst);
+    } else {
+      dynamicPart.push_back(cast<Value>(term));
+      result = result * builder.getAffineSymbolExpr(nDynamic++);
+    }
+  }
+  if (auto constant = dyn_cast<AffineConstantExpr>(result))
+    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+} // namespace affine
+} // namespace mlir
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -5142,27 +5167,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-OpFoldResult computeProduct(Location loc, OpBuilder &builder,
-                            ArrayRef<OpFoldResult> terms) {
-  int64_t nDynamic = 0;
-  SmallVector<Value> dynamicPart;
-  AffineExpr result = builder.getAffineConstantExpr(1);
-  for (OpFoldResult term : terms) {
-    if (!term)
-      return term;
-    std::optional<int64_t> maybeConst = getConstantIntValue(term);
-    if (maybeConst) {
-      result = result * builder.getAffineConstantExpr(*maybeConst);
-    } else {
-      dynamicPart.push_back(cast<Value>(term));
-      result = result * builder.getAffineSymbolExpr(nDynamic++);
-    }
-  }
-  if (auto constant = dyn_cast<AffineConstantExpr>(result))
-    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
-  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
-}
-
 /// If conseceutive outputs of a delinearize_index are linearized with the same
 /// bounds, canonicalize away the redundant arithmetic.
 ///
@@ -5309,7 +5313,7 @@ struct CancelLinearizeOfDelinearizePortion final
       // We use the slice from the linearize's basis above because of the
       // "bounds inferred from `disjoint`" case above.
       OpFoldResult newSize =
-          computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
+          affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
 
       // Trivial case where we can just skip past the delinearize all together
       if (m.length == m.delinearize.getNumResults()) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 6685896624536..8299dc1716121 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -103,7 +103,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
 
   // Compute collapsed size: (the outmost stride * outmost dimension).
   SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
-  OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
+  OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
 
   return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
           origStrides, origOffset, collapsedSize};

>From d4da14ad207ca57cf3d76b2d6e22c948d47d3845 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 22:47:17 -0400
Subject: [PATCH 05/10] linting

---
 mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h | 1 -
 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp    | 7 +++----
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 0649bf9c099f9..c2b8cb05be922 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,7 +144,6 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
-
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 8299dc1716121..5ec524967444a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -14,8 +14,8 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -28,7 +28,6 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-
 namespace mlir {
 namespace memref {
 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
@@ -323,7 +322,8 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
   }
 };
 
-struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+struct FlattenMemrefsPass
+    : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -354,4 +354,3 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
 std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
   return std::make_unique<FlattenMemrefsPass>();
 }
-

>From 6aad0e7cfda39b291247053af0410a127d3cd001 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 23 Apr 2025 10:10:03 -0400
Subject: [PATCH 06/10] Fix misspelling

---
 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 5ec524967444a..dda02be9a9c3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -136,7 +136,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
       getValueFromOpFoldResult(rewriter, loc, index));
 }
 
-static bool needFlattenning(Value val) {
+static bool needFlattening(Value val) {
   auto type = cast<MemRefType>(val.getType());
   return type.getRank() > 1;
 }
@@ -227,7 +227,7 @@ struct MemRefRewritePatternBase : public OpRewritePattern<T> {
   LogicalResult matchAndRewrite(T op,
                                 PatternRewriter &rewriter) const override {
     Value memref = getTargetMemref<T>(op);
-    if (!needFlattenning(memref) || !checkLayout(memref))
+    if (!needFlattening(memref) || !checkLayout(memref))
       return rewriter.notifyMatchFailure(op,
                                          "nothing to do or unsupported layout");
     auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
@@ -283,7 +283,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
   LogicalResult matchAndRewrite(memref::SubViewOp op,
                                 PatternRewriter &rewriter) const override {
     Value memref = op.getSource();
-    if (!needFlattenning(memref))
+    if (!needFlattening(memref))
       return rewriter.notifyMatchFailure(op, "nothing to do");
 
     if (!checkLayout(memref))

>From ce849953b7ae10c42bc1d7b02f4664eed87d4bef Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 25 Apr 2025 21:49:41 -0400
Subject: [PATCH 07/10] amend comments

---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |   2 -
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 203 +++++++-----------
 2 files changed, 79 insertions(+), 126 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a2a9047bda808..a8d135caa74f0 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -250,8 +250,6 @@ def FlattenMemrefsPass : Pass<"flatten-memref"> {
   let description = [{
 
   }];
-
-  let constructor = "mlir::memref::createFlattenMemrefsPass()";
   let dependentDialects = [
       "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
   ];
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index dda02be9a9c3a..8336d9b5715e6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
 namespace memref {
@@ -148,135 +149,90 @@ static bool checkLayout(Value val) {
 }
 
 namespace {
-template <typename T>
-static Value getTargetMemref(T op) {
-  if constexpr (std::is_same_v<T, memref::LoadOp>) {
-    return op.getMemref();
-  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
-    return op.getBase();
-  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
-    return op.getMemref();
-  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
-    return op.getBase();
-  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
-    return op.getBase();
-  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
-    return op.getBase();
-  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
-    return op.getSource();
-  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
-    return op.getSource();
-  }
-  return {};
+static Value getTargetMemref(Operation *op) {
+  return llvm::TypeSwitch<Operation *, Value>(op)
+      .template Case<memref::LoadOp, memref::StoreOp>(
+          [](auto op) { return op.getMemref(); })
+      .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
+                     vector::MaskedStoreOp>(
+          [](auto op) { return op.getBase(); })
+      .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
+          [](auto op) { return op.getSource(); })
+      .Default([](auto) { return Value{}; });
 }
 
-template <typename T>
-static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
-                      Value offset) {
-  if constexpr (std::is_same_v<T, memref::LoadOp>) {
-    auto newLoad = rewriter.create<memref::LoadOp>(
-        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
-    newLoad->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newLoad.getResult());
-  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
-    newLoad->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newLoad.getResult());
-  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
-    auto newStore = rewriter.create<memref::StoreOp>(
-        op->getLoc(), op->getOperands().front(), flatMemref,
-        ValueRange{offset});
-    newStore->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newStore);
-  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
-    auto newStore = rewriter.create<vector::StoreOp>(
-        op->getLoc(), op->getOperands().front(), flatMemref,
-        ValueRange{offset});
-    newStore->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newStore);
-  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
-    auto newTransferRead = rewriter.create<vector::TransferReadOp>(
-        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
-        op.getPadding());
-    rewriter.replaceOp(op, newTransferRead.getResult());
-  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
-    auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
-        op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
-    rewriter.replaceOp(op, newTransferWrite);
-  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
-    auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
-        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
-        op.getMask(), op.getPassThru());
-    newMaskedLoad->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newMaskedLoad.getResult());
-  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
-    auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
-        op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
-        op.getValueToStore());
-    newMaskedStore->setAttrs(op->getAttrs());
-    rewriter.replaceOp(op, newMaskedStore);
-  } else {
-    op.emitOpError("unimplemented: do not know how to replace op.");
-  }
+static void replaceOp(Operation *op, PatternRewriter &rewriter,
+                      Value flatMemref, Value offset) {
+  auto loc = op->getLoc();
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<memref::LoadOp>([&](auto op) {
+        auto newLoad = rewriter.create<memref::LoadOp>(
+            loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+        newLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newLoad.getResult());
+      })
+      .Case<memref::StoreOp>([&](auto op) {
+        auto newStore = rewriter.create<memref::StoreOp>(
+            loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+        newStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newStore);
+      })
+      .Case<vector::LoadOp>([&](auto op) {
+        auto newLoad = rewriter.create<vector::LoadOp>(
+            loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+        newLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newLoad.getResult());
+      })
+      .Case<vector::StoreOp>([&](auto op) {
+        auto newStore = rewriter.create<vector::StoreOp>(
+            loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+        newStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newStore);
+      })
+      .Case<vector::MaskedLoadOp>([&](auto op) {
+        auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+            loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
+            op.getPassThru());
+        newMaskedLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newMaskedLoad.getResult());
+      })
+      .Case<vector::MaskedStoreOp>([&](auto op) {
+        auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+            loc, flatMemref, ValueRange{offset}, op.getMask(),
+            op.getValueToStore());
+        newMaskedStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newMaskedStore);
+      })
+      .Case<vector::TransferReadOp>([&](auto op) {
+        auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+            loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
+        rewriter.replaceOp(op, newTransferRead.getResult());
+      })
+      .Case<vector::TransferWriteOp>([&](auto op) {
+        auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+            loc, op.getVector(), flatMemref, ValueRange{offset});
+        rewriter.replaceOp(op, newTransferWrite);
+      })
+      .Default([&](auto op) {
+        op->emitOpError("unimplemented: do not know how to replace op.");
+      });
 }
 
 template <typename T>
-struct MemRefRewritePatternBase : public OpRewritePattern<T> {
+struct MemRefRewritePattern : public OpRewritePattern<T> {
   using OpRewritePattern<T>::OpRewritePattern;
   LogicalResult matchAndRewrite(T op,
                                 PatternRewriter &rewriter) const override {
-    Value memref = getTargetMemref<T>(op);
+    Value memref = getTargetMemref(op);
     if (!needFlattening(memref) || !checkLayout(memref))
-      return rewriter.notifyMatchFailure(op,
-                                         "nothing to do or unsupported layout");
+      return failure();
     auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
         rewriter, op->getLoc(), memref, op.getIndices());
-    replaceOp<T>(op, rewriter, flatMemref, offset);
+    replaceOp(op, rewriter, flatMemref, offset);
     return success();
   }
 };
 
-struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
-  using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
-  using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
-  using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
-  using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorMaskedLoad
-    : public MemRefRewritePatternBase<vector::MaskedLoadOp> {
-  using MemRefRewritePatternBase<
-      vector::MaskedLoadOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorMaskedStore
-    : public MemRefRewritePatternBase<vector::MaskedStoreOp> {
-  using MemRefRewritePatternBase<
-      vector::MaskedStoreOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorTransferRead
-    : public MemRefRewritePatternBase<vector::TransferReadOp> {
-  using MemRefRewritePatternBase<
-      vector::TransferReadOp>::MemRefRewritePatternBase;
-};
-
-struct FlattenVectorTransferWrite
-    : public MemRefRewritePatternBase<vector::TransferWriteOp> {
-  using MemRefRewritePatternBase<
-      vector::TransferWriteOp>::MemRefRewritePatternBase;
-};
-
 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -284,7 +240,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
                                 PatternRewriter &rewriter) const override {
     Value memref = op.getSource();
     if (!needFlattening(memref))
-      return rewriter.notifyMatchFailure(op, "nothing to do");
+      return rewriter.notifyMatchFailure(op, "already flattened");
 
     if (!checkLayout(memref))
       return rewriter.notifyMatchFailure(op, "unsupported layout");
@@ -344,13 +300,12 @@ struct FlattenMemrefsPass
 } // namespace
 
 void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
-                  FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
-                  FlattenVectorLoad, FlattenVectorStore,
-                  FlattenVectorTransferRead, FlattenVectorTransferWrite>(
-      patterns.getContext());
-}
-
-std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
-  return std::make_unique<FlattenMemrefsPass>();
+  patterns
+      .insert<MemRefRewritePattern<memref::LoadOp>,
+              MemRefRewritePattern<memref::StoreOp>,
+              MemRefRewritePattern<vector::LoadOp>,
+              MemRefRewritePattern<vector::StoreOp>,
+              MemRefRewritePattern<vector::TransferReadOp>,
+              MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
+          patterns.getContext());
 }

>From 7b097f47abcfaec5e12eba1096c0f3913aba509b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 28 Apr 2025 09:31:23 -0400
Subject: [PATCH 08/10] Not working yet.

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 68 ++++++++++++++++++-
 1 file changed, 66 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 8336d9b5715e6..65d92cc8e8c52 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -46,6 +46,69 @@ static void setInsertionPointToStart(OpBuilder &builder, Value val) {
   }
 }
 
+OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) {
+  Location loc = memref.getLoc();
+  MemRefType type = cast<MemRefType>(memref.getType());
+  ArrayRef<int64_t> shape = type.getShape();
+  
+  // Check for empty memref
+  if (type.hasStaticShape() && 
+      llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) {
+    return builder.getIndexAttr(0);
+  }
+  
+  // Get strides of the memref
+  SmallVector<int64_t, 4> strides;
+  int64_t offset;
+  if (failed(type.getStridesAndOffset(strides, offset))) {
+    // Cannot extract strides, return a dynamic value
+    return Value();
+  }
+  
+  // Static case: compute at compile time if possible
+  if (type.hasStaticShape()) {
+    int64_t span = 0;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      span += (shape[i] - 1) * strides[i];
+    }
+    return builder.getIndexAttr(span);
+  }
+  
+  // Dynamic case: emit IR to compute at runtime
+  Value result = builder.create<arith::ConstantIndexOp>(loc, 0);
+  
+  for (unsigned i = 0; i < type.getRank(); ++i) {
+    // Get dimension size
+    Value dimSize;
+    if (shape[i] == ShapedType::kDynamic) {
+      dimSize = builder.create<memref::DimOp>(loc, memref, i);
+    } else {
+      dimSize = builder.create<arith::ConstantIndexOp>(loc, shape[i]);
+    }
+    
+    // Compute (dim - 1)
+    Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+    Value dimMinusOne = builder.create<arith::SubIOp>(loc, dimSize, one);
+    
+    // Get stride
+    Value stride;
+    if (strides[i] == ShapedType::kDynamicStrideOrOffset) {
+      // For dynamic strides, need to extract from memref descriptor
+      // This would require runtime support, possibly using extractStride
+      // As a placeholder, return a dynamic value
+      return Value();
+    } else {
+      stride = builder.create<arith::ConstantIndexOp>(loc, strides[i]);
+    }
+    
+    // Add (dim - 1) * stride to result
+    Value term = builder.create<arith::MulIOp>(loc, dimMinusOne, stride);
+    result = builder.create<arith::AddIOp>(loc, result, term);
+  }
+  
+  return result;
+}
+
 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
                   OpFoldResult>
 getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
@@ -102,8 +165,9 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
       affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
 
   // Compute collapsed size: (the outmost stride * outmost dimension).
-  SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
-  OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
+  //SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
+  //OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
+  OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter);
 
   return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
           origStrides, origOffset, collapsedSize};

>From c9db9bffd282bb0dbc4313eb065e15e92d10f24b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 1 May 2025 12:33:32 -0400
Subject: [PATCH 09/10] Some updates

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 207 +++++-------------
 mlir/test/Dialect/MemRef/flatten_memref.mlir  |  26 ++-
 2 files changed, 69 insertions(+), 164 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 65d92cc8e8c52..ba4a00f9c0ed6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -38,141 +39,6 @@ namespace memref {
 
 using namespace mlir;
 
-static void setInsertionPointToStart(OpBuilder &builder, Value val) {
-  if (auto *parentOp = val.getDefiningOp()) {
-    builder.setInsertionPointAfter(parentOp);
-  } else {
-    builder.setInsertionPointToStart(val.getParentBlock());
-  }
-}
-
-OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) {
-  Location loc = memref.getLoc();
-  MemRefType type = cast<MemRefType>(memref.getType());
-  ArrayRef<int64_t> shape = type.getShape();
-  
-  // Check for empty memref
-  if (type.hasStaticShape() && 
-      llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) {
-    return builder.getIndexAttr(0);
-  }
-  
-  // Get strides of the memref
-  SmallVector<int64_t, 4> strides;
-  int64_t offset;
-  if (failed(type.getStridesAndOffset(strides, offset))) {
-    // Cannot extract strides, return a dynamic value
-    return Value();
-  }
-  
-  // Static case: compute at compile time if possible
-  if (type.hasStaticShape()) {
-    int64_t span = 0;
-    for (unsigned i = 0; i < type.getRank(); ++i) {
-      span += (shape[i] - 1) * strides[i];
-    }
-    return builder.getIndexAttr(span);
-  }
-  
-  // Dynamic case: emit IR to compute at runtime
-  Value result = builder.create<arith::ConstantIndexOp>(loc, 0);
-  
-  for (unsigned i = 0; i < type.getRank(); ++i) {
-    // Get dimension size
-    Value dimSize;
-    if (shape[i] == ShapedType::kDynamic) {
-      dimSize = builder.create<memref::DimOp>(loc, memref, i);
-    } else {
-      dimSize = builder.create<arith::ConstantIndexOp>(loc, shape[i]);
-    }
-    
-    // Compute (dim - 1)
-    Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
-    Value dimMinusOne = builder.create<arith::SubIOp>(loc, dimSize, one);
-    
-    // Get stride
-    Value stride;
-    if (strides[i] == ShapedType::kDynamicStrideOrOffset) {
-      // For dynamic strides, need to extract from memref descriptor
-      // This would require runtime support, possibly using extractStride
-      // As a placeholder, return a dynamic value
-      return Value();
-    } else {
-      stride = builder.create<arith::ConstantIndexOp>(loc, strides[i]);
-    }
-    
-    // Add (dim - 1) * stride to result
-    Value term = builder.create<arith::MulIOp>(loc, dimMinusOne, stride);
-    result = builder.create<arith::AddIOp>(loc, result, term);
-  }
-  
-  return result;
-}
-
-static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
-                  OpFoldResult>
-getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
-                        ArrayRef<OpFoldResult> subOffsets,
-                        ArrayRef<OpFoldResult> subStrides = std::nullopt) {
-  auto sourceType = cast<MemRefType>(source.getType());
-  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
-
-  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
-  {
-    OpBuilder::InsertionGuard g(rewriter);
-    setInsertionPointToStart(rewriter, source);
-    newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
-  }
-
-  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
-
-  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
-    return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
-                                      : rewriter.getIndexAttr(dim);
-  };
-
-  OpFoldResult origOffset =
-      getDim(sourceOffset, newExtractStridedMetadata.getOffset());
-  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
-  OpFoldResult outmostDim =
-      getDim(sourceType.getShape().front(),
-             newExtractStridedMetadata.getSizes().front());
-
-  SmallVector<OpFoldResult> origStrides;
-  origStrides.reserve(sourceRank);
-
-  SmallVector<OpFoldResult> strides;
-  strides.reserve(sourceRank);
-
-  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
-  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
-  for (auto i : llvm::seq(0u, sourceRank)) {
-    OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
-
-    if (!subStrides.empty()) {
-      strides.push_back(affine::makeComposedFoldedAffineApply(
-          rewriter, loc, s0 * s1, {subStrides[i], origStride}));
-    }
-
-    origStrides.emplace_back(origStride);
-  }
-
-  // Compute linearized index:
-  auto &&[expr, values] =
-      computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
-  OpFoldResult linearizedIndex =
-      affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
-
-  // Compute collapsed size: (the outmost stride * outmost dimension).
-  //SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
-  //OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
-  OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter);
-
-  return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
-          origStrides, origOffset, collapsedSize};
-}
-
 static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
                                       OpFoldResult in) {
   if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
@@ -188,17 +54,36 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
                                                          Location loc,
                                                          Value source,
                                                          ValueRange indices) {
-  auto &&[base, index, strides, offset, collapsedShape] =
-      getFlatOffsetAndStrides(rewriter, loc, source,
-                              getAsOpFoldResult(indices));
+  int64_t sourceOffset;
+  SmallVector<int64_t, 4> sourceStrides;
+  auto sourceType = cast<MemRefType>(source.getType());
+  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
+    assert(false);
+  }
+
+  memref::ExtractStridedMetadataOp stridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+
+  auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
+  OpFoldResult linearizedIndices;
+  memref::LinearizedMemRefInfo linearizedInfo;
+  std::tie(linearizedInfo, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          rewriter, loc, typeBit, typeBit,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(),
+          getAsOpFoldResult(indices));
 
   return std::make_pair(
       rewriter.create<memref::ReinterpretCastOp>(
           loc, source,
-          /* offset = */ offset,
-          /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
-          /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
-      getValueFromOpFoldResult(rewriter, loc, index));
+          /* offset = */ linearizedInfo.linearizedOffset,
+          /* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
+          /* strides = */
+          ArrayRef<OpFoldResult>{
+              stridedMetadata.getConstifiedMixedStrides().back()}),
+      getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
 }
 
 static bool needFlattening(Value val) {
@@ -313,8 +198,23 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
     SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
     SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
     SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
-    auto &&[base, finalOffset, strides, _, __] =
-        getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+    // base, finalOffset, strides
+    memref::ExtractStridedMetadataOp stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, memref);
+
+    auto sourceType = cast<MemRefType>(memref.getType());
+    auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
+    OpFoldResult linearizedIndices;
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, typeBit, typeBit,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets());
+    auto finalOffset = linearizedInfo.linearizedOffset;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
 
     auto srcType = cast<MemRefType>(memref.getType());
     auto resultType = cast<MemRefType>(op.getType());
@@ -337,7 +237,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
     }
 
     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
-        op, resultType, base, finalOffset, finalSizes, finalStrides);
+        op, resultType, memref, finalOffset, finalSizes, finalStrides);
     return success();
   }
 };
@@ -364,12 +264,13 @@ struct FlattenMemrefsPass
 } // namespace
 
 void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
-  patterns
-      .insert<MemRefRewritePattern<memref::LoadOp>,
-              MemRefRewritePattern<memref::StoreOp>,
-              MemRefRewritePattern<vector::LoadOp>,
-              MemRefRewritePattern<vector::StoreOp>,
-              MemRefRewritePattern<vector::TransferReadOp>,
-              MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
-          patterns.getContext());
+  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+                  MemRefRewritePattern<memref::StoreOp>,
+                  MemRefRewritePattern<vector::LoadOp>,
+                  MemRefRewritePattern<vector::StoreOp>,
+                  MemRefRewritePattern<vector::TransferReadOp>,
+                  MemRefRewritePattern<vector::TransferWriteOp>,
+                  MemRefRewritePattern<vector::MaskedLoadOp>,
+                  MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 6c9b09985acf7..f65e12ad6916d 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -6,7 +6,7 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse
   %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>>
   return %value : f32
 }
-// CHECK: func @load_scalar_from_memref
+// CHECK-LABEL: func @load_scalar_from_memref
 // CHECK: %[[C10:.*]] = arith.constant 10 : index
 // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
 // CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
@@ -18,6 +18,7 @@ func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<
   %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
   return %value : f32
 }
+
 // CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
 // CHECK: func @load_scalar_from_memref_static_dim_2
 // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -39,7 +40,7 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
 // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
 // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
 // CHECK: memref.load %[[REINT]][%[[IDX]]]
 
@@ -49,7 +50,9 @@ func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index,
   %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
   return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
 }
-// CHECK: func @load_scalar_from_memref_subview
+// CHECK-LABEL: func @load_scalar_from_memref_subview
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1]
 
 // -----
 
@@ -76,7 +79,7 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
 // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
 // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
 // CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
 
@@ -88,7 +91,7 @@ func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> {
   %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
   return %value : vector<8xf32>
 }
-// CHECK: func @load_vector_from_memref
+// CHECK-LABEL: func @load_vector_from_memref
 // CHECK: %[[C30:.*]] = arith.constant 30
 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1]
 // CHECK-NEXT: vector.load %[[REINT]][%[[C30]]]
@@ -101,7 +104,7 @@ func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
   %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
   return %value : vector<3xi2>
 }
-// CHECK: func @load_vector_from_memref_odd
+// CHECK-LABEL: func @load_vector_from_memref_odd
 // CHECK: %[[C10:.*]] = arith.constant 10 : index
 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
 // CHECK-NEXT: vector.load %[[REINT]][%[[C10]]]
@@ -126,10 +129,11 @@ func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi
   vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
   return
 }
-// CHECK: func @store_vector_to_memref_odd
+// CHECK-LABEL: func @store_vector_to_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>)
 // CHECK: %[[C10:.*]] = arith.constant 10 : index
 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
-// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
+// CHECK-NEXT: vector.store %[[ARG1]], %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
 
 // -----
 
@@ -152,7 +156,7 @@ func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vecto
   vector.maskedstore %input[%c1, %c3], %mask, %value  : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
   return
 }
-// CHECK: func @mask_store_vector_to_memref_odd
+// CHECK-LABEL: func @mask_store_vector_to_memref_odd
 // CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>)
 // CHECK: %[[C10:.*]] = arith.constant 10 : index
 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
@@ -178,7 +182,7 @@ func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vecto
   %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
   return %result : vector<3xi2>
 }
-// CHECK: func @mask_load_vector_from_memref_odd
+// CHECK-LABEL: func @mask_load_vector_from_memref_odd
 // CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>)
 // CHECK: %[[C10:.*]] = arith.constant 10 : index
 // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
@@ -204,7 +208,7 @@ func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %r
    %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
    return %0 : vector<8xi2>
 }
-// CHECK: func @transfer_read_memref
+// CHECK-LABEL: func @transfer_read_memref
 // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
 // CHECK: %[[C0:.*]] = arith.constant 0 : i2
 // CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]

>From 01d131e936f03844a3f919a1a70cde916273be36 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 5 May 2025 23:38:08 -0400
Subject: [PATCH 10/10] Remove subview

---
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 62 +------------------
 mlir/test/Dialect/MemRef/flatten_memref.mlir  | 10 ---
 2 files changed, 1 insertion(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index ba4a00f9c0ed6..32fe64bb616bc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -182,66 +182,6 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
   }
 };
 
-struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(memref::SubViewOp op,
-                                PatternRewriter &rewriter) const override {
-    Value memref = op.getSource();
-    if (!needFlattening(memref))
-      return rewriter.notifyMatchFailure(op, "already flattened");
-
-    if (!checkLayout(memref))
-      return rewriter.notifyMatchFailure(op, "unsupported layout");
-
-    Location loc = op.getLoc();
-    SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
-    SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
-    SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
-
-    // base, finalOffset, strides
-    memref::ExtractStridedMetadataOp stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, memref);
-
-    auto sourceType = cast<MemRefType>(memref.getType());
-    auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
-    OpFoldResult linearizedIndices;
-    memref::LinearizedMemRefInfo linearizedInfo;
-    std::tie(linearizedInfo, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, typeBit, typeBit,
-            stridedMetadata.getConstifiedMixedOffset(),
-            stridedMetadata.getConstifiedMixedSizes(),
-            stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets());
-    auto finalOffset = linearizedInfo.linearizedOffset;
-    auto strides = stridedMetadata.getConstifiedMixedStrides();
-
-    auto srcType = cast<MemRefType>(memref.getType());
-    auto resultType = cast<MemRefType>(op.getType());
-    unsigned subRank = static_cast<unsigned>(resultType.getRank());
-
-    llvm::SmallBitVector droppedDims = op.getDroppedDims();
-
-    SmallVector<OpFoldResult> finalSizes;
-    finalSizes.reserve(subRank);
-
-    SmallVector<OpFoldResult> finalStrides;
-    finalStrides.reserve(subRank);
-
-    for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
-      if (droppedDims.test(i))
-        continue;
-
-      finalSizes.push_back(subSizes[i]);
-      finalStrides.push_back(strides[i]);
-    }
-
-    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
-        op, resultType, memref, finalOffset, finalSizes, finalStrides);
-    return success();
-  }
-};
-
 struct FlattenMemrefsPass
     : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
   using Base::Base;
@@ -271,6 +211,6 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<vector::TransferReadOp>,
                   MemRefRewritePattern<vector::TransferWriteOp>,
                   MemRefRewritePattern<vector::MaskedLoadOp>,
-                  MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
+                  MemRefRewritePattern<vector::MaskedStoreOp>>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index f65e12ad6916d..a182ae58683dd 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -46,16 +46,6 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
 
 // -----
 
-func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> {
-  %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
-  return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
-}
-// CHECK-LABEL: func @load_scalar_from_memref_subview
-// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1]
-
-// -----
-
 func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
   memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
   return



More information about the Mlir-commits mailing list