[Mlir-commits] [mlir] [mlir][memref] Add foldUseDominateCast function to castOp (PR #168337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 01:56:20 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

The foldUseDominateCast function is used to eliminate redundant casts.

---
Full diff: https://github.com/llvm/llvm-project/pull/168337.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+27-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return false;
 }
 
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+  auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+  if (!funcOp)
+    return {};
+  auto castOps = castOp->getBlock()->getOps<CastOp>();
+  CastOp dominateCastOp = castOp;
+  SmallVector<CastOp> ops(castOps);
+  mlir::DominanceInfo dominanceInfo(castOp);
+  for (auto it : castOps) {
+    if (it.getSource() == dominateCastOp.getSource() &&
+        it.getDest().getType() == dominateCastOp.getDest().getType() &&
+        dominanceInfo.dominates(it.getOperation(),
+                                dominateCastOp.getOperation())) {
+      dominateCastOp = it;
+    }
+  }
+  return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
-  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+  OpFoldResult result;
+  if (OpFoldResult value = foldUseDominateCast(*this))
+    result = value;
+  if (succeeded(foldMemRefCast(*this)))
+    result = getResult();
+  return result;
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+  return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+  // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+  %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+  // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+  call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> () 
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/168337


More information about the Mlir-commits mailing list