[Mlir-commits] [mlir] [mlir][memref] Remove unit-stride restriction in SubViewOp folding (PR #192437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 05:31:21 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR replaces manual offset/size resolution with `affine::mergeOffsetsSizesAndStrides`, simplifying the code and extending subview-of-subview folding to support non-unit strides.

---
Full diff: https://github.com/llvm/llvm-project/pull/192437.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+8-26) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 6f2752932422a..b7954cf26926d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -1,4 +1,4 @@
-//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -176,34 +176,16 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
     if (!srcSubView)
       return failure();
 
-    // TODO: relax unit stride assumption.
-    if (!subView.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(subView, "requires unit strides");
-    }
-    if (!srcSubView.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
-    }
-
-    // Resolve sizes according to dropped dims.
-    SmallVector<OpFoldResult> resolvedSizes;
-    llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
-    affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
-                                        subView.getMixedSizes(), srcDroppedDims,
-                                        resolvedSizes);
-
-    // Resolve offsets according to source offsets and strides.
-    SmallVector<Value> resolvedOffsets;
-    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-        rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
-        srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
-        resolvedOffsets);
+    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+    if (failed(affine::mergeOffsetsSizesAndStrides(
+            rewriter, subView.getLoc(), srcSubView, subView,
+            srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
+      return failure();
 
     // Replace original op.
     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
-        subView, subView.getType(), srcSubView.getSource(),
-        getAsOpFoldResult(resolvedOffsets), resolvedSizes,
-        srcSubView.getMixedStrides());
-
+        subView, subView.getType(), srcSubView.getSource(), newOffsets,
+        newSizes, newStrides);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 114ba86cda718..fb8ac2e9858e7 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -508,6 +508,28 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
 
 // -----
 
+// CHECK-LABEL:   func.func @subview_of_subview_no_unit_stride(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x8xf32, strided<[8, 1]>>)
+// CHECK:           %[[SUBVIEW_0:.*]] = memref.subview %[[ARG0]][3, 3] [2, 2] [4, 4] : memref<8x8xf32, strided<[8, 1]>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
+func.func @subview_of_subview_no_unit_stride(%arg0: memref<8x8xf32, strided<[8, 1]>>) -> memref<2x2xf32, strided<[32, 4], offset: 27>> {
+  %subview = memref.subview %arg0[1, 1] [4, 4] [2, 2] : memref<8x8xf32, strided<[8, 1]>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
+  %subview_0 = memref.subview %subview[1, 1] [2, 2] [2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
+  return %subview_0 : memref<2x2xf32, strided<[32, 4], offset: 27>>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @subview_of_subview_rank_reducing_no_unit_stride(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x8xf32, strided<[8, 1]>>)
+// CHECK:           %[[SUBVIEW_0:.*]] = memref.subview %[[ARG0]][3, 3] [1, 2] [4, 4] : memref<8x8xf32, strided<[8, 1]>> to memref<2xf32, strided<[4], offset: 27>>
+func.func @subview_of_subview_rank_reducing_no_unit_stride(%arg0: memref<8x8xf32, strided<[8, 1]>>) -> memref<2xf32, strided<[4], offset: 27>> {
+  %subview = memref.subview %arg0[1, 1] [4, 4] [2, 2] : memref<8x8xf32, strided<[8, 1]>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
+  %subview_0 = memref.subview %subview[1, 1] [1, 2] [2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2xf32, strided<[4], offset: 27>>
+  return %subview_0 : memref<2xf32, strided<[4], offset: 27>>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_load_keep_nontemporal(
 //      CHECK:   memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true}
 func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {

``````````

</details>


https://github.com/llvm/llvm-project/pull/192437


More information about the Mlir-commits mailing list