[Mlir-commits] [mlir] e5551a6 - [mlir][memref] memref.view canonicalizations fixes (#173237)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 22 05:57:24 PST 2025
Author: Ivan Butygin
Date: 2025-12-22T16:57:20+03:00
New Revision: e5551a692ed41f3e8909495a81e78e72cb3a6af4
URL: https://github.com/llvm/llvm-project/commit/e5551a692ed41f3e8909495a81e78e72cb3a6af4
DIFF: https://github.com/llvm/llvm-project/commit/e5551a692ed41f3e8909495a81e78e72cb3a6af4.diff
LOG: [mlir][memref] memref.view canonicalizations fixes (#173237)
* Do not fold if offset is not zero
* Remove unnecessary alloc check
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 620cc97b9e3a2..eb321bbc15ded 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3675,7 +3675,8 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
MemRefType sourceMemrefType = getSource().getType();
MemRefType resultMemrefType = getResult().getType();
- if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
+ if (resultMemrefType == sourceMemrefType &&
+ resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
return getViewSource();
return {};
@@ -3684,7 +3685,7 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
namespace {
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
- using OpRewritePattern<ViewOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
@@ -3751,26 +3752,22 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
}
};
+/// view(memref.cast(%source)) -> view(%source).
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
- using OpRewritePattern<ViewOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
- Value memrefOperand = viewOp.getOperand(0);
- CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
+ auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
if (!memrefCastOp)
return failure();
- Value allocOperand = memrefCastOp.getOperand();
- AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
- if (!allocOp)
- return failure();
- rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
- viewOp.getByteShift(),
- viewOp.getSizes());
+
+ rewriter.replaceOpWithNewOp<ViewOp>(
+ viewOp, viewOp.getType(), memrefCastOp.getSource(),
+ viewOp.getByteShift(), viewOp.getSizes());
return success();
}
};
-
} // namespace
void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..7b4dea6a24396 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1336,21 +1336,52 @@ func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> {
// -----
+// CHECK-LABEL: func @fold_view_cast
+// CHECK-SAME: (%[[ARG:.*]]: memref<128xi8>)
+func.func @fold_view_cast(%0: memref<128xi8>) -> memref<i32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C0]]][] : memref<128xi8> to memref<i32>
+ // CHECK: return %[[RES]]
+ %1 = memref.cast %0 : memref<128xi8> to memref<?xi8>
+ %res = memref.view %1[%c0][] : memref<?xi8> to memref<i32>
+ return %res : memref<i32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_view_same_source_result_types
+// CHECK-SAME: (%[[ARG:.*]]: memref<128xi8>)
func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128xi8> {
- %c0 = arith.constant 0: index
+ %c0 = arith.constant 0 : index
// CHECK-NOT: memref.view
+ // CHECK: return %[[ARG]]
%res = memref.view %0[%c0][] : memref<128xi8> to memref<128xi8>
return %res : memref<128xi8>
}
// -----
-// CHECK-LABEL: func @non_fold_view_same_source_res_types
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index) -> memref<?xi8> {
+// CHECK-LABEL: func @non_fold_view_non_zero_offset
+// CHECK-SAME: (%[[ARG:.*]]: memref<128xi8>)
+func.func @non_fold_view_non_zero_offset(%0: memref<128xi8>) -> memref<128xi8> {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C1]]][] : memref<128xi8> to memref<128xi8>
+ // CHECK: return %[[RES]]
+ %res = memref.view %0[%c1][] : memref<128xi8> to memref<128xi8>
+ return %res : memref<128xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_fold_view_same_source_dynamic_size
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
+func.func @non_fold_view_same_source_dynamic_size(%0: memref<?xi8>, %arg0 : index) -> memref<?xi8> {
%c0 = arith.constant 0: index
- // CHECK: memref.view
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = memref.view %[[ARG]][%[[C0]]][%[[SIZE]]] : memref<?xi8> to memref<?xi8>
+ // CHECK: return %[[RES]]
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
More information about the Mlir-commits
mailing list