[Mlir-commits] [mlir] [mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (PR #95700)

Matthias Springer llvmlistbot at llvm.org
Sun Jun 16 05:18:28 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/95700

This commit fixes a crash in `-reconcile-unrealized-casts` when cast ops have multiple operands:
```
DialectConversion.cpp:1583: virtual void mlir::ConversionPatternRewriter::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```

This commit also generalizes the pass such that more ops are folded. In particular (letters indicate types):
```
         A
       /   \
      B     C
      |
      A
```
Previously, such IR was not folded at all. The `A -> B -> A` type cast cycle is now folded away. (The `A -> C` cast stays in place.)

This commit also turns the pass from a dialect conversion into a simple IR walk. The pattern and its `populate` function are removed. The pattern was a (non-conversion) rewrite pattern, but used in a dialect conversion, which is generally not supported. In particular, the rewrite pattern may traverse IR that was already scheduled for erasure by the dialect conversion.

Note: Some test cases changed slightly (NFC) because the new pass implementation no longer attempts to fold ops.

Note for LLVM integration: If your pipeline uses the removed `populate` function, try to simply remove that function call. Chances are you may not need it at all. If it is in fact needed, run the `-reconcile-unrealized-casts` pass right after the pass that used to populate the pattern.

>From dfe30eea88f82905fcb2dcb9af0d86f5e234efcf Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 16 Jun 2024 13:59:08 +0200
Subject: [PATCH] [mlir][Conversion] Generalize and fix crash in
 `reconcile-unrealized-casts`

This commit fixes a crash in `-reconcile-unrealized-casts` when cast ops have multiple operands:
```
DialectConversion.cpp:1583: virtual void mlir::ConversionPatternRewriter::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```

This commit also generalizes the pass such that more ops are folded. In particular (letters indicate types):
```
         A
       /   \
      B     C
      |
      A
```
Previously, such IR was not folded at all. The `A -> B -> A` type cast cycle is now folded away. (The `A -> C` cast stays in place.)

This commit also turns the pass from a dialect conversion into a simple IR walk. The pattern and its `populate` function are removed. The pattern was a (non-conversion) rewrite pattern, but used in a dialect conversion, which is generally not supported. In particular, the rewrite pattern may traverse IR that was already scheduled for erasure by the dialect conversion.

Note: Some test cases changed slightly (NFC) because the new pass implementation no longer attempts to fold ops.

Note for LLVM integration: If your pipeline uses the removed `populate` function, try to simply remove that function call. Chances are you may not need it at all. If it is in fact needed, run the `-reconcile-unrealized-casts` pass right after the pass that used to populate the pattern.
---
 .../ReconcileUnrealizedCasts.h                |   4 -
 .../ReconcileUnrealizedCasts.cpp              | 168 ++++++++----------
 .../FuncToLLVM/calling-convention.mlir        |  23 ++-
 .../reconcile-unrealized-casts-failure.mlir   |  45 -----
 .../reconcile-unrealized-casts.mlir           |  81 +++++++++
 .../lower-to-llvm-e2e-with-target-tag.mlir    |   4 +-
 ...lvm-e2e-with-top-level-named-sequence.mlir |   4 +-
 7 files changed, 170 insertions(+), 159 deletions(-)
 delete mode 100644 mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir

diff --git a/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h b/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
index 9df117709c82f..533a57a060c26 100644
--- a/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
+++ b/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
@@ -21,10 +21,6 @@ class RewritePatternSet;
 /// Creates a pass that eliminates noop `unrealized_conversion_cast` operation
 /// sequences.
 std::unique_ptr<Pass> createReconcileUnrealizedCastsPass();
-
-/// Populates `patterns` with rewrite patterns that eliminate noop
-/// `unrealized_conversion_cast` operation sequences.
-void populateReconcileUnrealizedCastsPatterns(RewritePatternSet &patterns);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_RECONCILEUNREALIZEDCASTS_RECONCILEUNREALIZEDCASTS_H_
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 86a3d8b7819b3..d57103bd80ebe 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -9,9 +9,7 @@
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -22,113 +20,87 @@ using namespace mlir;
 
 namespace {
 
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
-    : public OpRewritePattern<UnrealizedConversionCastOp> {
-  using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
-                                PatternRewriter &rewriter) const override {
-    // The nodes that either are not used by any operation or have at least
-    // one user that is not an unrealized cast.
-    DenseSet<UnrealizedConversionCastOp> exitNodes;
-
-    // The nodes whose users are all unrealized casts
-    DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
-    // Stack used for the depth-first traversal of the use-def DAG.
-    SmallVector<UnrealizedConversionCastOp, 2> visitStack;
-    visitStack.push_back(op);
-
-    while (!visitStack.empty()) {
-      UnrealizedConversionCastOp current = visitStack.pop_back_val();
-      auto users = current->getUsers();
-      bool isLive = false;
-
-      for (Operation *user : users) {
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
-          if (other.getInputs() != current.getOutputs())
-            return rewriter.notifyMatchFailure(
-                op, "mismatching values propagation");
-        } else {
-          isLive = true;
-        }
-
-        // Continue traversing the DAG of unrealized casts
-        if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-          visitStack.push_back(other);
-      }
-
-      // If the cast is live, then we need to check if the results of the last
-      // cast have the same type of the root inputs. It this is the case (e.g.
-      // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
-      // no-op and the inputs can be forwarded. If it's not (e.g.
-      // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
-      bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
-      if (isLive && !isCycle)
-        return rewriter.notifyMatchFailure(op,
-                                           "live unrealized conversion cast");
-
-      bool isExitNode = users.empty() || isLive;
-
-      if (isExitNode) {
-        exitNodes.insert(current);
-      } else {
-        intermediateNodes.insert(current);
-      }
-    }
-
-    // Replace the sink nodes with the root input values
-    for (UnrealizedConversionCastOp exitNode : exitNodes)
-      rewriter.replaceOp(exitNode, op.getInputs());
-
-    // Erase all the other casts belonging to the DAG
-    for (UnrealizedConversionCastOp castOp : intermediateNodes)
-      rewriter.eraseOp(castOp);
-
-    return success();
-  }
-};
-
 /// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
 struct ReconcileUnrealizedCasts
     : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
   ReconcileUnrealizedCasts() = default;
 
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateReconcileUnrealizedCastsPatterns(patterns);
-    ConversionTarget target(getContext());
-    target.addIllegalOp<UnrealizedConversionCastOp>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      signalPassFailure();
+    // Gather all unrealized_conversion_cast ops.
+    SetVector<UnrealizedConversionCastOp> worklist;
+    getOperation()->walk(
+        [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+    // Helper function that adds all operands to the worklist that are an
+    // unrealized_conversion_cast op result.
+    auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+      for (Value v : castOp.getInputs())
+        if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+          worklist.insert(castOp);
+    };
+
+    // Helper function that return the unrealized_conversion_cast op that
+    // defines all inputs of the given op (in the same order). Return "nullptr"
+    // if there is no such op.
+    auto getInputCast =
+        [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+      if (castOp.getInputs().empty())
+        return {};
+      auto inputCastOp = castOp.getInputs()
+                             .front()
+                             .getDefiningOp<UnrealizedConversionCastOp>();
+      if (!inputCastOp)
+        return {};
+      if (inputCastOp.getOutputs() != castOp.getInputs())
+        return {};
+      return inputCastOp;
+    };
+
+    // Process ops in the worklist bottom-to-top.
+    while (!worklist.empty()) {
+      UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+      if (castOp->getUsers().empty()) {
+        // DCE: If the op has no users, erase it. Add the operands to the
+        // worklist to find additional DCE opportunities.
+        enqueueOperands(castOp);
+        castOp->erase();
+        continue;
+      }
+
+      // Traverse the chain of input cast ops to see if an op with the same
+      // input types can be found.
+      UnrealizedConversionCastOp nextCast = castOp;
+      while (nextCast) {
+        if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+          // Found a cast where the input types match the output types of the
+          // matched op. We can directly use those inputs and the matched op can
+          // be removed.
+          enqueueOperands(castOp);
+          castOp.getResults().replaceAllUsesWith(nextCast.getInputs());
+          castOp->erase();
+          break;
+        }
+        nextCast = getInputCast(nextCast);
+      }
+    }
   }
 };
 
 } // namespace
 
-void mlir::populateReconcileUnrealizedCastsPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
-}
-
 std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
   return std::make_unique<ReconcileUnrealizedCasts>();
 }
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 7cdb89e1f72d2..18734d19e335c 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,14 +159,17 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
 
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[RANK_EXTR:.*]] = llvm.extractvalue %[[DESC_2]][0]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK_EXTR]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
   // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]])
-  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
+  // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA_EXTRACTED]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
   // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0]
+  // CHECK: %[[RANK_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][0]
+  // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK_EXTRACTED]], %[[NEW_DESC]][0]
   // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1]
   // CHECK: llvm.return %[[NEW_DESC_2]]
   return %0 : memref<*xf32>
@@ -218,13 +221,15 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr
   // convention requires the caller to free them and the caller cannot know
   // whether they are the same value or not.
   // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}})
-  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
+  // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
   // CHECK: %[[RES_1:.*]] = llvm.mlir.undef
   // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0]
   // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1]
 
   // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}})
-  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
+  // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+  // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
   // CHECK: %[[RES_2:.*]] = llvm.mlir.undef
   // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0]
   // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1]
@@ -265,7 +270,8 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 :
   // CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
   memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
 
-  // CHECK: llvm.return %[[ARG0]]
+  // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
+  // CHECK: llvm.return %[[EXTRACT_MEMREF]]
   return %arg0 : memref<4x3xf32>
 }
 
@@ -298,9 +304,10 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind
   // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
   %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
 
+  // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
   // CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)>
   // CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0]
-  // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1]
+  // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[EXTRACT_MEMREF]], %[[INSERT_RETURN0]][1]
   // CHECK: llvm.return %[[INSERT_RETURN1]]
   return %0, %arg0 : f32, memref<4x3xf32>
 }
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
deleted file mode 100644
index f5ceb295e7c4e..0000000000000
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s
-
-// CHECK-LABEL: @liveSingleCast
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
-// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
-// CHECK: return %[[liveCast]] : i32
-
-func.func @liveSingleCast(%arg0: i64) -> i32 {
-    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
-    return %0 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @liveChain
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
-// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
-// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
-// CHECK: return %[[cast1]] : i32
-
-func.func @liveChain(%arg0: i64) -> i32 {
-    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
-    %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
-    return %1 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @liveBifurcation
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
-// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
-// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64
-// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
-// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
-// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64
-// CHECK: return %[[result]] : i64
-
-func.func @liveBifurcation(%arg0: i64) -> i64 {
-    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
-    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
-    %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
-    %3 = arith.extsi %2 : i1 to i64
-    %4 = arith.addi %1, %3 : i64
-    return %4 : i64
-}
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index d71cbba1d13c5..20d297a619220 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -103,3 +103,84 @@ func.func @unusedBifurcation(%arg0: i64) -> i64 {
     %4 = arith.addi %arg0, %3 : i64
     return %4 : i64
 }
+
+// -----
+
+// CHECK-LABEL: @liveSingleCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: return %[[liveCast]] : i32
+
+func.func @liveSingleCast(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
+// CHECK: return %[[cast1]] : i32
+
+func.func @liveChain(%arg0: i64) -> i32 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
+    %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
+    return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
+// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[extsi]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @liveBifurcation(%arg0: i64) -> i64 {
+    %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+    %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+    %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
+    %3 = arith.extsi %2 : i1 to i64
+    %4 = arith.addi %1, %3 : i64
+    return %4 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @deadNToOneCast(
+// CHECK-NEXT:    return
+func.func @deadNToOneCast(%arg0: index, %arg1: index) {
+    %0 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @swappingOperands(
+// CHECK-SAME:      %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK:         %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]]
+// CHECK:         %[[cast2:.*]]:2 = builtin.unrealized_conversion_cast %[[cast1]]#1, %[[cast1]]#0
+// CHECK:         %[[cast3:.*]]:2 = builtin.unrealized_conversion_cast %[[cast2]]#0, %[[cast2]]#1
+// CHECK:         return %[[cast3]]#0, %[[cast3]]#1
+func.func @swappingOperands(%arg0: index, %arg1: index) -> (index, index) {
+    %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64
+    %1:2 = builtin.unrealized_conversion_cast %0#1, %0#0 : i64, i64 to i32, i32
+    %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1 : i32, i32 to index, index
+    return %2#0, %2#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @matchingOperands(
+// CHECK-SAME:      %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK:         return %[[arg0]], %[[arg1]]
+func.func @matchingOperands(%arg0: index, %arg1: index) -> (index, index) {
+    %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64
+    %1:3 = builtin.unrealized_conversion_cast %0#0, %0#1 : i64, i64 to i32, i32, i32
+    %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1, %1#2 : i32, i32, i32 to index, index
+    return %2#0, %2#1 : index, index
+}
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
index cc9759f62431f..f6d3387d99b3c 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
@@ -1,11 +1,11 @@
 // Note: We run CSE here to make the pattern matching more direct.
 
-// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
+// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s
 
 // RUN: mlir-opt %s \
 // RUN:   -transform-preload-library="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir" \
 // RUN:   -transform-interpreter="debug-payload-root-tag=payload" \
-// RUN:   -test-transform-dialect-erase-schedule -cse \
+// RUN:   -test-transform-dialect-erase-schedule -cse -canonicalize \
 // RUN: | FileCheck %s
 
 module attributes {transform.target_tag="payload"} {
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
index ac4608e2f83a4..a74553cc2268e 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
@@ -1,12 +1,12 @@
 // Note: We run CSE here to make the pattern matching more direct.
 
-// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
+// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s
 
 // RUN: mlir-opt %s \
 // RUN:   -transform-preload-library=transform-library-paths=%p/../Transform/include/Library/lower-to-llvm.mlir \
 // RUN:   -transform-interpreter="entry-point=entry_point" \
 // RUN:   -test-transform-dialect-erase-schedule \
-// RUN:   -cse \
+// RUN:   -cse -canonicalize \
 // RUN: | FileCheck %s
 
 // Check that we properly lower to llvm memref operations that require to be



More information about the Mlir-commits mailing list