[Mlir-commits] [mlir] [MLIR] Add a utility pass to linearize `memref` (PR #136797)
Alan Li
llvmlistbot at llvm.org
Tue Apr 22 18:45:08 PDT 2025
https://github.com/lialan created https://github.com/llvm/llvm-project/pull/136797
To add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes:
* support vector dialect ops
* instead of decompose memrefs to rank-0 memrefs, flatten higher-ranked memrefs to rank-1.
Notes:
* After the linearization, a MemRef's offset is kept, so a `memref<4x8xf32, strided<[8, 1], offset: 100>>` becomes `memref<32xf32, strided<[1], offset: 100>>`.
* It also works with dynamic shapes and strides and offsets (see test cases for details).
* The shape of the casted memref is computed as 1d, flattened, with size calculated as `outermostStride * outermostDimSize`.
>From fc1275e21926006d8fc373e5fbbde9709aaabbda Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 14:42:38 -0400
Subject: [PATCH 1/4] First commit.
---
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h | 4 ++++
.../include/mlir/Dialect/MemRef/Transforms/Passes.td | 12 ++++++++++++
.../mlir/Dialect/MemRef/Transforms/Transforms.h | 3 +++
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 ++
4 files changed, 21 insertions(+)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index d7050156862df..7580985754843 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -77,6 +77,10 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
/// components.
std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
+/// Creates an operation pass to flatten multiple dimensional memrefs into
+/// 1-d memrefs.
+std::unique_ptr<Pass> createFlattenMemrefsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 651ee05ae1f3c..c87472851fd78 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -253,5 +253,17 @@ def ExpandRealloc : Pass<"expand-realloc"> {
];
}
+def FlattenMemrefsPass : Pass<"flatten-memref"> {
+ let summary = "Flatten a multiple dimensional memref to 1-dimensional";
+ let description = [{
+
+ }];
+
+ let constructor = "mlir::memref::createFlattenMemrefsPass()";
+ let dependentDialects = [
+ "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 62a2297c80e78..0649bf9c099f9 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,6 +144,9 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+
+void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
+
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
/// given independencies. If the op is already independent of all
/// independencies, the same AllocaOp result is returned.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ecab97bc2b8e7..48e8bccd369fa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
EmulateWideInt.cpp
EmulateNarrowType.cpp
ExtractAddressComputations.cpp
+ FlattenMemRefs.cpp
FoldMemRefAliasOps.cpp
IndependenceTransforms.cpp
MultiBuffer.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
LINK_LIBS PUBLIC
MLIRAffineTransforms
+ MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRArithTransforms
>From 5187c2037eee1b14984ef3a6ba0df1201cf61d86 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 14:45:51 -0400
Subject: [PATCH 2/4] Adding test cases
---
mlir/test/Dialect/MemRef/flatten_memref.mlir | 225 +++++++++++++++++++
1 file changed, 225 insertions(+)
create mode 100644 mlir/test/Dialect/MemRef/flatten_memref.mlir
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
new file mode 100644
index 0000000000000..6c9b09985acf7
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt --flatten-memref %s --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) -> f32 {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>>
+ return %value : f32
+}
+// CHECK: func @load_scalar_from_memref
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
+// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 {
+ %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
+ return %value : f32
+}
+// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
+// CHECK: func @load_scalar_from_memref_static_dim_2
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [12]
+// CHECK-SAME: to memref<32xf32, strided<[12], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index) -> f32 {
+ %value = memref.load %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %value : f32
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @load_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
+// CHECK: memref.load %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> {
+ %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
+ return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
+}
+// CHECK: func @load_scalar_from_memref_subview
+
+// -----
+
+func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
+ memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
+ return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
+// CHECK: func @store_scalar_from_memref_static_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[12], offset: 100>>
+
+// -----
+
+func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index, %value: f32) {
+ memref.store %value, %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @store_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> {
+ %c3 = arith.constant 3 : index
+ %c6 = arith.constant 6 : index
+ %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
+ return %value : vector<8xf32>
+}
+// CHECK: func @load_vector_from_memref
+// CHECK: %[[C30:.*]] = arith.constant 30
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1]
+// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]]
+
+// -----
+
+func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+ return %value : vector<3xi2>
+}
+// CHECK: func @load_vector_from_memref_odd
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]]
+
+// -----
+
+func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> {
+ %value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+ return %value : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @load_vector_from_memref_dynamic
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.load %[[REINT]][%[[IDX]]] : memref<21xi2, strided<[1]>>, vector<3xi2>
+
+// -----
+
+func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>) {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+ return
+}
+// CHECK: func @store_vector_to_memref_odd
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
+
+// -----
+
+func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) {
+ vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+ return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.store %[[ARG1]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+ return
+}
+// CHECK: func @mask_store_vector_to_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.maskedstore %[[REINT]][%[[C10]]], %[[ARG2]], %[[ARG1]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) {
+ vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+ return
+}
+// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: vector<3xi1>)
+// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedstore %[[REINT]][%[[IDX]]], %[[ARG4]], %[[ARG1]]
+
+// -----
+func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+ return %result : vector<3xi2>
+}
+// CHECK: func @mask_load_vector_from_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.maskedload %[[REINT]][%[[C10]]], %[[MASK]], %[[PASSTHRU]]
+
+// -----
+
+func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+ %result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+ return %result : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_load_vector_from_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<3xi1>, %[[ARG4:.*]]: vector<3xi2>)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedload %[[REINT]][%[[IDX]]], %[[ARG3]]
+
+// -----
+
+func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
+ %c0 = arith.constant 0 : i2
+ %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
+ return %0 : vector<8xi2>
+}
+// CHECK: func @transfer_read_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[C0:.*]] = arith.constant 0 : i2
+// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]]
+
+// -----
+
+func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) {
+ vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2>
+ return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK: func @transfer_write_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]]
>From 90377146fbf7f6c630be51373bf09f90bf082454 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 16:23:39 -0400
Subject: [PATCH 3/4] missing the file
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 357 ++++++++++++++++++
1 file changed, 357 insertions(+)
create mode 100644 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
new file mode 100644
index 0000000000000..6685896624536
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -0,0 +1,357 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+ if (auto *parentOp = val.getDefiningOp()) {
+ builder.setInsertionPointAfter(parentOp);
+ } else {
+ builder.setInsertionPointToStart(val.getParentBlock());
+ }
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
+ OpFoldResult>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+ ArrayRef<OpFoldResult> subOffsets,
+ ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+ auto sourceType = cast<MemRefType>(source.getType());
+ auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+ memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ setInsertionPointToStart(rewriter, source);
+ newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+ }
+
+ auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
+
+ auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+ return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+ : rewriter.getIndexAttr(dim);
+ };
+
+ OpFoldResult origOffset =
+ getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+ ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+ OpFoldResult outmostDim =
+ getDim(sourceType.getShape().front(),
+ newExtractStridedMetadata.getSizes().front());
+
+ SmallVector<OpFoldResult> origStrides;
+ origStrides.reserve(sourceRank);
+
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(sourceRank);
+
+ AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+ AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+ for (auto i : llvm::seq(0u, sourceRank)) {
+ OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+ if (!subStrides.empty()) {
+ strides.push_back(affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+ }
+
+ origStrides.emplace_back(origStride);
+ }
+
+ // Compute linearized index:
+ auto &&[expr, values] =
+ computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
+ OpFoldResult linearizedIndex =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+
+ // Compute collapsed size: (the outmost stride * outmost dimension).
+ SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
+ OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
+
+ return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
+ origStrides, origOffset, collapsedSize};
+}
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+ OpFoldResult in) {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+ return rewriter.create<arith::ConstantIndexOp>(
+ loc, cast<IntegerAttr>(offsetAttr).getInt());
+ }
+ return cast<Value>(in);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+ Location loc,
+ Value source,
+ ValueRange indices) {
+ auto &&[base, index, strides, offset, collapsedShape] =
+ getFlatOffsetAndStrides(rewriter, loc, source,
+ getAsOpFoldResult(indices));
+
+ return std::make_pair(
+ rewriter.create<memref::ReinterpretCastOp>(
+ loc, source,
+ /* offset = */ offset,
+ /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
+ /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
+ getValueFromOpFoldResult(rewriter, loc, index));
+}
+
+static bool needFlattenning(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getLayout().isIdentity() ||
+ isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+template <typename T>
+static Value getTargetMemref(T op) {
+ if constexpr (std::is_same_v<T, memref::LoadOp>) {
+ return op.getMemref();
+ } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+ return op.getMemref();
+ } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+ return op.getSource();
+ } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+ return op.getSource();
+ }
+ return {};
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+ Value offset) {
+ if constexpr (std::is_same_v<T, memref::LoadOp>) {
+ auto newLoad = rewriter.create<memref::LoadOp>(
+ op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+ auto newStore = rewriter.create<memref::StoreOp>(
+ op->getLoc(), op->getOperands().front(), flatMemref,
+ ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+ auto newStore = rewriter.create<vector::StoreOp>(
+ op->getLoc(), op->getOperands().front(), flatMemref,
+ ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+ auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+ op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+ op.getPadding());
+ rewriter.replaceOp(op, newTransferRead.getResult());
+ } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+ auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+ op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
+ rewriter.replaceOp(op, newTransferWrite);
+ } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+ auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+ op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+ op.getMask(), op.getPassThru());
+ newMaskedLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedLoad.getResult());
+ } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+ auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+ op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
+ op.getValueToStore());
+ newMaskedStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedStore);
+ } else {
+ op.emitOpError("unimplemented: do not know how to replace op.");
+ }
+}
+
+template <typename T>
+struct MemRefRewritePatternBase : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ Value memref = getTargetMemref<T>(op);
+ if (!needFlattenning(memref) || !checkLayout(memref))
+ return rewriter.notifyMatchFailure(op,
+ "nothing to do or unsupported layout");
+ auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
+ rewriter, op->getLoc(), memref, op.getIndices());
+ replaceOp<T>(op, rewriter, flatMemref, offset);
+ return success();
+ }
+};
+
+struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
+ using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
+ using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
+ using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
+ using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedLoad
+ : public MemRefRewritePatternBase<vector::MaskedLoadOp> {
+ using MemRefRewritePatternBase<
+ vector::MaskedLoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedStore
+ : public MemRefRewritePatternBase<vector::MaskedStoreOp> {
+ using MemRefRewritePatternBase<
+ vector::MaskedStoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferRead
+ : public MemRefRewritePatternBase<vector::TransferReadOp> {
+ using MemRefRewritePatternBase<
+ vector::TransferReadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferWrite
+ : public MemRefRewritePatternBase<vector::TransferWriteOp> {
+ using MemRefRewritePatternBase<
+ vector::TransferWriteOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ Value memref = op.getSource();
+ if (!needFlattenning(memref))
+ return rewriter.notifyMatchFailure(op, "nothing to do");
+
+ if (!checkLayout(memref))
+ return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+ auto &&[base, finalOffset, strides, _, __] =
+ getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+ auto srcType = cast<MemRefType>(memref.getType());
+ auto resultType = cast<MemRefType>(op.getType());
+ unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(subRank);
+
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(subRank);
+
+ for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+ if (droppedDims.test(i))
+ continue;
+
+ finalSizes.push_back(subSizes[i]);
+ finalStrides.push_back(strides[i]);
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, resultType, base, finalOffset, finalSizes, finalStrides);
+ return success();
+ }
+};
+
+struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ memref::MemRefDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+
+ memref::populateFlattenMemrefsPatterns(patterns);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
+ FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
+ FlattenVectorLoad, FlattenVectorStore,
+ FlattenVectorTransferRead, FlattenVectorTransferWrite>(
+ patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
+ return std::make_unique<FlattenMemrefsPass>();
+}
+
>From 35d05d43be26c232d86b080fa7fedf230c0ad364 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Apr 2025 21:00:29 -0400
Subject: [PATCH 4/4] Fix linking issue
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 48 ++++++++++---------
.../MemRef/Transforms/FlattenMemRefs.cpp | 2 +-
2 files changed, 27 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index aa49c49062c76..43224de5604ed 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5022,6 +5022,31 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
return ret;
}
+namespace mlir {
+namespace affine {
+OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(cast<Value>(term));
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+} // namespace affine
+} // namespace mlir
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -5081,27 +5106,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-OpFoldResult computeProduct(Location loc, OpBuilder &builder,
- ArrayRef<OpFoldResult> terms) {
- int64_t nDynamic = 0;
- SmallVector<Value> dynamicPart;
- AffineExpr result = builder.getAffineConstantExpr(1);
- for (OpFoldResult term : terms) {
- if (!term)
- return term;
- std::optional<int64_t> maybeConst = getConstantIntValue(term);
- if (maybeConst) {
- result = result * builder.getAffineConstantExpr(*maybeConst);
- } else {
- dynamicPart.push_back(cast<Value>(term));
- result = result * builder.getAffineSymbolExpr(nDynamic++);
- }
- }
- if (auto constant = dyn_cast<AffineConstantExpr>(result))
- return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
- return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
-}
-
/// If conseceutive outputs of a delinearize_index are linearized with the same
/// bounds, canonicalize away the redundant arithmetic.
///
@@ -5248,7 +5252,7 @@ struct CancelLinearizeOfDelinearizePortion final
// We use the slice from the linearize's basis above because of the
// "bounds inferred from `disjoint`" case above.
OpFoldResult newSize =
- computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
+ affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
// Trivial case where we can just skip past the delinearize all together
if (m.length == m.delinearize.getNumResults()) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 6685896624536..8299dc1716121 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -103,7 +103,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
// Compute collapsed size: (the outmost stride * outmost dimension).
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
- OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
+ OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
origStrides, origOffset, collapsedSize};
More information about the Mlir-commits
mailing list