[Mlir-commits] [mlir] [mlir][memref] Remove unit-stride restriction in SubViewOp folding (PR #192437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 05:31:21 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR replaces manual offset/size resolution with `affine::mergeOffsetsSizesAndStrides`, simplifying the code and extending subview-of-subview folding to support non-unit strides.
---
Full diff: https://github.com/llvm/llvm-project/pull/192437.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+8-26)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+22)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 6f2752932422a..b7954cf26926d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -1,4 +1,4 @@
-//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -176,34 +176,16 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
if (!srcSubView)
return failure();
- // TODO: relax unit stride assumption.
- if (!subView.hasUnitStride()) {
- return rewriter.notifyMatchFailure(subView, "requires unit strides");
- }
- if (!srcSubView.hasUnitStride()) {
- return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
- }
-
- // Resolve sizes according to dropped dims.
- SmallVector<OpFoldResult> resolvedSizes;
- llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
- affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
- subView.getMixedSizes(), srcDroppedDims,
- resolvedSizes);
-
- // Resolve offsets according to source offsets and strides.
- SmallVector<Value> resolvedOffsets;
- affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
- srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
- resolvedOffsets);
+ SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+ if (failed(affine::mergeOffsetsSizesAndStrides(
+ rewriter, subView.getLoc(), srcSubView, subView,
+ srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
+ return failure();
// Replace original op.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
- subView, subView.getType(), srcSubView.getSource(),
- getAsOpFoldResult(resolvedOffsets), resolvedSizes,
- srcSubView.getMixedStrides());
-
+ subView, subView.getType(), srcSubView.getSource(), newOffsets,
+ newSizes, newStrides);
return success();
}
};
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 114ba86cda718..fb8ac2e9858e7 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -508,6 +508,28 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
// -----
+// CHECK-LABEL: func.func @subview_of_subview_no_unit_stride(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, strided<[8, 1]>>)
+// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[ARG0]][3, 3] [2, 2] [4, 4] : memref<8x8xf32, strided<[8, 1]>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
+func.func @subview_of_subview_no_unit_stride(%arg0: memref<8x8xf32, strided<[8, 1]>>) -> memref<2x2xf32, strided<[32, 4], offset: 27>> {
+ %subview = memref.subview %arg0[1, 1] [4, 4] [2, 2] : memref<8x8xf32, strided<[8, 1]>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
+ %subview_0 = memref.subview %subview[1, 1] [2, 2] [2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
+ return %subview_0 : memref<2x2xf32, strided<[32, 4], offset: 27>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @subview_of_subview_rank_reducing_no_unit_stride(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, strided<[8, 1]>>)
+// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[ARG0]][3, 3] [1, 2] [4, 4] : memref<8x8xf32, strided<[8, 1]>> to memref<2xf32, strided<[4], offset: 27>>
+func.func @subview_of_subview_rank_reducing_no_unit_stride(%arg0: memref<8x8xf32, strided<[8, 1]>>) -> memref<2xf32, strided<[4], offset: 27>> {
+ %subview = memref.subview %arg0[1, 1] [4, 4] [2, 2] : memref<8x8xf32, strided<[8, 1]>> to memref<4x4xf32, strided<[16, 2], offset: 9>>
+ %subview_0 = memref.subview %subview[1, 1] [1, 2] [2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2xf32, strided<[4], offset: 27>>
+ return %subview_0 : memref<2xf32, strided<[4], offset: 27>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_load_keep_nontemporal(
// CHECK: memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true}
func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
``````````
</details>
https://github.com/llvm/llvm-project/pull/192437
More information about the Mlir-commits
mailing list