[Mlir-commits] [mlir] 469c02d - [mlir] Add support for merging identical blocks during canonicalization

River Riddle llvmlistbot at llvm.org
Mon May 4 20:00:23 PDT 2020


Author: River Riddle
Date: 2020-05-04T19:56:46-07:00
New Revision: 469c02d0581a4bd7539c7dd62063c29072b55852

URL: https://github.com/llvm/llvm-project/commit/469c02d0581a4bd7539c7dd62063c29072b55852
DIFF: https://github.com/llvm/llvm-project/commit/469c02d0581a4bd7539c7dd62063c29072b55852.diff

LOG: [mlir] Add support for merging identical blocks during canonicalization

This revision adds support for merging identical blocks, or those with the same operations that branch to the same successors. Operands that mismatch between the different blocks are replaced with new block arguments added to the merged block.

Differential Revision: https://reviews.llvm.org/D79134

Added: 
    mlir/test/Transforms/canonicalize-block-merge.mlir

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    mlir/include/mlir/IR/BlockSupport.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/IR/Value.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/Dialect/SPIRV/canonicalize.mlir
    mlir/test/Transforms/canonicalize-dce.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 30bcdf5268a8..71ad4fc1fe2b 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1174,6 +1174,9 @@ class indexed_accessor_range_base {
     return RangeT(iterator_range<iterator>(*this));
   }
 
+  /// Returns the base of this range.
+  const BaseT &getBase() const { return base; }
+
 private:
   /// Offset the given base by the given amount.
   static BaseT offset_base(const BaseT &base, size_t n) {

diff  --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index 10b8c48c6db6..f3dd6140420e 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -119,6 +119,29 @@ class op_iterator
 
 namespace llvm {
 
+/// Provide support for hashing successor ranges.
+template <>
+struct DenseMapInfo<mlir::SuccessorRange> {
+  static mlir::SuccessorRange getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getEmptyKey();
+    return mlir::SuccessorRange(pointer, 0);
+  }
+  static mlir::SuccessorRange getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getTombstoneKey();
+    return mlir::SuccessorRange(pointer, 0);
+  }
+  static unsigned getHashValue(mlir::SuccessorRange value) {
+    return llvm::hash_combine_range(value.begin(), value.end());
+  }
+  static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) {
+    if (rhs.getBase() == getEmptyKey().getBase())
+      return lhs.getBase() == getEmptyKey().getBase();
+    if (rhs.getBase() == getTombstoneKey().getBase())
+      return lhs.getBase() == getTombstoneKey().getBase();
+    return lhs == rhs;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ilist_traits for Operation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5c9408199abf..fcde73efd566 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -554,6 +554,14 @@ class Operation final
                         [](OpResult result) { return result.use_empty(); });
   }
 
+  /// Returns true if the results of this operation are used outside of the
+  /// given block.
+  bool isUsedOutsideOfBlock(Block *block) {
+    return llvm::any_of(getOpResults(), [block](OpResult result) {
+      return result.isUsedOutsideOfBlock(block);
+    });
+  }
+
   //===--------------------------------------------------------------------===//
   // Users
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index edfe89ad97f2..8c0a3f12d426 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 #include "llvm/Support/TrailingObjects.h"
@@ -617,6 +618,17 @@ class ValueTypeRange final
       ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
   template <typename Container>
   ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
+
+  /// Compare this range with another.
+  template <typename OtherT>
+  bool operator==(const OtherT &other) const {
+    return llvm::size(*this) == llvm::size(other) &&
+           std::equal(this->begin(), this->end(), other.begin());
+  }
+  template <typename OtherT>
+  bool operator!=(const OtherT &other) const {
+    return !(*this == other);
+  }
 };
 
 template <typename RangeT>
@@ -829,12 +841,29 @@ class ValueRange final
 /// This class provides utilities for computing if two operations are
 /// equivalent.
 struct OperationEquivalence {
+  enum Flags {
+    None = 0,
+
+    /// This flag signals that operands should not be considered when checking
+    /// for equivalence. This allows for users to implement there own
+    /// equivalence schemes for operand values. The number of operands are still
+    /// checked, just not the operands themselves.
+    IgnoreOperands = 1,
+
+    LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands)
+  };
+
   /// Compute a hash for the given operation.
-  static llvm::hash_code computeHash(Operation *op);
+  static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None);
 
   /// Compare two operations and return if they are equivalent.
-  static bool isEquivalentTo(Operation *lhs, Operation *rhs);
+  static bool isEquivalentTo(Operation *lhs, Operation *rhs,
+                             Flags flags = Flags::None);
 };
+
+/// Enable Bitmask enums for OperationEquivalence::Flags.
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
 } // end namespace mlir
 
 namespace llvm {

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 95def7686792..78517309468d 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -123,6 +123,9 @@ class Value {
   /// Return the Region in which this Value is defined.
   Region *getParentRegion();
 
+  /// Return the Block in which this Value is defined.
+  Block *getParentBlock();
+
   //===--------------------------------------------------------------------===//
   // UseLists
   //===--------------------------------------------------------------------===//
@@ -150,6 +153,9 @@ class Value {
   void replaceUsesWithIf(Value newValue,
                          function_ref<bool(OpOperand &)> shouldReplace);
 
+  /// Returns true if the value is used outside of the given block.
+  bool isUsedOutsideOfBlock(Block *block);
+
   //===--------------------------------------------------------------------===//
   // Uses
 

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index a08762326143..91842cf95e56 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -412,7 +412,7 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptr
diff _t index) {
 // Operation Equivalency
 //===----------------------------------------------------------------------===//
 
-llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
+llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) {
   // Hash operations based upon their:
   //   - Operation Name
   //   - Attributes
@@ -438,12 +438,17 @@ llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
   }
 
   //   - Operands
-  // TODO: Allow commutative operations to have 
diff erent ordering.
-  return llvm::hash_combine(
-      hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
+  bool ignoreOperands = flags & Flags::IgnoreOperands;
+  if (!ignoreOperands) {
+    // TODO: Allow commutative operations to have 
diff erent ordering.
+    hash = llvm::hash_combine(
+        hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
+  }
+  return hash;
 }
 
-bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
+bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
+                                          Flags flags) {
   if (lhs == rhs)
     return true;
 
@@ -478,6 +483,9 @@ bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
     break;
   }
   // Compare operands.
+  bool ignoreOperands = flags & Flags::IgnoreOperands;
+  if (ignoreOperands)
+    return true;
   // TODO: Allow commutative operations to have 
diff erent ordering.
   return std::equal(lhs->operand_begin(), lhs->operand_end(),
                     rhs->operand_begin());

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index fdc5ad6be887..6467a7f2295b 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -87,6 +87,13 @@ Region *Value::getParentRegion() {
   return cast<BlockArgument>().getOwner()->getParent();
 }
 
+/// Return the Block in which this Value is defined.
+Block *Value::getParentBlock() {
+  if (Operation *op = getDefiningOp())
+    return op->getBlock();
+  return cast<BlockArgument>().getOwner();
+}
+
 //===----------------------------------------------------------------------===//
 // Value::UseLists
 //===----------------------------------------------------------------------===//
@@ -134,6 +141,13 @@ void Value::replaceUsesWithIf(Value newValue,
       use.set(newValue);
 }
 
+/// Returns true if the value is used outside of the given block.
+bool Value::isUsedOutsideOfBlock(Block *block) {
+  return llvm::any_of(getUsers(), [block](Operation *user) {
+    return user->getBlock() != block;
+  });
+}
+
 //===--------------------------------------------------------------------===//
 // Uses
 

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 7a00032650b2..76b393183708 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -367,6 +367,324 @@ static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
   return deleteDeadness(regions, liveMap);
 }
 
+//===----------------------------------------------------------------------===//
+// Block Merging
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// BlockEquivalenceData
+
+namespace {
+/// This class contains the information for comparing the equivalencies of two
+/// blocks. Blocks are considered equivalent if they contain the same operations
+/// in the same order. The only allowed divergence is for operands that come
+/// from sources outside of the parent block, i.e. the uses of values produced
+/// within the block must be equivalent.
+///   e.g.,
+/// Equivalent:
+///  ^bb1(%arg0: i32)
+///    return %arg0, %foo : i32, i32
+///  ^bb2(%arg1: i32)
+///    return %arg1, %bar : i32, i32
+/// Not Equivalent:
+///  ^bb1(%arg0: i32)
+///    return %foo, %arg0 : i32, i32
+///  ^bb2(%arg1: i32)
+///    return %arg1, %bar : i32, i32
+struct BlockEquivalenceData {
+  BlockEquivalenceData(Block *block);
+
+  /// Return the order index for the given value that is within the block of
+  /// this data.
+  unsigned getOrderOf(Value value) const;
+
+  /// The block this data refers to.
+  Block *block;
+  /// A hash value for this block.
+  llvm::hash_code hash;
+  /// A map of result producing operations to their relative orders within this
+  /// block. The order of an operation is the number of defined values that are
+  /// produced within the block before this operation.
+  DenseMap<Operation *, unsigned> opOrderIndex;
+};
+} // end anonymous namespace
+
+BlockEquivalenceData::BlockEquivalenceData(Block *block)
+    : block(block), hash(0) {
+  unsigned orderIt = block->getNumArguments();
+  for (Operation &op : *block) {
+    if (unsigned numResults = op.getNumResults()) {
+      opOrderIndex.try_emplace(&op, orderIt);
+      orderIt += numResults;
+    }
+    auto opHash = OperationEquivalence::computeHash(
+        &op, OperationEquivalence::Flags::IgnoreOperands);
+    hash = llvm::hash_combine(hash, opHash);
+  }
+}
+
+unsigned BlockEquivalenceData::getOrderOf(Value value) const {
+  assert(value.getParentBlock() == block && "expected value of this block");
+
+  // Arguments use the argument number as the order index.
+  if (BlockArgument arg = value.dyn_cast<BlockArgument>())
+    return arg.getArgNumber();
+
+  // Otherwise, the result order is offset from the parent op's order.
+  OpResult result = value.cast<OpResult>();
+  auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
+  assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
+  return opOrderIt->second + result.getResultNumber();
+}
+
+//===----------------------------------------------------------------------===//
+// BlockMergeCluster
+
+namespace {
+/// This class represents a cluster of blocks to be merged together.
+class BlockMergeCluster {
+public:
+  BlockMergeCluster(BlockEquivalenceData &&leaderData)
+      : leaderData(std::move(leaderData)) {}
+
+  /// Attempt to add the given block to this cluster. Returns success if the
+  /// block was merged, failure otherwise.
+  LogicalResult addToCluster(BlockEquivalenceData &blockData);
+
+  /// Try to merge all of the blocks within this cluster into the leader block.
+  LogicalResult merge();
+
+private:
+  /// The equivalence data for the leader of the cluster.
+  BlockEquivalenceData leaderData;
+
+  /// The set of blocks that can be merged into the leader.
+  llvm::SmallSetVector<Block *, 1> blocksToMerge;
+
+  /// A set of operand+index pairs that correspond to operands that need to be
+  /// replaced by arguments when the cluster gets merged.
+  std::set<std::pair<int, int>> operandsToMerge;
+
+  /// A map of operations with external uses to a replacement within the leader
+  /// block.
+  DenseMap<Operation *, Operation *> opsToReplace;
+};
+} // end anonymous namespace
+
+LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
+  if (leaderData.hash != blockData.hash)
+    return failure();
+  Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
+  if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
+    return failure();
+
+  // A set of operands that mismatch between the leader and the new block.
+  SmallVector<std::pair<int, int>, 8> mismatchedOperands;
+  SmallVector<std::pair<Operation *, Operation *>, 2> newOpsToReplace;
+  auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
+  auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
+  for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
+    // Check that the operations are equivalent.
+    if (!OperationEquivalence::isEquivalentTo(
+            &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands))
+      return failure();
+
+    // Compare the operands of the two operations. If the operand is within
+    // the block, it must refer to the same operation.
+    auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
+    for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
+      Value lhsOperand = lhsOperands[operand];
+      Value rhsOperand = rhsOperands[operand];
+      if (lhsOperand == rhsOperand)
+        continue;
+
+      // Check that these uses are both external, or both internal.
+      bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
+      bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
+      if (lhsIsInBlock != rhsIsInBlock)
+        return failure();
+      // Let the operands 
diff er if they are defined in a 
diff erent block. These
+      // will become new arguments if the blocks get merged.
+      if (!lhsIsInBlock) {
+        mismatchedOperands.emplace_back(opI, operand);
+        continue;
+      }
+
+      // Otherwise, these operands must have the same logical order within the
+      // parent block.
+      if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
+        return failure();
+    }
+
+    // If the rhs has external uses, it will need to be replaced.
+    if (rhsIt->isUsedOutsideOfBlock(mergeBlock))
+      newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt);
+  }
+  // Make sure that the block sizes are equivalent.
+  if (lhsIt != lhsE || rhsIt != rhsE)
+    return failure();
+
+  // If we get here, the blocks are equivalent and can be merged.
+  operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
+  opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end());
+  blocksToMerge.insert(blockData.block);
+  return success();
+}
+
+/// Returns true if the predecessor terminators of the given block can not have
+/// their operands updated.
+static bool ableToUpdatePredOperands(Block *block) {
+  for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
+    auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
+    if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
+      return false;
+  }
+  return true;
+}
+
+LogicalResult BlockMergeCluster::merge() {
+  // Don't consider clusters that don't have blocks to merge.
+  if (blocksToMerge.empty())
+    return failure();
+
+  Block *leaderBlock = leaderData.block;
+  if (!operandsToMerge.empty()) {
+    // If the cluster has operands to merge, verify that the predecessor
+    // terminators of each of the blocks can have their successor operands
+    // updated.
+    // TODO: We could try and sub-partition this cluster if only some blocks
+    // cause the mismatch.
+    if (!ableToUpdatePredOperands(leaderBlock) ||
+        !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
+      return failure();
+
+    // Replace any necessary operations.
+    for (std::pair<Operation *, Operation *> &it : opsToReplace)
+      it.first->replaceAllUsesWith(it.second);
+
+    // Collect the iterators for each of the blocks to merge. We will walk all
+    // of the iterators at once to avoid operand index invalidation.
+    SmallVector<Block::iterator, 2> blockIterators;
+    blockIterators.reserve(blocksToMerge.size() + 1);
+    blockIterators.push_back(leaderBlock->begin());
+    for (Block *mergeBlock : blocksToMerge)
+      blockIterators.push_back(mergeBlock->begin());
+
+    // Update each of the predecessor terminators with the new arguments.
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(
+        1 + blocksToMerge.size(),
+        SmallVector<Value, 8>(operandsToMerge.size()));
+    unsigned curOpIndex = 0;
+    for (auto it : llvm::enumerate(operandsToMerge)) {
+      unsigned nextOpOffset = it.value().first - curOpIndex;
+      curOpIndex = it.value().first;
+
+      // Process the operand for each of the block iterators.
+      for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
+        Block::iterator &blockIter = blockIterators[i];
+        std::advance(blockIter, nextOpOffset);
+        auto &operand = blockIter->getOpOperand(it.value().second);
+        newArguments[i][it.index()] = operand.get();
+
+        // Update the operand and insert an argument if this is the leader.
+        if (i == 0)
+          operand.set(leaderBlock->addArgument(operand.get().getType()));
+      }
+    }
+    // Update the predecessors for each of the blocks.
+    auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
+      for (auto predIt = block->pred_begin(), predE = block->pred_end();
+           predIt != predE; ++predIt) {
+        auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+        unsigned succIndex = predIt.getSuccessorIndex();
+        branch.getMutableSuccessorOperands(succIndex)->append(
+            newArguments[clusterIndex]);
+      }
+    };
+    updatePredecessors(leaderBlock, /*clusterIndex=*/0);
+    for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
+      updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
+  }
+
+  // Replace all uses of the merged blocks with the leader and erase them.
+  for (Block *block : blocksToMerge) {
+    block->replaceAllUsesWith(leaderBlock);
+    block->erase();
+  }
+  return success();
+}
+
+/// Identify identical blocks within the given region and merge them, inserting
+/// new block arguments as necessary. Returns success if any blocks were merged,
+/// failure otherwise.
+static LogicalResult mergeIdenticalBlocks(Region &region) {
+  if (region.empty() || llvm::hasSingleElement(region))
+    return failure();
+
+  // Identify sets of blocks, other than the entry block, that branch to the
+  // same successors. We will use these groups to create clusters of equivalent
+  // blocks.
+  DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
+  for (Block &block : llvm::drop_begin(region, 1))
+    matchingSuccessors[block.getSuccessors()].push_back(&block);
+
+  bool mergedAnyBlocks = false;
+  for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
+    if (blocks.size() == 1)
+      continue;
+
+    SmallVector<BlockMergeCluster, 1> clusters;
+    for (Block *block : blocks) {
+      BlockEquivalenceData data(block);
+
+      // Don't allow merging if this block has any regions.
+      // TODO: Add support for regions if necessary.
+      bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
+        return llvm::any_of(op.getRegions(),
+                            [](Region &region) { return !region.empty(); });
+      });
+      if (hasNonEmptyRegion)
+        continue;
+
+      // Try to add this block to an existing cluster.
+      bool addedToCluster = false;
+      for (auto &cluster : clusters)
+        if ((addedToCluster = succeeded(cluster.addToCluster(data))))
+          break;
+      if (!addedToCluster)
+        clusters.emplace_back(std::move(data));
+    }
+    for (auto &cluster : clusters)
+      mergedAnyBlocks |= succeeded(cluster.merge());
+  }
+
+  return success(mergedAnyBlocks);
+}
+
+/// Identify identical blocks within the given regions and merge them, inserting
+/// new block arguments as necessary.
+static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+    if (succeeded(mergeIdenticalBlocks(*region))) {
+      worklist.insert(region);
+      anyChanged = true;
+    }
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region)
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+  }
+
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -376,7 +694,9 @@ static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise.
 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
-  LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions);
-  LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions);
-  return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs));
+  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions));
+  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions));
+  bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions));
+  return success(eliminatedBlocks || eliminatedOpsOrArgs ||
+                 mergedIdenticalBlocks);
 }

diff  --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
index f8c3bdebda39..20ed6e96be8d 100644
--- a/mlir/test/Dialect/SPIRV/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir
@@ -559,15 +559,18 @@ func @cannot_canonicalize_selection_op_0(%cond: i1) -> () {
 
   // CHECK: spv.selection {
   spv.selection {
+    // CHECK: spv.BranchConditional
+    // CHECK-SAME: ^bb1(%[[DST_VAR_0]], %[[SRC_VALUE_0]]
+    // CHECK-SAME: ^bb1(%[[DST_VAR_1]], %[[SRC_VALUE_1]]
     spv.BranchConditional %cond, ^then, ^else
 
   ^then:
-    // CHECK: spv.Store "Function" %[[DST_VAR_0]], %[[SRC_VALUE_0]] ["Aligned", 8] : vector<3xi32>
+    // CHECK: ^bb1(%[[ARG0:.*]]: !spv.ptr<vector<3xi32>, Function>, %[[ARG1:.*]]: vector<3xi32>):
+    // CHECK: spv.Store "Function" %[[ARG0]], %[[ARG1]] ["Aligned", 8] : vector<3xi32>
     spv.Store "Function" %3, %1 ["Aligned", 8]:  vector<3xi32>
     spv.Branch ^merge
 
   ^else:
-    // CHECK: spv.Store "Function" %[[DST_VAR_1]], %[[SRC_VALUE_1]] ["Aligned", 8] : vector<3xi32>
     spv.Store "Function" %4, %2 ["Aligned", 8] : vector<3xi32>
     spv.Branch ^merge
 

diff  --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
new file mode 100644
index 000000000000..86cac9dddbb3
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -0,0 +1,204 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+
+// Check the simple case of single operation blocks with a return.
+
+// CHECK-LABEL: func @return_blocks(
+func @return_blocks() {
+  // CHECK: "foo.cond_br"()[^bb1, ^bb1]
+  // CHECK: ^bb1:
+  // CHECK-NEXT: return
+  // CHECK-NOT: ^bb2
+
+  "foo.cond_br"() [^bb1, ^bb2] : () -> ()
+
+^bb1:
+  return
+^bb2:
+  return
+}
+
+// Check the case of identical blocks with matching arguments.
+
+// CHECK-LABEL: func @matching_arguments(
+func @matching_arguments() -> i32 {
+  // CHECK: "foo.cond_br"()[^bb1, ^bb1]
+  // CHECK: ^bb1(%{{.*}}: i32):
+  // CHECK-NEXT: return
+  // CHECK-NOT: ^bb2
+
+  "foo.cond_br"() [^bb1, ^bb2] : () -> ()
+
+^bb1(%arg0 : i32):
+  return %arg0 : i32
+^bb2(%arg1 : i32):
+  return %arg1 : i32
+}
+
+// Check that no merging occurs if there is an operand mismatch and we can't
+// update th predecessor.
+
+// CHECK-LABEL: func @mismatch_unknown_terminator
+func @mismatch_unknown_terminator(%arg0 : i32, %arg1 : i32) -> i32 {
+  // CHECK: "foo.cond_br"()[^bb1, ^bb2]
+
+  "foo.cond_br"() [^bb1, ^bb2] : () -> ()
+
+^bb1:
+  return %arg0 : i32
+^bb2:
+  return %arg1 : i32
+}
+
+// Check that merging does occurs if there is an operand mismatch and we can
+// update th predecessor.
+
+// CHECK-LABEL: func @mismatch_operands
+// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @mismatch_operands(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
+  // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
+  // CHECK: return %[[RES]]
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  return %arg0 : i32
+^bb2:
+  return %arg1 : i32
+}
+
+// Check the same as above, but with pre-existing arguments.
+
+// CHECK-LABEL: func @mismatch_operands_matching_arguments(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
+  // CHECK: %[[RES0:.*]] = select %[[COND]], %[[ARG1]], %[[ARG0]]
+  // CHECK: %[[RES1:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
+  // CHECK: return %[[RES1]], %[[RES0]]
+
+  cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
+
+^bb1(%arg2 : i32):
+  return %arg0, %arg2 : i32, i32
+^bb2(%arg3 : i32):
+  return %arg1, %arg3 : i32, i32
+}
+
+// Check that merging does not occur if the uses of the arguments 
diff er.
+
+// CHECK-LABEL: func @mismatch_argument_uses(
+func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+
+  cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
+
+^bb1(%arg2 : i32):
+  return %arg0, %arg2 : i32, i32
+^bb2(%arg3 : i32):
+  return %arg3, %arg1 : i32, i32
+}
+
+// Check that merging does not occur if the types of the arguments 
diff er.
+
+// CHECK-LABEL: func @mismatch_argument_types(
+func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+
+  cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
+
+^bb1(%arg2 : i32):
+  "foo.return"(%arg2) : (i32) -> ()
+^bb2(%arg3 : i16):
+  "foo.return"(%arg3) : (i16) -> ()
+}
+
+// Check that merging does not occur if the number of the arguments 
diff er.
+
+// CHECK-LABEL: func @mismatch_argument_count(
+func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+
+  cond_br %cond, ^bb1(%arg0 : i32), ^bb2
+
+^bb1(%arg2 : i32):
+  "foo.return"(%arg2) : (i32) -> ()
+^bb2:
+  "foo.return"() : () -> ()
+}
+
+// Check that merging does not occur if the operations 
diff er.
+
+// CHECK-LABEL: func @mismatch_operations(
+func @mismatch_operations(%cond : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  "foo.return"() : () -> ()
+^bb2:
+  return
+}
+
+// Check that merging does not occur if the number of operations 
diff er.
+
+// CHECK-LABEL: func @mismatch_operation_count(
+func @mismatch_operation_count(%cond : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  "foo.op"() : () -> ()
+  return
+^bb2:
+  return
+}
+
+// Check that merging does not occur if the blocks contain regions.
+
+// CHECK-LABEL: func @contains_regions(
+func @contains_regions(%cond : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  loop.if %cond {
+    "foo.op"() : () -> ()
+  }
+  return
+^bb2:
+  loop.if %cond {
+    "foo.op"() : () -> ()
+  }
+  return
+}
+
+// Check that properly handles back edges and the case where a value from one
+// block is used in another.
+
+// CHECK-LABEL: func @mismatch_loop(
+// CHECK-SAME: %[[ARG:.*]]: i1
+func @mismatch_loop(%cond : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2
+
+  cond_br %cond, ^bb2, ^bb3
+
+^bb1:
+  // CHECK: ^bb1(%[[ARG2:.*]]: i1):
+  // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
+  // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
+
+  %ignored = "foo.op"() : () -> (i1)
+  cond_br %cond2, ^bb1, ^bb3
+
+^bb2:
+  %cond2 = "foo.op"() : () -> (i1)
+  cond_br %cond, ^bb1, ^bb3
+
+^bb3:
+  // CHECK: ^bb2:
+  // CHECK-NEXT: return
+
+  return
+}

diff  --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index b93af002823a..6028821934ff 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -62,10 +62,6 @@ func @f(%arg0: f32) {
 // Test case: Delete block arguments for cond_br.
 
 // CHECK:      func @f(%arg0: f32, %arg1: i1)
-// CHECK-NEXT:   cond_br %arg1, ^bb1, ^bb2
-// CHECK-NEXT: ^bb1:
-// CHECK-NEXT:   return
-// CHECK-NEXT: ^bb2:
 // CHECK-NEXT:   return
 
 func @f(%arg0: f32, %pred: i1) {

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 69e8b398296b..1cff314d731a 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s -dump-input-on-failure
 
 // CHECK-LABEL: func @test_subi_zero
 func @test_subi_zero(%arg0: i32) -> i32 {
@@ -361,19 +361,15 @@ func @dead_dealloc_fold() {
 
 // CHECK-LABEL: func @dead_dealloc_fold_multi_use
 func @dead_dealloc_fold_multi_use(%cond : i1) {
-  // CHECK-NEXT: cond_br
+  // CHECK-NEXT: return
   %a = alloc() : memref<4xf32>
   cond_br %cond, ^bb1, ^bb2
 
-  // CHECK-LABEL: bb1:
 ^bb1:
-  // CHECK-NEXT: return
   dealloc %a: memref<4xf32>
   return
 
-  // CHECK-LABEL: bb2:
 ^bb2:
-  // CHECK-NEXT: return
   dealloc %a: memref<4xf32>
   return
 }


        


More information about the Mlir-commits mailing list