[Mlir-commits] [mlir] e90deaf - [MLIR] Reconciliation of chains of unrealized casts

Alex Zinenko llvmlistbot at llvm.org
Wed Aug 3 04:57:31 PDT 2022


Author: Michele Scuttari
Date: 2022-08-03T11:57:20Z
New Revision: e90deaf1217d9ea0316a3ec03e199c658f5757d5

URL: https://github.com/llvm/llvm-project/commit/e90deaf1217d9ea0316a3ec03e199c658f5757d5
DIFF: https://github.com/llvm/llvm-project/commit/e90deaf1217d9ea0316a3ec03e199c658f5757d5.diff

LOG: [MLIR] Reconciliation of chains of unrealized casts

The reconciliation pass has been improved to introduce the support for chains of casts, thus not limiting anymore the reconciliation to just consider pairs of unrealized casts.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D130711

Added: 
    mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
    mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 26ec204849fcb..00ca7af1fabdc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -619,8 +619,9 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
     ```
     %0 = "producer.op"() : () -> !type.A
     %1 = unrealized_conversion_cast %0 : !type.A to !type.B
-    %2 = unrealized_conversion_cast %1 : !type.B to !type.A
-    "consumer.op"(%2) : (!type.A) -> ()
+    %2 = unrealized_conversion_cast %1 : !type.B to !type.C
+    %3 = unrealized_conversion_cast %2 : !type.C to !type.A
+    "consumer.op"(%3) : (!type.A) -> ()
     ```
 
     Such situations appear when the consumer operation is converted by one pass

diff  --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 51968173e45d2..773e9b267c4f0 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -17,37 +17,86 @@ using namespace mlir;
 
 namespace {
 
-/// Removes `unrealized_conversion_cast`s whose results are only used by other
-/// `unrealized_conversion_cast`s converting back to the original type. This
-/// pattern is complementary to the folder and can be used to process operations
-/// starting from the first, i.e. the usual traversal order in dialect
-/// conversion. The folder, on the other hand, can only apply to the last
-/// operation in a chain of conversions because it is not expected to walk
-/// use-def chains. One would need to declare cast ops as dynamically illegal
-/// with a complex condition in order to eliminate them using the folder alone
-/// in the dialect conversion infra.
+/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
+/// the same as the input ones.
+/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
+/// represent a noop within the IR, and thus the initial input values can be
+/// propagated.
+/// The same does not hold for 'open' chains chains of casts, such as
+/// `A -> B -> C`. In this last case there is no cycle among the types and thus
+/// the conversion is incomplete. The same hold for 'closed' chains like
+/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
+/// operations.
+/// Bifurcations (that is when a chain starts in between of another one) are
+/// also taken into considerations, and all the above considerations remain
+/// valid.
+/// Special corner cases such as dead casts or single casts with same input and
+/// output types are also covered.
 struct UnrealizedConversionCastPassthrough
     : public OpRewritePattern<UnrealizedConversionCastOp> {
   using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
                                 PatternRewriter &rewriter) const override {
-    // Match the casts that are _only_ used by other casts, with the overall
-    // cast being a trivial noop: A->B->A.
-    auto users = op->getUsers();
-    if (!llvm::all_of(users, [&](Operation *user) {
-          if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-            return other.getResultTypes() == op.getInputs().getTypes() &&
-                   other.getInputs() == op.getOutputs();
-          return false;
-        })) {
-      return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
+    // The nodes that either are not used by any operation or have at least
+    // one user that is not an unrealized cast.
+    DenseSet<UnrealizedConversionCastOp> exitNodes;
+
+    // The nodes whose users are all unrealized casts
+    DenseSet<UnrealizedConversionCastOp> intermediateNodes;
+
+    // Stack used for the depth-first traversal of the use-def DAG.
+    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
+    visitStack.push_back(op);
+
+    while (!visitStack.empty()) {
+      UnrealizedConversionCastOp current = visitStack.pop_back_val();
+      auto users = current->getUsers();
+      bool isLive = false;
+
+      for (Operation *user : users) {
+        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
+          if (other.getInputs() != current.getOutputs())
+            return rewriter.notifyMatchFailure(
+                op, "mismatching values propagation");
+        } else {
+          isLive = true;
+        }
+
+        // Continue traversing the DAG of unrealized casts
+        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
+          visitStack.push_back(other);
+      }
+
+      // If the cast is live, then we need to check if the results of the last
+      // cast have the same type of the root inputs. It this is the case (e.g.
+      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
+      // no-op and the inputs can be forwarded. If it's not (e.g.
+      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
+
+      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
+
+      if (isLive && !isCycle)
+        return rewriter.notifyMatchFailure(op,
+                                           "live unrealized conversion cast");
+
+      bool isExitNode = users.empty() || isLive;
+
+      if (isExitNode) {
+        exitNodes.insert(current);
+      } else {
+        intermediateNodes.insert(current);
+      }
     }
 
-    for (Operation *user : users)
-      rewriter.replaceOp(user, op.getInputs());
+    // Replace the sink nodes with the root input values
+    for (UnrealizedConversionCastOp exitNode : exitNodes)
+      rewriter.replaceOp(exitNode, op.getInputs());
+
+    // Erase all the other casts belonging to the DAG
+    for (UnrealizedConversionCastOp castOp : intermediateNodes)
+      rewriter.eraseOp(castOp);
 
-    rewriter.eraseOp(op);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
new file mode 100644
index 0000000000000..f5ceb295e7c4e
--- /dev/null
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
@@ -0,0 +1,45 @@
+// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s
+
+// CHECK-LABEL: @liveSingleCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: return %[[liveCast]] : i32
+
+func.func @liveSingleCast(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
+// CHECK: return %[[cast1]] : i32
+
+func.func @liveChain(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
+    %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
+    return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
+// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
+// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @liveBifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %3 = arith.extsi %2 : i1 to i64
+    %4 = arith.addi %1, %3 : i64
+    return %4 : i64
+}

diff  --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
new file mode 100644
index 0000000000000..d71cbba1d13c5
--- /dev/null
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s
+
+// CHECK-LABEL: @unusedCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedCast(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @sameTypes
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @sameTypes(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64
+    return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @pair
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @pair(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @symmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @symmetricChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i32
+    %3 = builtin.unrealized_conversion_cast %2 : i32 to i64
+    return %3 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @asymmetricChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @asymmetricChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: return %[[arg0]] : i64
+
+func.func @unusedChain(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @bifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @bifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    %3 = builtin.unrealized_conversion_cast %1 : i1 to i32
+    %4 = builtin.unrealized_conversion_cast %3 : i32 to i64
+    %5 = arith.addi %2, %4 : i64
+    return %5 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @unusedBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @unusedBifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %2 = builtin.unrealized_conversion_cast %1 : i1 to i64
+    %3 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    %4 = arith.addi %arg0, %3 : i64
+    return %4 : i64
+}


        


More information about the Mlir-commits mailing list