[Mlir-commits] [mlir] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (PR #104671)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 17 03:08:26 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/104671

Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.

>From 8703175c448143e769957d9e2b5e02c9c1ac6e0f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 17 Aug 2024 12:01:58 +0200
Subject: [PATCH] [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts`
 implementation

Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
---
 .../mlir/Transforms/DialectConversion.h       | 23 ++++++
 .../ReconcileUnrealizedCasts.cpp              | 58 +--------------
 .../Transforms/Utils/DialectConversion.cpp    | 74 +++++++++++++++++++
 3 files changed, 100 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..86f0337dd90dfe 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
   RewriterBase::Listener *listener = nullptr;
 };
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..d01e3dcbe8cc45 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    // Gather all unrealized_conversion_cast ops.
-    SetVector<UnrealizedConversionCastOp> worklist;
+    SmallVector<UnrealizedConversionCastOp> ops;
     getOperation()->walk(
         [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
-    // Helper function that adds all operands to the worklist that are an
-    // unrealized_conversion_cast op result.
-    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
-      for (Value v : castOp.getInputs())
-        if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-          worklist.insert(inputCastOp);
-    };
-
-    // Helper function that return the unrealized_conversion_cast op that
-    // defines all inputs of the given op (in the same order). Return "nullptr"
-    // if there is no such op.
-    auto getInputCast =
-        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
-      if (castOp.getInputs().empty())
-        return {};
-      auto inputCastOp = castOp.getInputs()
-                             .front()
-                             .getDefiningOp<UnrealizedConversionCastOp>();
-      if (!inputCastOp)
-        return {};
-      if (inputCastOp.getOutputs() != castOp.getInputs())
-        return {};
-      return inputCastOp;
-    };
-
-    // Process ops in the worklist bottom-to-top.
-    while (!worklist.empty()) {
-      UnrealizedConversionCastOp castOp = worklist.pop_back_val();
-      if (castOp->use_empty()) {
-        // DCE: If the op has no users, erase it. Add the operands to the
-        // worklist to find additional DCE opportunities.
-        enqueueOperands(castOp);
-        castOp->erase();
-        continue;
-      }
-
-      // Traverse the chain of input cast ops to see if an op with the same
-      // input types can be found.
-      UnrealizedConversionCastOp nextCast = castOp;
-      while (nextCast) {
-        if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
-          // Found a cast where the input types match the output types of the
-          // matched op. We can directly use those inputs and the matched op can
-          // be removed.
-          enqueueOperands(castOp);
-          castOp.replaceAllUsesWith(nextCast.getInputs());
-          castOp->erase();
-          break;
-        }
-        nextCast = getInputCast(nextCast);
-      }
-    }
+    reconcileUnrealizedCasts(ops);
   }
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8a4c7463a69a95..0da8eabadb4ee1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
+  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+                                                 castOps.end());
+  // This set is maintained only if `remainingCastOps` is provided.
+  DenseSet<Operation *> erasedOps;
+
+  // Helper function that adds all operands to the worklist that are an
+  // unrealized_conversion_cast op result.
+  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+    for (Value v : castOp.getInputs())
+      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+        worklist.insert(inputCastOp);
+  };
+
+  // Helper function that return the unrealized_conversion_cast op that
+  // defines all inputs of the given op (in the same order). Return "nullptr"
+  // if there is no such op.
+  auto getInputCast =
+      [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+    if (castOp.getInputs().empty())
+      return {};
+    auto inputCastOp =
+        castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+    if (!inputCastOp)
+      return {};
+    if (inputCastOp.getOutputs() != castOp.getInputs())
+      return {};
+    return inputCastOp;
+  };
+
+  // Process ops in the worklist bottom-to-top.
+  while (!worklist.empty()) {
+    UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+    if (castOp->use_empty()) {
+      // DCE: If the op has no users, erase it. Add the operands to the
+      // worklist to find additional DCE opportunities.
+      enqueueOperands(castOp);
+      if (remainingCastOps)
+        erasedOps.insert(castOp.getOperation());
+      castOp->erase();
+      continue;
+    }
+
+    // Traverse the chain of input cast ops to see if an op with the same
+    // input types can be found.
+    UnrealizedConversionCastOp nextCast = castOp;
+    while (nextCast) {
+      if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+        // Found a cast where the input types match the output types of the
+        // matched op. We can directly use those inputs and the matched op can
+        // be removed.
+        enqueueOperands(castOp);
+        castOp.replaceAllUsesWith(nextCast.getInputs());
+        if (remainingCastOps)
+          erasedOps.insert(castOp.getOperation());
+        castOp->erase();
+        break;
+      }
+      nextCast = getInputCast(nextCast);
+    }
+  }
+
+  if (remainingCastOps)
+    for (UnrealizedConversionCastOp op : castOps)
+      if (!erasedOps.contains(op.getOperation()))
+        remainingCastOps->push_back(op);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list