[Mlir-commits] [mlir] [mlir][memref] Add foldUseDominateCast function to castOp (PR #168337)

lonely eagle llvmlistbot at llvm.org
Mon Nov 17 22:46:02 PST 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/168337

>From b025178c6406849b16c753fe0fb9f3e9920606ce Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 17 Nov 2025 09:52:34 +0000
Subject: [PATCH 1/3] add foldUseDominateCast to castOp.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 28 +++++++++++++++++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return false;
 }
 
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+  auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+  if (!funcOp)
+    return {};
+  auto castOps = castOp->getBlock()->getOps<CastOp>();
+  CastOp dominateCastOp = castOp;
+  SmallVector<CastOp> ops(castOps);
+  mlir::DominanceInfo dominanceInfo(castOp);
+  for (auto it : castOps) {
+    if (it.getSource() == dominateCastOp.getSource() &&
+        it.getDest().getType() == dominateCastOp.getDest().getType() &&
+        dominanceInfo.dominates(it.getOperation(),
+                                dominateCastOp.getOperation())) {
+      dominateCastOp = it;
+    }
+  }
+  return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
-  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+  OpFoldResult result;
+  if (OpFoldResult value = foldUseDominateCast(*this))
+    result = value;
+  if (succeeded(foldMemRefCast(*this)))
+    result = getResult();
+  return result;
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+  return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+  // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> () 
+  return
+}

>From f3127f08d48266113c1a4f84cbe03c93944af194 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 18 Nov 2025 06:42:14 +0000
Subject: [PATCH 2/3] add HoistCastPos pattern.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  1 +
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 51 ++++++++++--------
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 52 +++++++++++++------
 mlir/test/Dialect/SCF/one-shot-bufferize.mlir |  4 +-
 4 files changed, 68 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..c342f25fe61a9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -565,6 +565,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index aafd908c7af7e..b489f71b775e0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -795,32 +795,37 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return false;
 }
 
-static OpFoldResult foldUseDominateCast(CastOp castOp) {
-  auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
-  if (!funcOp)
-    return {};
-  auto castOps = castOp->getBlock()->getOps<CastOp>();
-  CastOp dominateCastOp = castOp;
-  SmallVector<CastOp> ops(castOps);
-  mlir::DominanceInfo dominanceInfo(castOp);
-  for (auto it : castOps) {
-    if (it.getSource() == dominateCastOp.getSource() &&
-        it.getDest().getType() == dominateCastOp.getDest().getType() &&
-        dominanceInfo.dominates(it.getOperation(),
-                                dominateCastOp.getOperation())) {
-      dominateCastOp = it;
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
+  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+}
+
+namespace {
+struct HoistCastPos : public OpRewritePattern<CastOp> {
+  using OpRewritePattern<CastOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CastOp castOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto *defineOp = castOp.getSource().getDefiningOp()) {
+      if (defineOp->getBlock() != castOp->getBlock()) {
+        rewriter.moveOpAfter(castOp.getOperation(), defineOp);
+        return success();
+      }
+      return failure();
+    } else {
+      auto argument = cast<BlockArgument>(castOp.getSource());
+      if (argument.getOwner() != castOp->getBlock()) {
+        rewriter.moveOpBefore(castOp.getOperation(),
+                              &argument.getOwner()->front());
+        return success();
+      }
+      return failure();
     }
   }
-  return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
-}
+};
+} // namespace
 
-OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
-  OpFoldResult result;
-  if (OpFoldResult value = foldUseDominateCast(*this))
-    result = value;
-  if (succeeded(foldMemRefCast(*this)))
-    result = getResult();
-  return result;
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<HoistCastPos>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3638b8d4ac701..e435615cc8e26 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1370,19 +1370,41 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
 
 // -----
 
-func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
-  return
-}
-
-// CHECK-LABEL: func @fold_use_dominate_cast(
-//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>)
-func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
-  // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
-  %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
-  %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
-  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
-  call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
-  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
-  call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> () 
-  return
+// CHECK-LABEL: func @hoist_cast_pos
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<10xf32>,
+//  CHECK-SAME:   %[[ARG1:.*]]: i1
+func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG1]]
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
+}
+
+// -----
+
+// CHECK-LABEL: func.func @hoist_cast_pos_alloc
+//  CHECK-SAME:   %[[ARG0:.*]]: i1
+func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
+  //      CHECK: %[[ALLOC_0:.*]] = memref.alloc()
+  //      CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
+  //      CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
+  // CHECK-NEXT: cf.cond_br %[[ARG0]]
+  %alloc = memref.alloc() : memref<10xf32>
+  cf.cond_br %arg, ^bb1, ^bb2
+^bb1:
+  %cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_1]]
+  return %cast : memref<?xf32>
+^bb2:
+  %cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
+  // CHECK: return %[[CAST_0]]
+  return %cast1 : memref<?xf32> 
 }
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index af09dc865e2de..1ae6e3a8a3cf7 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -922,13 +922,13 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
 //  CHECK-SAME:     %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
 func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
   // Throw in a tensor that bufferizes to a different layout map.
-  // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+  // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>  
+  // CHECK:   %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
   %a = bufferization.alloc_tensor() : tensor<5xf32>
 
   // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
   %0 = scf.index_switch %pred -> tensor<5xf32>
   // CHECK: case 2 {
-  // CHECK:   %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
   // CHECK:   scf.yield %[[cast]]
   case 2 {
     scf.yield %a: tensor<5xf32>

>From 6f0497c201b6d43773a8138b7b38e69ec8ef4b34 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 18 Nov 2025 06:45:46 +0000
Subject: [PATCH 3/3] cleanup code.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 2 --
 mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 2 +-
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b489f71b775e0..e94db0ccb11de 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,12 +13,10 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 1ae6e3a8a3cf7..d1c1f1780e353 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -923,7 +923,7 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
 func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
   // Throw in a tensor that bufferizes to a different layout map.
   // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>  
-  // CHECK:   %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+  // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
   %a = bufferization.alloc_tensor() : tensor<5xf32>
 
   // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>



More information about the Mlir-commits mailing list