[Mlir-commits] [mlir] a9f6224 - [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (#104671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 23 08:46:35 PDT 2024


Author: Matthias Springer
Date: 2024-08-23T08:46:31-07:00
New Revision: a9f62244f28a64e7b7338c2299ba169df70fbb03

URL: https://github.com/llvm/llvm-project/commit/a9f62244f28a64e7b7338c2299ba169df70fbb03
DIFF: https://github.com/llvm/llvm-project/commit/a9f62244f28a64e7b7338c2299ba169df70fbb03.diff

LOG: [mlir][Transforms][NFC] Move `ReconcileUnrealizedCasts` implementation (#104671)

Move the implementation of `ReconcileUnrealizedCasts` to
`DialectConversion.cpp`, so that it can be called from there in a future
commit.

This commit is in preparation of decoupling argument/source/target
materializations from the dialect conversion framework. The existing
logic around unresolved materializations that predicts IR changes to
decide if a cast op can be folded/erased will become obsolete, as
`ReconcileUnrealizedCasts` will perform these kind of foldings on fully
materialized IR.

---------

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0aeb..60113bdef16a23 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
   RewriterBase::Listener *listener = nullptr;
 };
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided).
+///
+/// This function processes cast ops in a worklist-driven fashion. For each
+/// cast op, if the chain of input casts eventually reaches a cast op where the
+/// input types match the output types of the matched op, replace the matched
+/// op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
+void reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 12e0029cebfd0d..2ce6dcbb490149 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    // Gather all unrealized_conversion_cast ops.
-    SetVector<UnrealizedConversionCastOp> worklist;
+    SmallVector<UnrealizedConversionCastOp> ops;
     getOperation()->walk(
-        [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
-
-    // Helper function that adds all operands to the worklist that are an
-    // unrealized_conversion_cast op result.
-    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
-      for (Value v : castOp.getInputs())
-        if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-          worklist.insert(inputCastOp);
-    };
-
-    // Helper function that return the unrealized_conversion_cast op that
-    // defines all inputs of the given op (in the same order). Return "nullptr"
-    // if there is no such op.
-    auto getInputCast =
-        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
-      if (castOp.getInputs().empty())
-        return {};
-      auto inputCastOp = castOp.getInputs()
-                             .front()
-                             .getDefiningOp<UnrealizedConversionCastOp>();
-      if (!inputCastOp)
-        return {};
-      if (inputCastOp.getOutputs() != castOp.getInputs())
-        return {};
-      return inputCastOp;
-    };
-
-    // Process ops in the worklist bottom-to-top.
-    while (!worklist.empty()) {
-      UnrealizedConversionCastOp castOp = worklist.pop_back_val();
-      if (castOp->use_empty()) {
-        // DCE: If the op has no users, erase it. Add the operands to the
-        // worklist to find additional DCE opportunities.
-        enqueueOperands(castOp);
-        castOp->erase();
-        continue;
-      }
-
-      // Traverse the chain of input cast ops to see if an op with the same
-      // input types can be found.
-      UnrealizedConversionCastOp nextCast = castOp;
-      while (nextCast) {
-        if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
-          // Found a cast where the input types match the output types of the
-          // matched op. We can directly use those inputs and the matched op can
-          // be removed.
-          enqueueOperands(castOp);
-          castOp.replaceAllUsesWith(nextCast.getInputs());
-          castOp->erase();
-          break;
-        }
-        nextCast = getInputCast(nextCast);
-      }
-    }
+        [&](UnrealizedConversionCastOp castOp) { ops.push_back(castOp); });
+    reconcileUnrealizedCasts(ops);
   }
 };
 

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bf2990c257bad2..adf012a261cb7e 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2870,6 +2870,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Reconcile Unrealized Casts
+//===----------------------------------------------------------------------===//
+
+void mlir::reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
+                                                 castOps.end());
+  // This set is maintained only if `remainingCastOps` is provided.
+  DenseSet<Operation *> erasedOps;
+
+  // Helper function that adds all operands to the worklist that are an
+  // unrealized_conversion_cast op result.
+  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+    for (Value v : castOp.getInputs())
+      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+        worklist.insert(inputCastOp);
+  };
+
+  // Helper function that return the unrealized_conversion_cast op that
+  // defines all inputs of the given op (in the same order). Return "nullptr"
+  // if there is no such op.
+  auto getInputCast =
+      [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+    if (castOp.getInputs().empty())
+      return {};
+    auto inputCastOp =
+        castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
+    if (!inputCastOp)
+      return {};
+    if (inputCastOp.getOutputs() != castOp.getInputs())
+      return {};
+    return inputCastOp;
+  };
+
+  // Process ops in the worklist bottom-to-top.
+  while (!worklist.empty()) {
+    UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+    if (castOp->use_empty()) {
+      // DCE: If the op has no users, erase it. Add the operands to the
+      // worklist to find additional DCE opportunities.
+      enqueueOperands(castOp);
+      if (remainingCastOps)
+        erasedOps.insert(castOp.getOperation());
+      castOp->erase();
+      continue;
+    }
+
+    // Traverse the chain of input cast ops to see if an op with the same
+    // input types can be found.
+    UnrealizedConversionCastOp nextCast = castOp;
+    while (nextCast) {
+      if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+        // Found a cast where the input types match the output types of the
+        // matched op. We can directly use those inputs and the matched op can
+        // be removed.
+        enqueueOperands(castOp);
+        castOp.replaceAllUsesWith(nextCast.getInputs());
+        if (remainingCastOps)
+          erasedOps.insert(castOp.getOperation());
+        castOp->erase();
+        break;
+      }
+      nextCast = getInputCast(nextCast);
+    }
+  }
+
+  if (remainingCastOps)
+    for (UnrealizedConversionCastOp op : castOps)
+      if (!erasedOps.contains(op.getOperation()))
+        remainingCastOps->push_back(op);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list