[Mlir-commits] [mlir] [mlir][memref] Add foldUseDominateCast function to castOp (PR #168337)
lonely eagle
llvmlistbot at llvm.org
Mon Nov 17 01:55:44 PST 2025
https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/168337
The foldUseDominateCast function is used to eliminate redundant casts.
>From b025178c6406849b16c753fe0fb9f3e9920606ce Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 17 Nov 2025 09:52:34 +0000
Subject: [PATCH] add foldUseDominateCast to castOp.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 28 +++++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++
2 files changed, 46 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+ auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+ if (!funcOp)
+ return {};
+ auto castOps = castOp->getBlock()->getOps<CastOp>();
+ CastOp dominateCastOp = castOp;
+ SmallVector<CastOp> ops(castOps);
+ mlir::DominanceInfo dominanceInfo(castOp);
+ for (auto it : castOps) {
+ if (it.getSource() == dominateCastOp.getSource() &&
+ it.getDest().getType() == dominateCastOp.getDest().getType() &&
+ dominanceInfo.dominates(it.getOperation(),
+ dominateCastOp.getOperation())) {
+ dominateCastOp = it;
+ }
+ }
+ return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
- return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+ OpFoldResult result;
+ if (OpFoldResult value = foldUseDominateCast(*this))
+ result = value;
+ if (succeeded(foldMemRefCast(*this)))
+ result = getResult();
+ return result;
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+ return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list