[Mlir-commits] [mlir] [mlir][memref] Add foldUseDominateCast function to castOp (PR #168337)
lonely eagle
llvmlistbot at llvm.org
Mon Nov 17 22:46:02 PST 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/168337
>From b025178c6406849b16c753fe0fb9f3e9920606ce Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 17 Nov 2025 09:52:34 +0000
Subject: [PATCH 1/3] add foldUseDominateCast to castOp.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 28 +++++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++
2 files changed, 46 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+ auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+ if (!funcOp)
+ return {};
+ auto castOps = castOp->getBlock()->getOps<CastOp>();
+ CastOp dominateCastOp = castOp;
+ SmallVector<CastOp> ops(castOps);
+ mlir::DominanceInfo dominanceInfo(castOp);
+ for (auto it : castOps) {
+ if (it.getSource() == dominateCastOp.getSource() &&
+ it.getDest().getType() == dominateCastOp.getDest().getType() &&
+ dominanceInfo.dominates(it.getOperation(),
+ dominateCastOp.getOperation())) {
+ dominateCastOp = it;
+ }
+ }
+ return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
- return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+ OpFoldResult result;
+ if (OpFoldResult value = foldUseDominateCast(*this))
+ result = value;
+ if (succeeded(foldMemRefCast(*this)))
+ result = getResult();
+ return result;
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+ return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ return
+}
>From f3127f08d48266113c1a4f84cbe03c93944af194 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 18 Nov 2025 06:42:14 +0000
Subject: [PATCH 2/3] add HoistCastPos pattern.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 +
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 51 ++++++++++--------
mlir/test/Dialect/MemRef/canonicalize.mlir | 52 +++++++++++++------
mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 4 +-
4 files changed, 68 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..c342f25fe61a9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -565,6 +565,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index aafd908c7af7e..b489f71b775e0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -795,32 +795,37 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
-static OpFoldResult foldUseDominateCast(CastOp castOp) {
- auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
- if (!funcOp)
- return {};
- auto castOps = castOp->getBlock()->getOps<CastOp>();
- CastOp dominateCastOp = castOp;
- SmallVector<CastOp> ops(castOps);
- mlir::DominanceInfo dominanceInfo(castOp);
- for (auto it : castOps) {
- if (it.getSource() == dominateCastOp.getSource() &&
- it.getDest().getType() == dominateCastOp.getDest().getType() &&
- dominanceInfo.dominates(it.getOperation(),
- dominateCastOp.getOperation())) {
- dominateCastOp = it;
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
+ return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+}
+
+namespace {
+struct HoistCastPos : public OpRewritePattern<CastOp> {
+ using OpRewritePattern<CastOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ if (auto *defineOp = castOp.getSource().getDefiningOp()) {
+ if (defineOp->getBlock() != castOp->getBlock()) {
+ rewriter.moveOpAfter(castOp.getOperation(), defineOp);
+ return success();
+ }
+ return failure();
+ } else {
+ auto argument = cast<BlockArgument>(castOp.getSource());
+ if (argument.getOwner() != castOp->getBlock()) {
+ rewriter.moveOpBefore(castOp.getOperation(),
+ &argument.getOwner()->front());
+ return success();
+ }
+ return failure();
}
}
- return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
-}
+};
+} // namespace
-OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
- OpFoldResult result;
- if (OpFoldResult value = foldUseDominateCast(*this))
- result = value;
- if (succeeded(foldMemRefCast(*this)))
- result = getResult();
- return result;
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<HoistCastPos>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3638b8d4ac701..e435615cc8e26 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1370,19 +1370,41 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
// -----
-func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
- return
-}
-
-// CHECK-LABEL: func @fold_use_dominate_cast(
-// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>)
-func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
- // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
- %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
- call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
- // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
- call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> ()
- return
+// CHECK-LABEL: func @hoist_cast_pos
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG1]]
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+// CHECK-SAME: %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+ // CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+ // CHECK-NEXT: cf.cond_br %[[ARG0]]
+ %alloc = memref.alloc() : memref<10xf32>
+ cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+ %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_1]]
+ return %cast : memref<?xf32>
+^bb2:
+ %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+ // CHECK: return %[[CAST_0]]
+ return %cast1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index af09dc865e2de..1ae6e3a8a3cf7 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -922,13 +922,13 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
// CHECK-SAME: %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
// Throw in a tensor that bufferizes to a different layout map.
- // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+ // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+ // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
%a = bufferization.alloc_tensor() : tensor<5xf32>
// CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
%0 = scf.index_switch %pred -> tensor<5xf32>
// CHECK: case 2 {
- // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
// CHECK: scf.yield %[[cast]]
case 2 {
scf.yield %a: tensor<5xf32>
>From 6f0497c201b6d43773a8138b7b38e69ec8ef4b34 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 18 Nov 2025 06:45:46 +0000
Subject: [PATCH 3/3] cleanup code.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 --
mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 2 +-
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b489f71b775e0..e94db0ccb11de 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,12 +13,10 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 1ae6e3a8a3cf7..d1c1f1780e353 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -923,7 +923,7 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
// Throw in a tensor that bufferizes to a different layout map.
// CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
- // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+ // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
%a = bufferization.alloc_tensor() : tensor<5xf32>
// CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
More information about the Mlir-commits
mailing list