[Mlir-commits] [mlir] [MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver (PR #77258)

Billy Zhu llvmlistbot at llvm.org
Sun Jan 7 14:13:38 PST 2024


https://github.com/zyx-billy created https://github.com/llvm/llvm-project/pull/77258

GreedyPatternRewriteDriver needs to handle failures of `materializeConstant` gracefully. Previously it was not checking whether the returned op was null and crashing.

>From 7b290f4d538f8c5b8b0ea06e3584f45990468fd6 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Sun, 7 Jan 2024 13:42:52 -0800
Subject: [PATCH] handle materializeConstant failure

---
 .../Utils/GreedyPatternRewriteDriver.cpp      | 31 ++++++++++++++++---
 mlir/test/Transforms/canonicalize.mlir        | 11 +++++++
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-        changed = true;
         if (foldResults.empty()) {
           // Op was modified in-place.
           notifyOperationModified(op);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
           if (config.scope && failed(verify(config.scope->getParentOp())))
             llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         OpBuilder::InsertionGuard g(*this);
         setInsertionPoint(op);
         SmallVector<Value> replacements;
+        bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
              llvm::zip_equal(foldResults, op->getResultTypes())) {
           if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
               *this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+          if (!constOp) {
+            // If materialization fails, cleanup any operations generated for
+            // the previous results.
+            llvm::SmallDenseSet<Operation *> replacementOps;
+            std::transform(replacements.begin(), replacements.end(),
+                           replacementOps.begin(), [](Value replacement) {
+                             return replacement.getDefiningOp();
+                           });
+            for (Operation *op : replacementOps)
+              eraseOp(op);
+
+            materializationSucceeded = false;
+            break;
+          }
+
           assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
                  "materializeConstant produced op that is not a ConstantLike");
           assert(constOp->getResultTypes()[0] == resultType &&
                  "materializeConstant produced incorrect result type");
           replacements.push_back(constOp->getResult(0));
         }
-        replaceOp(op, replacements);
+
+        if (materializationSucceeded) {
+          replaceOp(op, replacements);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        if (config.scope && failed(verify(config.scope->getParentOp())))
-          llvm::report_fatal_error("IR failed to verify after folding");
+          if (config.scope && failed(verify(config.scope->getParentOp())))
+            llvm::report_fatal_error("IR failed to verify after folding");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        continue;
+          continue;
+        }
       }
     }
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
 // CHECK-NEXT: scf.yield %[[ALLOC3_2]]
 //      CHECK: memref.dealloc %[[ALLOC1]]
 // CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+  %const = index.constant 1234
+  // Cannot materialize this castu's output constant.
+  // CHECK: index.castu
+  %u = index.castu %const : index to i64
+  return %u: i64
+}



More information about the Mlir-commits mailing list