[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Hardening `replaceOp` (PR #109540)

Matthias Springer llvmlistbot at llvm.org
Sat Sep 28 05:22:07 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/109540

>From 4b71bece34043d31a165166216b53011c04597b0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 21 Sep 2024 20:03:17 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Extra checks during
 `replaceOp`

This commit adds extra checks/assertions to the `ConversionPatternRewriterImpl::notifyOpReplaced` to improve its robustness.

Replacing an `unrealized_conversion_cast` op that was created by the driver is forbidden and is now caught early during `replaceOp`. It may work in some cases, but is generally dangerous because the conversion driver keeps track of these ops. (Erasing is them is fine.) This change is also in preparation of a subsequent commit that splits the `ConversionValueMapping` into replacements and materializations (with the goal of simplifying block signature conversions).

`null` replacement values are no longer registered in the `ConversionValueMapping`. This was an oversight in #106760. `null` values in the mapping could result in crashes when using the `ConversionValueMapping` API.
---
 .../Transforms/Utils/DialectConversion.cpp    | 31 ++++++++++++++-----
 1 file changed, 23 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 69036e947ebdb0..69bd3b64ff2eab 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1361,16 +1361,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
   assert(newValues.size() == op->getNumResults());
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
+  // Check if replaced op is an unresolved materialization, i.e., an
+  // unrealized_conversion_cast op that was created by the conversion driver.
+  bool isUnresolvedMaterialization = false;
+  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+    if (unresolvedMaterializations.contains(castOp))
+      isUnresolvedMaterialization = true;
+
   // Create mappings for each of the new result values.
   for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
     if (!newValue) {
       // This result was dropped and no replacement value was provided.
-      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
-        if (unresolvedMaterializations.contains(castOp)) {
-          // Do not create another materializations if we are erasing a
-          // materialization.
-          continue;
-        }
+      if (isUnresolvedMaterialization) {
+        // Do not create another materializations if we are erasing a
+        // materialization.
+        continue;
       }
 
       // Materialize a replacement value "out of thin air".
@@ -1378,10 +1383,20 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), currentTypeConverter);
+    } else {
+      // Make sure that the user does not mess with unresolved materializations
+      // that were inserted by the conversion driver. We keep track of these
+      // ops in internal data structures. Erasing them must be allowed because
+      // this can happen when the user is erasing an entire block (including
+      // its body). But replacing them with another value should be forbidden
+      // to avoid problems with the `mapping`.
+      assert(!isUnresolvedMaterialization &&
+             "attempting to replace an unresolved materialization");
     }
 
-    // Remap, and check for any result type changes.
-    mapping.map(result, newValue);
+    // Remap result to replacement value.
+    if (newValue)
+      mapping.map(result, newValue);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);



More information about the Mlir-commits mailing list