[Mlir-commits] [mlir] [mlir][memref] Add foldUseDominateCast function to castOp (PR #168337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 17 01:56:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
The foldUseDominateCast function is used to eliminate redundant casts.
---
Full diff: https://github.com/llvm/llvm-project/pull/168337.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+27-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+19)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f270da6..aafd908c7af7e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -13,10 +13,12 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
@@ -793,8 +795,32 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}
+static OpFoldResult foldUseDominateCast(CastOp castOp) {
+ auto funcOp = castOp->getParentOfType<FunctionOpInterface>();
+ if (!funcOp)
+ return {};
+ auto castOps = castOp->getBlock()->getOps<CastOp>();
+ CastOp dominateCastOp = castOp;
+ SmallVector<CastOp> ops(castOps);
+ mlir::DominanceInfo dominanceInfo(castOp);
+ for (auto it : castOps) {
+ if (it.getSource() == dominateCastOp.getSource() &&
+ it.getDest().getType() == dominateCastOp.getDest().getType() &&
+ dominanceInfo.dominates(it.getOperation(),
+ dominateCastOp.getOperation())) {
+ dominateCastOp = it;
+ }
+ }
+ return dominateCastOp == castOp ? Value() : dominateCastOp.getResult();
+}
+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
- return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
+ OpFoldResult result;
+ if (OpFoldResult value = foldUseDominateCast(*this))
+ result = value;
+ if (succeeded(foldMemRefCast(*this)))
+ result = getResult();
+ return result;
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 313090272ef90..3638b8d4ac701 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1367,3 +1367,22 @@ func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index)
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
+
+// -----
+
+func.func @fold_use_dominate_cast_foo(%arg0: memref<?xf32, strided<[?], offset:?>>) {
+ return
+}
+
+// CHECK-LABEL: func @fold_use_dominate_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>)
+func.func @fold_use_dominate_cast(%arg: memref<?xf32>) {
+ // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
+ %cast0 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %cast1 = memref.cast %arg : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast0) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ // CHECK: call @fold_use_dominate_cast_foo(%[[CAST_0]])
+ call @fold_use_dominate_cast_foo(%cast1) : (memref<?xf32, strided<[?], offset:?>>) -> ()
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/168337
More information about the Mlir-commits
mailing list