[Mlir-commits] [mlir] [mlir][memref]: Fold ExpandShape into TransferRead (PR #176786)
Jack Frankland
llvmlistbot at llvm.org
Wed Jan 21 03:36:33 PST 2026
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/176786
>From e63d8bc47f27ad9fcd7d4c02f53ca7b834561cd1 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 19 Jan 2026 17:26:03 +0000
Subject: [PATCH] [mlir][memref]: Fold ExpandShape into TransferRead
Add support for folding `memref.expand_shape` ops into
`vector.transfer_read` ops when the permutation map is a
non-minor-identity.
In the case that the permutation map indexes into expanded dimensions
that would be contiguous within the original source shape then it is
safe to make this transformation.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
.../MemRef/Transforms/FoldMemRefAliasOps.cpp | 37 ++++++++++++++-----
.../Dialect/MemRef/fold-memref-alias-ops.mlir | 36 ++++++++++++++++++
2 files changed, 63 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3cacb7e29263b..06c3392cd6732 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -323,25 +323,42 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
return success();
})
.Case([&](vector::TransferReadOp op) {
- // We only support minor identity maps in the permutation attribute.
- if (!op.getPermutationMap().isMinorIdentity())
- return failure();
-
// We only support the case where the source of the expand shape has
// rank greater than or equal to the vector rank.
- const int64_t sourceRank = sourceIndices.size();
const int64_t vectorRank = op.getVectorType().getRank();
+ const int64_t sourceRank = sourceIndices.size();
if (sourceRank < vectorRank)
return failure();
- // We need to construct a new minor identity map since we will have lost
- // some dimensions in folding away the expand shape.
- auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
- op.getContext());
+ SmallVector<AffineExpr> newResults;
+ // We can only fold if the permutation map uses only the least
+ // significant dimension from an expanded shape.
+ for (AffineExpr result : op.getPermutationMap().getResults()) {
+ bool foundExpr = false;
+
+ for (auto reassocationIndices :
+ llvm::enumerate(expandShapeOp.getReassociationIndices())) {
+ auto reassociation = reassocationIndices.value();
+
+ AffineExpr dim = getAffineDimExpr(
+ reassociation[reassociation.size() - 1], rewriter.getContext());
+ if (dim == result) {
+ newResults.push_back(getAffineDimExpr(reassocationIndices.index(),
+ rewriter.getContext()));
+ foundExpr = true;
+ break;
+ }
+ }
+ if (!foundExpr)
+ return failure();
+ }
+
+ auto newMap =
+ AffineMap::get(sourceRank, 0, newResults, op.getContext());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), expandShapeOp.getViewSource(),
- sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+ sourceIndices, newMap, op.getPadding(), op.getMask(),
op.getInBounds());
return success();
})
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 93e5ba462584a..79156df0ebe1e 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -820,6 +820,42 @@ func.func @fold_vector_transfer_read_expand_shape(
// -----
+func.func @fold_vector_transfer_read_expand_shape_non_identity(
+ %arg0 : memref<32x32xf32>, %arg1 : index, %arg2 : index) -> vector<8x8xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0, %arg2, %c0], %pad {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>} : memref<4x8x4x8xf32>, vector<8x8xf32>
+ return %1 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape_non_identity
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[PAD:.*]] = ub.poison : f32
+// CHECK: %[[IDX1:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+// CHECK: %[[IDX2:.*]] = affine.linearize_index [%[[ARG2]], %[[C0]]] by (4, 8)
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]]], %[[PAD]] {in_bounds = [true, true]}
+
+// -----
+
+func.func @fold_vector_transfer_read_expand_shape_non_identity_non_contiguous(
+ %arg0 : memref<32x32xf32>, %arg1 : index, %arg2 : index) -> vector<8x8xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0, %arg2, %c0], %pad {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>} : memref<4x8x4x8xf32>, vector<8x8xf32>
+ return %1 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape_non_identity_non_contiguous
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32x32xf32>
+// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+
+// -----
+
func.func @fold_vector_transfer_read_with_perm_map(
%arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list