[Mlir-commits] [mlir] [milr][memref]: Fold expand_shape + transfer_read (PR #167679)

Jack Frankland llvmlistbot at llvm.org
Mon Nov 17 03:26:16 PST 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/167679

>From ffbb4ac70a1e1374ec6865219ea75a2c54c3e928 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Wed, 12 Nov 2025 10:57:10 +0000
Subject: [PATCH] [milr][memref]: Fold expand_shape + transfer_read

Extend the load of a expand shape rewrite pattern to support folding a
`memref.expand_shape` and `vector.transfer_read` when the permutation
map on `vector.transfer_read` is a minor identity.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 26 ++++++++++++--
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 34 +++++++++++++++++++
 2 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 214410f78e51c..30df10c1deedc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -347,28 +347,49 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
           isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
-  llvm::TypeSwitch<Operation *, void>(loadOp)
+
+  return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
         rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices);
+        return success();
       })
       .Case([&](memref::LoadOp op) {
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices,
             op.getNontemporal());
+        return success();
       })
       .Case([&](vector::LoadOp op) {
         rewriter.replaceOpWithNewOp<vector::LoadOp>(
             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
             op.getNontemporal());
+        return success();
       })
       .Case([&](vector::MaskedLoadOp op) {
         rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
             op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
             op.getMask(), op.getPassThru());
+        return success();
+      })
+      .Case([&](vector::TransferReadOp op) {
+        // We only support minor identity maps in the permutation attribute.
+        if (!op.getPermutationMap().isMinorIdentity())
+          return failure();
+
+        // We need to construct a new minor identity map since we will have lost
+        // some dimensions in folding away the expand shape.
+        auto minorIdMap = AffineMap::getMinorIdentityMap(
+            sourceIndices.size(), op.getVectorType().getRank(),
+            op.getContext());
+
+        rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+            op, op.getVectorType(), expandShapeOp.getViewSource(),
+            sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+            op.getInBounds());
+        return success();
       })
       .DefaultUnreachable("unexpected operation");
-  return success();
 }
 
 template <typename OpTy>
@@ -659,6 +680,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
                LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
                LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
+               LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
                StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
                StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 106652623933f..87f23457644ae 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -992,6 +992,40 @@ func.func @fold_vector_maskedstore_expand_shape(
 
 // -----
 
+func.func @fold_vector_transfer_read_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = ub.poison : f32
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[C0:.*]] = arith.constant 0
+//       CHECK:   %[[PAD:.*]] = ub.poison : f32
+//       CHECK:   %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+//       CHECK:   vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]}
+
+// -----
+
+func.func @fold_vector_transfer_read_with_perm_map(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = ub.poison : f32
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32>
+  return %1 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//       CHECK:   memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+
+// -----
+
 func.func @fold_vector_load_collapse_shape(
   %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>



More information about the Mlir-commits mailing list