[Mlir-commits] [mlir] [mlir][memref]: Fold ExpandShape into TransferRead (PR #176786)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 09:55:21 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

Add support for folding `memref.expand_shape` ops into `vector.transfer_read` ops when the permutation map is a non-minor-identity.

In the case that the permutation map indexes into expanded dimensions that would be contiguous within the original source shape then it is safe to make this transformation.

---
Full diff: https://github.com/llvm/llvm-project/pull/176786.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+28-10) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+36) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3cacb7e29263b..3344186b130e1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -323,25 +323,43 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
         return success();
       })
       .Case([&](vector::TransferReadOp op) {
-        // We only support minor identity maps in the permutation attribute.
-        if (!op.getPermutationMap().isMinorIdentity())
-          return failure();
-
         // We only support the case where the source of the expand shape has
         // rank greater than or equal to the vector rank.
-        const int64_t sourceRank = sourceIndices.size();
         const int64_t vectorRank = op.getVectorType().getRank();
+        const int64_t sourceRank = sourceIndices.size();
         if (sourceRank < vectorRank)
           return failure();
 
-        // We need to construct a new minor identity map since we will have lost
-        // some dimensions in folding away the expand shape.
-        auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
-                                                         op.getContext());
+        SmallVector<AffineExpr> newResults;
+        // We can only fold if the permutation map uses only the least
+        // significant dimension from an expanded shape.
+        for (AffineExpr result : op.getPermutationMap().getResults()) {
+          bool foundExpr = false;
+
+          uint32_t newDim = 0;
+          for (auto reassocationIndices :
+               llvm::enumerate(expandShapeOp.getReassociationIndices())) {
+            auto reassociation = reassocationIndices.value();
+
+            AffineExpr dim = getAffineDimExpr(
+                reassociation[reassociation.size() - 1], rewriter.getContext());
+            if (dim == result) {
+              newResults.push_back(getAffineDimExpr(reassocationIndices.index(),
+                                                    rewriter.getContext()));
+              foundExpr = true;
+              break;
+            }
+          }
+          if (!foundExpr)
+            return failure();
+        }
+
+        auto newMap =
+            AffineMap::get(sourceRank, 0, newResults, op.getContext());
 
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             op, op.getVectorType(), expandShapeOp.getViewSource(),
-            sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+            sourceIndices, newMap, op.getPadding(), op.getMask(),
             op.getInBounds());
         return success();
       })
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 93e5ba462584a..79156df0ebe1e 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -820,6 +820,42 @@ func.func @fold_vector_transfer_read_expand_shape(
 
 // -----
 
+func.func @fold_vector_transfer_read_expand_shape_non_identity(
+  %arg0 : memref<32x32xf32>, %arg1 : index, %arg2 : index) -> vector<8x8xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = ub.poison : f32
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+  %1 = vector.transfer_read %0[%arg1, %c0, %arg2, %c0], %pad {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>} : memref<4x8x4x8xf32>, vector<8x8xf32>
+  return %1 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape_non_identity
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32x32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[C0:.*]] = arith.constant 0
+//       CHECK:   %[[PAD:.*]] = ub.poison : f32
+//       CHECK:   %[[IDX1:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+//       CHECK:   %[[IDX2:.*]] = affine.linearize_index [%[[ARG2]], %[[C0]]] by (4, 8)
+//       CHECK:   vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]]], %[[PAD]] {in_bounds = [true, true]}
+
+// -----
+
+func.func @fold_vector_transfer_read_expand_shape_non_identity_non_contiguous(
+  %arg0 : memref<32x32xf32>, %arg1 : index, %arg2 : index) -> vector<8x8xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = ub.poison : f32
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+  %1 = vector.transfer_read %0[%arg1, %c0, %arg2, %c0], %pad {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d3)>} : memref<4x8x4x8xf32>, vector<8x8xf32>
+  return %1 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape_non_identity_non_contiguous
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32x32xf32>
+//       CHECK:   memref.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] output_shape [4, 8, 4, 8] : memref<32x32xf32> into memref<4x8x4x8xf32>
+
+// -----
+
 func.func @fold_vector_transfer_read_with_perm_map(
   %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
   %c0 = arith.constant 0 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/176786


More information about the Mlir-commits mailing list