[Mlir-commits] [mlir] [MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver (PR #77258)
Billy Zhu
llvmlistbot at llvm.org
Mon Jan 8 10:05:20 PST 2024
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/77258
>From 7b290f4d538f8c5b8b0ea06e3584f45990468fd6 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Sun, 7 Jan 2024 13:42:52 -0800
Subject: [PATCH 1/2] handle materializeConstant failure
---
.../Utils/GreedyPatternRewriteDriver.cpp | 31 ++++++++++++++++---
mlir/test/Transforms/canonicalize.mlir | 11 +++++++
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
+ bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+ if (!constOp) {
+ // If materialization fails, cleanup any operations generated for
+ // the previous results.
+ llvm::SmallDenseSet<Operation *> replacementOps;
+ std::transform(replacements.begin(), replacements.end(),
+ replacementOps.begin(), [](Value replacement) {
+ return replacement.getDefiningOp();
+ });
+ for (Operation *op : replacementOps)
+ eraseOp(op);
+
+ materializationSucceeded = false;
+ break;
+ }
+
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
- replaceOp(op, replacements);
+
+ if (materializationSucceeded) {
+ replaceOp(op, replacements);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ if (config.scope && failed(verify(config.scope->getParentOp())))
+ llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- continue;
+ continue;
+ }
}
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
// CHECK-NEXT: scf.yield %[[ALLOC3_2]]
// CHECK: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+ %const = index.constant 1234
+ // Cannot materialize this castu's output constant.
+ // CHECK: index.castu
+ %u = index.castu %const : index to i64
+ return %u: i64
+}
>From f49a49153490d125e71cb42219f53d01109f53e5 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Mon, 8 Jan 2024 10:05:07 -0800
Subject: [PATCH 2/2] check replacement values have no use
---
.../Transforms/Utils/GreedyPatternRewriteDriver.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 146091f96b8fd3..67c2d9d59f4c92 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -468,12 +468,15 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If materialization fails, cleanup any operations generated for
// the previous results.
llvm::SmallDenseSet<Operation *> replacementOps;
- std::transform(replacements.begin(), replacements.end(),
- replacementOps.begin(), [](Value replacement) {
- return replacement.getDefiningOp();
- });
- for (Operation *op : replacementOps)
+ for (Value replacement : replacements) {
+ assert(replacement.use_empty() &&
+ "folder reused existing op for one result but constant "
+ "materialization failed for another result");
+ replacementOps.insert(replacement.getDefiningOp());
+ }
+ for (Operation *op : replacementOps) {
eraseOp(op);
+ }
materializationSucceeded = false;
break;
More information about the Mlir-commits
mailing list