[Mlir-commits] [mlir] c484c7d - [mlir-reduce] Reducer refactor.

Chia-hung Duan llvmlistbot at llvm.org
Tue Jun 1 16:46:45 PDT 2021


Author: Chia-hung Duan
Date: 2021-06-02T07:45:00+08:00
New Revision: c484c7dd9d2382f07216ae9142ceb76272e21dc4

URL: https://github.com/llvm/llvm-project/commit/c484c7dd9d2382f07216ae9142ceb76272e21dc4
DIFF: https://github.com/llvm/llvm-project/commit/c484c7dd9d2382f07216ae9142ceb76272e21dc4.diff

LOG: [mlir-reduce] Reducer refactor.

* A Reducer is a kind of RewritePattern, so it's just the same as
writing graph rewrite.
* ReductionTreePass operates on Operation rather than ModuleOp, so that
* we are able to reduce a nested structure(e.g., module in module) by
* self-nesting.

Reviewed By: jpienaar, rriddle

Differential Revision: https://reviews.llvm.org/D101046

Added: 
    mlir/include/mlir/Reducer/ReductionPatternInterface.h
    mlir/lib/Reducer/OptReductionPass.cpp
    mlir/lib/Reducer/ReductionNode.cpp
    mlir/lib/Reducer/ReductionTreePass.cpp
    mlir/test/mlir-reduce/crashop-reduction.mlir

Modified: 
    mlir/include/mlir/Reducer/Passes.h
    mlir/include/mlir/Reducer/Passes.td
    mlir/include/mlir/Reducer/ReductionNode.h
    mlir/lib/Reducer/CMakeLists.txt
    mlir/lib/Reducer/Tester.cpp
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Reducer/MLIRTestReducer.cpp
    mlir/test/mlir-reduce/dce-test.mlir
    mlir/test/mlir-reduce/multiple-function.mlir
    mlir/test/mlir-reduce/simple-test.mlir
    mlir/test/mlir-reduce/single-function.mlir
    mlir/tools/mlir-reduce/CMakeLists.txt
    mlir/tools/mlir-reduce/mlir-reduce.cpp

Removed: 
    mlir/include/mlir/Reducer/OptReductionPass.h
    mlir/include/mlir/Reducer/Passes/OpReducer.h
    mlir/include/mlir/Reducer/ReductionTreePass.h
    mlir/tools/mlir-reduce/OptReductionPass.cpp
    mlir/tools/mlir-reduce/ReductionNode.cpp
    mlir/tools/mlir-reduce/ReductionTreePass.cpp


################################################################################
diff  --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h
deleted file mode 100644
index f2735395d7590..0000000000000
--- a/mlir/include/mlir/Reducer/OptReductionPass.h
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
-// run any optimization pass within it and only replaces the output module with
-// the transformed version if it is smaller and interesting.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H
-#define MLIR_REDUCER_OPTREDUCTIONPASS_H
-
-#include "PassDetail.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Tester.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/Support/Debug.h"
-
-namespace mlir {
-
-class OptReductionPass : public OptReductionBase<OptReductionPass> {
-public:
-  OptReductionPass() = default;
-
-  OptReductionPass(const OptReductionPass &srcPass) = default;
-
-  /// Runs the pass instance in the pass pipeline.
-  void runOnOperation() override;
-};
-
-} // end namespace mlir
-
-#endif

diff  --git a/mlir/include/mlir/Reducer/Passes.h b/mlir/include/mlir/Reducer/Passes.h
index 8eb1d59b47cfa..438f47b329d04 100644
--- a/mlir/include/mlir/Reducer/Passes.h
+++ b/mlir/include/mlir/Reducer/Passes.h
@@ -9,8 +9,6 @@
 #define MLIR_REDUCER_PASSES_H
 
 #include "mlir/Pass/Pass.h"
-#include "mlir/Reducer/OptReductionPass.h"
-#include "mlir/Reducer/ReductionTreePass.h"
 
 namespace mlir {
 

diff  --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index ce99507706e08..43e90122e6682 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -24,14 +24,12 @@ def CommonReductionPassOptions {
   ];
 }
 
-def ReductionTree : Pass<"reduction-tree", "ModuleOp"> {
+def ReductionTree : Pass<"reduction-tree"> {
   let summary = "A general reduction tree pass for the MLIR Reduce Tool";
 
   let constructor = "mlir::createReductionTreePass()";
 
   let options = [
-    Option<"opReducerName", "op-reducer", "std::string", /* default */"",
-           "The OpReducer to reduce the module">,
     Option<"traversalModeId", "traversal-mode", "unsigned",
            /* default */"0", "The graph traversal mode">,
   ] # CommonReductionPassOptions.options;

diff  --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h
deleted file mode 100644
index 6e48d41187356..0000000000000
--- a/mlir/include/mlir/Reducer/Passes/OpReducer.h
+++ /dev/null
@@ -1,76 +0,0 @@
-//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the OpReducer class. It defines a variant generator method
-// with the purpose of producing 
diff erent variants by eliminating a
-// parameterizable type of operations from the  parent module.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
-#define MLIR_REDUCER_PASSES_OPREDUCER_H
-
-#include <limits>
-
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/Tester.h"
-
-namespace mlir {
-
-class OpReducer {
-public:
-  virtual ~OpReducer() = default;
-  /// According to rangeToKeep, try to reduce the given module. We implicitly
-  /// number each interesting operation and rangeToKeep indicates that if an
-  /// operation's number falls into certain range, then we will not try to
-  /// reduce that operation.
-  virtual void reduce(ModuleOp module,
-                      ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
-  /// Return the number of certain kind of operations that we would like to
-  /// reduce. This can be used to build a range map to exclude uninterested
-  /// operations.
-  virtual int getNumTargetOps(ModuleOp module) const = 0;
-};
-
-/// Reducer is a helper class to remove potential uninteresting operations from
-/// module.
-template <typename OpType>
-class Reducer : public OpReducer {
-public:
-  ~Reducer() override = default;
-
-  int getNumTargetOps(ModuleOp module) const override {
-    return std::distance(module.getOps<OpType>().begin(),
-                         module.getOps<OpType>().end());
-  }
-
-  void reduce(ModuleOp module,
-              ArrayRef<ReductionNode::Range> rangeToKeep) override {
-    std::vector<Operation *> opsToRemove;
-    size_t keepIndex = 0;
-
-    for (auto op : enumerate(module.getOps<OpType>())) {
-      int index = op.index();
-      if (keepIndex < rangeToKeep.size() &&
-          index == rangeToKeep[keepIndex].second)
-        ++keepIndex;
-      if (keepIndex == rangeToKeep.size() ||
-          index < rangeToKeep[keepIndex].first)
-        opsToRemove.push_back(op.value());
-    }
-
-    for (Operation *o : opsToRemove) {
-      o->dropAllUses();
-      o->erase();
-    }
-  }
-};
-
-} // end namespace mlir
-
-#endif

diff  --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6364157b9040b..65b2928bfd546 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -21,19 +21,25 @@
 #include <vector>
 
 #include "mlir/Reducer/Tester.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/ToolOutputFile.h"
 
 namespace mlir {
 
+class ModuleOp;
+class Region;
+
 /// Defines the traversal method options to be used in the reduction tree
 /// traversal.
 enum TraversalMode { SinglePath, Backtrack, MultiPath };
 
-/// This class defines the ReductionNode which is used to generate variant and
-/// keep track of the necessary metadata for the reduction pass. The nodes are
-/// linked together in a reduction tree structure which defines the relationship
-/// between all the 
diff erent generated variants.
+/// ReductionTreePass will build a reduction tree during module reduction and
+/// the ReductionNode represents the vertex of the tree. A ReductionNode records
+/// the information such as the reduced module, how this node is reduced from
+/// the parent node, etc. This information will be used to construct a reduction
+/// path to reduce the certain module.
 class ReductionNode {
 public:
   template <TraversalMode mode>
@@ -44,23 +50,46 @@ class ReductionNode {
   ReductionNode(ReductionNode *parent, std::vector<Range> range,
                 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
 
-  ReductionNode *getParent() const;
+  ReductionNode *getParent() const { return parent; }
+
+  /// If the ReductionNode hasn't been tested the interestingness, it'll be the
+  /// same module as the one in the parent node. Otherwise, the returned module
+  /// will have been applied certain reduction strategies. Note that it's not
+  /// necessary to be an interesting case or a reduced module (has smaller size
+  /// than parent's).
+  ModuleOp getModule() const { return module; }
+
+  /// Return the region we're reducing.
+  Region &getRegion() const { return *region; }
 
-  size_t getSize() const;
+  /// Return the size of the module.
+  size_t getSize() const { return size; }
 
   /// Returns true if the module exhibits the interesting behavior.
-  Tester::Interestingness isInteresting() const;
+  Tester::Interestingness isInteresting() const { return interesting; }
 
-  std::vector<Range> getRanges() const;
+  /// Return the range information that how this node is reduced from the parent
+  /// node.
+  ArrayRef<Range> getStartRanges() const { return startRanges; }
 
-  std::vector<ReductionNode *> &getVariants();
+  /// Return the range set we are using to generate variants.
+  ArrayRef<Range> getRanges() const { return ranges; }
+
+  /// Return the generated variants(the child nodes).
+  ArrayRef<ReductionNode *> getVariants() const { return variants; }
 
   /// Split the ranges and generate new variants.
-  std::vector<ReductionNode *> generateNewVariants();
+  ArrayRef<ReductionNode *> generateNewVariants();
 
   /// Update the interestingness result from tester.
   void update(std::pair<Tester::Interestingness, size_t> result);
 
+  /// Each Reduction Node contains a copy of module for applying rewrite
+  /// patterns. In addition, we only apply rewrite patterns in a certain region.
+  /// In init(), we will duplicate the module from parent node and locate the
+  /// corresponding region.
+  LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+
 private:
   /// A custom BFS iterator. The 
diff erence between
   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
@@ -87,8 +116,7 @@ class ReductionNode {
     BaseIterator &operator++() {
       ReductionNode *top = visitQueue.front();
       visitQueue.pop();
-      std::vector<ReductionNode *> neighbors = getNeighbors(top);
-      for (ReductionNode *node : neighbors)
+      for (ReductionNode *node : getNeighbors(top))
         visitQueue.push(node);
       return *this;
     }
@@ -103,7 +131,7 @@ class ReductionNode {
     ReductionNode *operator->() const { return visitQueue.front(); }
 
   protected:
-    std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
+    ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
       return static_cast<T *>(this)->getNeighbors(node);
     }
 
@@ -111,21 +139,42 @@ class ReductionNode {
     std::queue<ReductionNode *> visitQueue;
   };
 
-  /// The size of module after applying the range constraints.
+  /// This is a copy of module from parent node. All the reducer patterns will
+  /// be applied to this instance.
+  ModuleOp module;
+
+  /// The region of certain operation we're reducing in the module
+  Region *region;
+
+  /// The node we are reduced from. It means we will be in variants of parent
+  /// node.
+  ReductionNode *parent;
+
+  /// The size of module after applying the reducer patterns with range
+  /// constraints. This is only valid while the interestingness has been tested.
   size_t size;
 
   /// This is true if the module has been evaluated and it exhibits the
   /// interesting behavior.
   Tester::Interestingness interesting;
 
-  ReductionNode *parent;
-
-  /// We will only keep the operation with index falls into the ranges.
-  /// For example, number each function in a certain module and then we will
-  /// remove the functions with index outside the ranges and see if the
-  /// resulting module is still interesting.
+  /// `ranges` represents the selected subset of operations in the region. We
+  /// implictly number each operation in the region and ReductionTreePass will
+  /// apply reducer patterns on the operation falls into the `ranges`. We will
+  /// generate new ReductionNode with subset of `ranges` to see if we can do
+  /// further reduction. we may split the element in the `ranges` so that we can
+  /// have more subset variants from `ranges`.
+  /// Note that after applying the reducer patterns the number of operation in
+  /// the region may have changed, we need to update the `ranges` after that.
   std::vector<Range> ranges;
 
+  /// `startRanges` records the ranges of operations selected from the parent
+  /// node to produce this ReductionNode. It can be used to construct the
+  /// reduction path from the root. I.e., if we apply the same reducer patterns
+  /// and `startRanges` selection on the parent region, we will get the same
+  /// module as this node.
+  const std::vector<Range> startRanges;
+
   /// This points to the child variants that were created using this node as a
   /// starting point.
   std::vector<ReductionNode *> variants;
@@ -139,9 +188,9 @@ class ReductionNode::iterator<SinglePath>
     : public BaseIterator<iterator<SinglePath>> {
   friend BaseIterator<iterator<SinglePath>>;
   using BaseIterator::BaseIterator;
-  std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
+  ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
 };
 
 } // end namespace mlir
 
-#endif
+#endif // MLIR_REDUCER_REDUCTIONNODE_H

diff  --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
new file mode 100644
index 0000000000000..887d120e4c352
--- /dev/null
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -0,0 +1,56 @@
+//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
+#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+
+class RewritePatternSet;
+
+/// This is used to report the reduction patterns for a Dialect. While using
+/// mlir-reduce to reduce a module, we may want to transform certain cases into
+/// simpler forms by applying certain rewrite patterns. Implement the
+/// `populateReductionPatterns` to report those patterns by adding them to the
+/// RewritePatternSet.
+///
+/// Example:
+///   MyDialectReductionPattern::populateReductionPatterns(
+///       RewritePatternSet &patterns) {
+///       patterns.add<TensorOpReduction>(patterns.getContext());
+///   }
+///
+/// For DRR, mlir-tblgen will generate a helper function
+/// `populateWithGenerated` which has the same signature therefore you can
+/// delegate to the helper function as well.
+///
+/// Example:
+///   MyDialectReductionPattern::populateReductionPatterns(
+///       RewritePatternSet &patterns) {
+///       // Include the autogen file somewhere above.
+///       populateWithGenerated(patterns);
+///   }
+class DialectReductionPatternInterface
+    : public DialectInterface::Base<DialectReductionPatternInterface> {
+public:
+  /// Patterns provided here are intended to transform operations from a complex
+  /// form to a simpler form, without breaking the semantics of the program
+  /// being reduced. For example, you may want to replace the
+  /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
+  /// replacing an operation with a constant.
+  virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
+
+protected:
+  DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
+};
+
+} // end namespace mlir
+
+#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

diff  --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h
deleted file mode 100644
index a657d6187597b..0000000000000
--- a/mlir/include/mlir/Reducer/ReductionTreePass.h
+++ /dev/null
@@ -1,50 +0,0 @@
-//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Pass class. It provides a framework for
-// the implementation of 
diff erent reduction passes in the MLIR Reduce tool. It
-// allows for custom specification of the variant generation behavior. It
-// implements methods that define the 
diff erent possible traversals of the
-// reduction tree.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
-#define MLIR_REDUCER_REDUCTIONTREEPASS_H
-
-#include <vector>
-
-#include "PassDetail.h"
-#include "ReductionNode.h"
-#include "mlir/Reducer/Passes/OpReducer.h"
-#include "mlir/Reducer/Tester.h"
-
-#define DEBUG_TYPE "mlir-reduce"
-
-namespace mlir {
-
-/// This class defines the Reduction Tree Pass. It provides a framework to
-/// to implement a reduction pass using a tree structure to keep track of the
-/// generated reduced variants.
-class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
-public:
-  ReductionTreePass() = default;
-  ReductionTreePass(const ReductionTreePass &pass) = default;
-
-  /// Runs the pass instance in the pass pipeline.
-  void runOnOperation() override;
-
-private:
-  template <typename IteratorType>
-  ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
-                       ReductionNode *node);
-};
-
-} // end namespace mlir
-
-#endif

diff  --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 73601f031d574..e6dd933921792 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -1,7 +1,13 @@
 add_mlir_library(MLIRReduce
+   OptReductionPass.cpp
+   ReductionNode.cpp
+   ReductionTreePass.cpp
    Tester.cpp
    LINK_LIBS PUBLIC
    MLIRIR
+   MLIRPass
+   MLIRRewrite
+   MLIRTransformUtils
 )
 
 mlir_check_all_link_libraries(MLIRReduce)

diff  --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp
similarity index 88%
rename from mlir/tools/mlir-reduce/OptReductionPass.cpp
rename to mlir/lib/Reducer/OptReductionPass.cpp
index 533afc0f2aeed..2d0f47ced930d 100644
--- a/mlir/tools/mlir-reduce/OptReductionPass.cpp
+++ b/mlir/lib/Reducer/OptReductionPass.cpp
@@ -12,15 +12,27 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Reducer/OptReductionPass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
+#include "mlir/Reducer/PassDetail.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/Tester.h"
+#include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "mlir-reduce"
 
 using namespace mlir;
 
+namespace {
+
+class OptReductionPass : public OptReductionBase<OptReductionPass> {
+public:
+  /// Runs the pass instance in the pass pipeline.
+  void runOnOperation() override;
+};
+
+} // end anonymous namespace
+
 /// Runs the pass instance in the pass pipeline.
 void OptReductionPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");

diff  --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
similarity index 61%
rename from mlir/tools/mlir-reduce/ReductionNode.cpp
rename to mlir/lib/Reducer/ReductionNode.cpp
index a8e4af8c88223..e2a96681e9aa1 100644
--- a/mlir/tools/mlir-reduce/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -15,6 +15,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Reducer/ReductionNode.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "llvm/ADT/STLExtras.h"
 
 #include <algorithm>
@@ -23,102 +24,102 @@
 using namespace mlir;
 
 ReductionNode::ReductionNode(
-    ReductionNode *parent, std::vector<Range> ranges,
+    ReductionNode *parentNode, std::vector<Range> ranges,
     llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
-    : size(std::numeric_limits<size_t>::max()),
-      interesting(Tester::Interestingness::Untested),
-      /// Root node will have the parent pointer point to themselves.
-      parent(parent == nullptr ? this : parent), ranges(ranges),
-      allocator(allocator) {}
-
-/// Returns the size in bytes of the module.
-size_t ReductionNode::getSize() const { return size; }
-
-ReductionNode *ReductionNode::getParent() const { return parent; }
-
-/// Returns true if the module exhibits the interesting behavior.
-Tester::Interestingness ReductionNode::isInteresting() const {
-  return interesting;
+    /// Root node will have the parent pointer point to themselves.
+    : parent(parentNode == nullptr ? this : parentNode),
+      size(std::numeric_limits<size_t>::max()),
+      interesting(Tester::Interestingness::Untested), ranges(ranges),
+      startRanges(ranges), allocator(allocator) {
+  if (parent != this)
+    if (failed(initialize(parent->getModule(), parent->getRegion())))
+      llvm_unreachable("unexpected initialization failure");
 }
 
-std::vector<ReductionNode::Range> ReductionNode::getRanges() const {
-  return ranges;
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+                                        Region &targetRegion) {
+  // Use the mapper help us find the corresponding region after module clone.
+  BlockAndValueMapping mapper;
+  module = cast<ModuleOp>(parentModule->clone(mapper));
+  // Use the first block of targetRegion to locate the cloned region.
+  Block *block = mapper.lookup(&*targetRegion.begin());
+  region = block->getParent();
+  return success();
 }
 
-std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
-
-#include <iostream>
-
 /// If we haven't explored any variants from this node, we will create N
 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
 /// max element in `ranges` and create 2 new variants for each call.
-std::vector<ReductionNode *> ReductionNode::generateNewVariants() {
-  std::vector<ReductionNode *> newNodes;
+ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
+  int oldNumVariant = getVariants().size();
+
+  auto createNewNode = [this](std::vector<Range> ranges) {
+    return new (allocator.Allocate())
+        ReductionNode(this, std::move(ranges), allocator);
+  };
 
   // If we haven't created new variant, then we can create varients by removing
   // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
   // produce variants with range {{1, 3}} and {{4, 9}}.
-  if (variants.size() == 0 && ranges.size() != 1) {
-    for (const Range &range : ranges) {
-      std::vector<Range> subRanges = ranges;
+  if (variants.size() == 0 && getRanges().size() > 1) {
+    for (const Range &range : getRanges()) {
+      std::vector<Range> subRanges = getRanges();
       llvm::erase_value(subRanges, range);
-      ReductionNode *newNode = allocator.Allocate();
-      new (newNode) ReductionNode(this, subRanges, allocator);
-      newNodes.push_back(newNode);
-      variants.push_back(newNode);
+      variants.push_back(createNewNode(std::move(subRanges)));
     }
 
-    return newNodes;
+    return getVariants().drop_front(oldNumVariant);
   }
 
   // At here, we have created the type of variants mentioned above. We would
   // like to split the max range into 2 to create 2 new variants. Continue on
   // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
   // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
-  // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
+  // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
   auto maxElement = std::max_element(
       ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
         return (lhs.second - lhs.first) > (rhs.second - rhs.first);
       });
 
-  // We can't split range with lenght 1, which means we can't produce new
+  // The length of range is less than 1, we can't split it to create new
   // variant.
-  if (maxElement->second - maxElement->first == 1)
+  if (maxElement->second - maxElement->first <= 1)
     return {};
 
-  auto createNewNode = [this](const std::vector<Range> &ranges) {
-    ReductionNode *newNode = allocator.Allocate();
-    new (newNode) ReductionNode(this, ranges, allocator);
-    return newNode;
-  };
-
   Range maxRange = *maxElement;
-  std::vector<Range> subRanges = ranges;
+  std::vector<Range> subRanges = getRanges();
   auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
   int half = (maxRange.first + maxRange.second) / 2;
   *subRangesIter = std::make_pair(maxRange.first, half);
-  newNodes.push_back(createNewNode(subRanges));
+  variants.push_back(createNewNode(subRanges));
   *subRangesIter = std::make_pair(half, maxRange.second);
-  newNodes.push_back(createNewNode(subRanges));
+  variants.push_back(createNewNode(std::move(subRanges)));
 
-  variants.insert(variants.end(), newNodes.begin(), newNodes.end());
   auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
   it = ranges.insert(it, std::make_pair(maxRange.first, half));
   // Remove the range that has been split.
   ranges.erase(it + 2);
 
-  return newNodes;
+  return getVariants().drop_front(oldNumVariant);
 }
 
 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
   std::tie(interesting, size) = result;
+  // After applying reduction, the number of operation in the region may have
+  // changed. Non-interesting case won't be explored thus it's safe to keep it
+  // in a stale status.
+  if (interesting == Tester::Interestingness::True) {
+    // This module may has been updated. Reset the range.
+    ranges.clear();
+    ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
+  }
 }
 
-std::vector<ReductionNode *>
+ArrayRef<ReductionNode *>
 ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
   // Single Path: Traverses the smallest successful variant at each level until
   // no new successful variants can be created at that level.
-  llvm::ArrayRef<ReductionNode *> variantsFromParent =
+  ArrayRef<ReductionNode *> variantsFromParent =
       node->getParent()->getVariants();
 
   // The parent node created several variants and they may be waiting for
@@ -139,7 +140,8 @@ ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
       smallest = node;
   }
 
-  if (smallest != nullptr) {
+  if (smallest != nullptr &&
+      smallest->getSize() < node->getParent()->getSize()) {
     // We got a smallest one, keep traversing from this node.
     node = smallest;
   } else {

diff  --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
new file mode 100644
index 0000000000000..f564c3b42410f
--- /dev/null
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -0,0 +1,247 @@
+//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Pass class. It provides a framework for
+// the implementation of 
diff erent reduction passes in the MLIR Reduce tool. It
+// allows for custom specification of the variant generation behavior. It
+// implements methods that define the 
diff erent possible traversals of the
+// reduction tree.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Reducer/PassDetail.h"
+#include "mlir/Reducer/Passes.h"
+#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Reducer/Tester.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ManagedStatic.h"
+
+using namespace mlir;
+
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToKeep, we need to keep it and apply the given
+/// rewrite patterns on it.
+static void applyPatterns(Region &region,
+                          const FrozenRewritePatternSet &patterns,
+                          ArrayRef<ReductionNode::Range> rangeToKeep,
+                          bool eraseOpNotInRange) {
+  std::vector<Operation *> opsNotInRange;
+  std::vector<Operation *> opsInRange;
+  size_t keepIndex = 0;
+  for (auto op : enumerate(region.getOps())) {
+    int index = op.index();
+    if (keepIndex < rangeToKeep.size() &&
+        index == rangeToKeep[keepIndex].second)
+      ++keepIndex;
+    if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
+      opsNotInRange.push_back(&op.value());
+    else
+      opsInRange.push_back(&op.value());
+  }
+
+  // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
+  // matching in above iteration. Besides, erase op not-in-range may end up in
+  // invalid module, so `applyOpPatternsAndFold` should come before that
+  // transform.
+  for (Operation *op : opsInRange)
+    // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
+    // because we don't have expectation this reduction will be success or not.
+    (void)applyOpPatternsAndFold(op, patterns);
+
+  if (eraseOpNotInRange)
+    for (Operation *op : opsNotInRange) {
+      op->dropAllUses();
+      op->erase();
+    }
+}
+
+/// We will apply the reducer patterns to the operations in the ranges specified
+/// by ReductionNode. Note that we are not able to remove an operation without
+/// replacing it with another valid operation. However, The validity of module
+/// reduction is based on the Tester provided by the user and that means certain
+/// invalid module is still interested by the use. Thus we provide an
+/// alternative way to remove operations, which is using `eraseOpNotInRange` to
+/// erase the operations not in the range specified by ReductionNode.
+template <typename IteratorType>
+static void findOptimal(ModuleOp module, Region &region,
+                        const FrozenRewritePatternSet &patterns,
+                        const Tester &test, bool eraseOpNotInRange) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+  // While exploring the reduction tree, we always branch from an interesting
+  // node. Thus the root node must be interesting.
+  if (initStatus.first != Tester::Interestingness::True)
+    return;
+
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  std::vector<ReductionNode::Range> ranges{
+      {0, std::distance(region.op_begin(), region.op_end())}};
+
+  ReductionNode *root = allocator.Allocate();
+  new (root) ReductionNode(nullptr, std::move(ranges), allocator);
+  // Duplicate the module for root node and locate the region in the copy.
+  if (failed(root->initialize(module, region)))
+    llvm_unreachable("unexpected initialization failure");
+  root->update(initStatus);
+
+  ReductionNode *smallestNode = root;
+  IteratorType iter(root);
+
+  while (iter != IteratorType::end()) {
+    ReductionNode &currentNode = *iter;
+    Region &curRegion = currentNode.getRegion();
+
+    applyPatterns(curRegion, patterns, currentNode.getRanges(),
+                  eraseOpNotInRange);
+    currentNode.update(test.isInteresting(currentNode.getModule()));
+
+    if (currentNode.isInteresting() == Tester::Interestingness::True &&
+        currentNode.getSize() < smallestNode->getSize())
+      smallestNode = ¤tNode;
+
+    ++iter;
+  }
+
+  // At here, we have found an optimal path to reduce the given region. Retrieve
+  // the path and apply the reducer to it.
+  SmallVector<ReductionNode *> trace;
+  ReductionNode *curNode = smallestNode;
+  trace.push_back(curNode);
+  while (curNode != root) {
+    curNode = curNode->getParent();
+    trace.push_back(curNode);
+  }
+
+  // Reduce the region through the optimal path.
+  while (!trace.empty()) {
+    ReductionNode *top = trace.pop_back_val();
+    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+  }
+
+  if (test.isInteresting(module).first != Tester::Interestingness::True)
+    llvm::report_fatal_error("Reduced module is not interesting");
+  if (test.isInteresting(module).second != smallestNode->getSize())
+    llvm::report_fatal_error(
+        "Reduced module doesn't have consistent size with smallestNode");
+}
+
+template <typename IteratorType>
+static void findOptimal(ModuleOp module, Region &region,
+                        const FrozenRewritePatternSet &patterns,
+                        const Tester &test) {
+  // We separate the reduction process into 2 steps, the first one is to erase
+  // redundant operations and the second one is to apply the reducer patterns.
+
+  // In the first phase, we don't apply any patterns so that we only select the
+  // range of operations to keep to the module stay interesting.
+  findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
+                            /*eraseOpNotInRange=*/true);
+  // In the second phase, we suppose that no operation is redundant, so we try
+  // to rewrite the operation into simpler form.
+  findOptimal<IteratorType>(module, region, patterns, test,
+                            /*eraseOpNotInRange=*/false);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Reduction Pattern Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternInterfaceCollection
+    : public DialectInterfaceCollection<DialectReductionPatternInterface> {
+public:
+  using Base::Base;
+
+  // Collect the reduce patterns defined by each dialect.
+  void populateReductionPatterns(RewritePatternSet &pattern) const {
+    for (const DialectReductionPatternInterface &interface : *this)
+      interface.populateReductionPatterns(pattern);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ReductionTreePass
+//===----------------------------------------------------------------------===//
+
+/// This class defines the Reduction Tree Pass. It provides a framework to
+/// to implement a reduction pass using a tree structure to keep track of the
+/// generated reduced variants.
+class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
+public:
+  ReductionTreePass() = default;
+  ReductionTreePass(const ReductionTreePass &pass) = default;
+
+  LogicalResult initialize(MLIRContext *context) override;
+
+  /// Runs the pass instance in the pass pipeline.
+  void runOnOperation() override;
+
+private:
+  void reduceOp(ModuleOp module, Region &region);
+
+  FrozenRewritePatternSet reducerPatterns;
+};
+
+} // end anonymous namespace
+
+LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+  RewritePatternSet patterns(context);
+  ReductionPatternInterfaceCollection reducePatternCollection(context);
+  reducePatternCollection.populateReductionPatterns(patterns);
+  reducerPatterns = std::move(patterns);
+  return success();
+}
+
+void ReductionTreePass::runOnOperation() {
+  Operation *topOperation = getOperation();
+  while (topOperation->getParentOp() != nullptr)
+    topOperation = topOperation->getParentOp();
+  ModuleOp module = cast<ModuleOp>(topOperation);
+
+  SmallVector<Operation *, 8> workList;
+  workList.push_back(getOperation());
+
+  do {
+    Operation *op = workList.pop_back_val();
+
+    for (Region &region : op->getRegions())
+      if (!region.empty())
+        reduceOp(module, region);
+
+    for (Region &region : op->getRegions())
+      for (Operation &op : region.getOps())
+        if (op.getNumRegions() != 0)
+          workList.push_back(&op);
+  } while (!workList.empty());
+}
+
+void ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
+  Tester test(testerName, testerArgs);
+  switch (traversalModeId) {
+  case TraversalMode::SinglePath:
+    findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
+        module, region, reducerPatterns, test);
+    break;
+  default:
+    llvm_unreachable("Unsupported mode");
+  }
+}
+
+std::unique_ptr<Pass> mlir::createReductionTreePass() {
+  return std::make_unique<ReductionTreePass>();
+}

diff  --git a/mlir/lib/Reducer/Tester.cpp b/mlir/lib/Reducer/Tester.cpp
index c0d4862481016..b5531a9a0ca02 100644
--- a/mlir/lib/Reducer/Tester.cpp
+++ b/mlir/lib/Reducer/Tester.cpp
@@ -15,7 +15,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Reducer/Tester.h"
-
+#include "mlir/IR/Verifier.h"
 #include "llvm/Support/ToolOutputFile.h"
 
 using namespace mlir;
@@ -25,6 +25,12 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
 
 std::pair<Tester::Interestingness, size_t>
 Tester::isInteresting(ModuleOp module) const {
+  // The reduced module should always be vaild, or we may end up retaining the
+  // error message by an invalid case. Besides, an invalid module may not be
+  // able to print properly.
+  if (failed(verify(module)))
+    return std::make_pair(Interestingness::False, /*size=*/0);
+
   SmallString<128> filepath;
   int fd;
 
@@ -50,7 +56,6 @@ Tester::isInteresting(ModuleOp module) const {
 /// true if the interesting behavior is present in the test case or false
 /// otherwise.
 Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
-
   std::vector<StringRef> testerArgs;
   testerArgs.push_back(testCase);
 

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 32870af2a4fe9..d1cf46ae5788b 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_library(MLIRTestDialect
   MLIRInferTypeOpInterface
   MLIRLinalgTransforms
   MLIRPass
+  MLIRReduce
   MLIRStandard
   MLIRStandardOpsTransforms
   MLIRTransformUtils

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 396f71eac802f..77aa20956bf51 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -8,6 +8,7 @@
 
 #include "TestDialect.h"
 #include "TestAttributes.h"
+#include "TestInterfaces.h"
 #include "TestTypes.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -16,6 +17,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/StringSwitch.h"
@@ -170,6 +172,18 @@ struct TestInlinerInterface : public DialectInlinerInterface {
     return builder.create<TestCastOp>(conversionLoc, resultType, input);
   }
 };
+
+struct TestReductionPatternInterface : public DialectReductionPatternInterface {
+public:
+  TestReductionPatternInterface(Dialect *dialect)
+      : DialectReductionPatternInterface(dialect) {}
+
+  virtual void
+  populateReductionPatterns(RewritePatternSet &patterns) const final {
+    populateTestReductionPatterns(patterns);
+  }
+};
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -207,7 +221,7 @@ void TestDialect::initialize() {
 #include "TestOps.cpp.inc"
       >();
   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
-                TestInlinerInterface>();
+                TestInlinerInterface, TestReductionPatternInterface>();
   allowUnknownOperations();
 
   // Instantiate our fallback op interface that we'll use on specific

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 79cc8dd789267..d57a2c119723a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -34,6 +34,7 @@
 
 namespace mlir {
 class DLTIDialect;
+class RewritePatternSet;
 } // namespace mlir
 
 #include "TestOpEnums.h.inc"
@@ -47,6 +48,7 @@ class DLTIDialect;
 namespace mlir {
 namespace test {
 void registerTestDialect(DialectRegistry &registry);
+void populateTestReductionPatterns(RewritePatternSet &patterns);
 } // namespace test
 } // namespace mlir
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8c3e45600b9f8..847436ea3fb09 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2113,4 +2113,19 @@ def DataLayoutQueryOp : TEST_Op<"data_layout_query"> {
   let results = (outs AnyType:$res);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Reducer Patterns
+//===----------------------------------------------------------------------===//
+
+def OpCrashLong : TEST_Op<"op_crash_long"> {
+  let arguments = (ins I32, I32, I32);
+  let results = (outs I32);
+}
+
+def OpCrashShort : TEST_Op<"op_crash_short"> {
+  let results = (outs I32);
+}
+
+def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
+
 #endif // TEST_OPS

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 32f47f40ad36a..eabf9e0110eb7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -58,6 +58,14 @@ namespace {
 #include "TestPatterns.inc"
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Test Reduce Pattern Interface
+//===----------------------------------------------------------------------===//
+
+void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
+  populateWithGenerated(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Canonicalizer Driver.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp
index cb9d219349095..fe16390129e94 100644
--- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp
+++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp
@@ -38,7 +38,7 @@ void TestReducer::runOnFunction() {
     op.walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
 
-      if (opName == "test.crashOp") {
+      if (opName.contains("op_crash")) {
         llvm::errs() << "MLIR Reducer Test generated failure: Found "
                         "\"crashOp\" operation\n";
         exit(1);

diff  --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir
new file mode 100644
index 0000000000000..dd65b0fbc0854
--- /dev/null
+++ b/mlir/test/mlir-reduce/crashop-reduction.mlir
@@ -0,0 +1,20 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
+// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short".
+
+// CHECK-NOT: func @simple1() {
+func @simple1() {
+  return
+}
+
+// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
+func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
+  // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32
+  %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32
+  return
+}
+
+// CHECK-NOT: func @simple5() {
+func @simple5() {
+  return
+}

diff  --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir
index 7faa1505b2aaf..c0a93c8eb6ed6 100644
--- a/mlir/test/mlir-reduce/dce-test.mlir
+++ b/mlir/test/mlir-reduce/dce-test.mlir
@@ -12,6 +12,6 @@ func nested @dead_nested_function()
 
 // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
 func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
-  "test.crashOp" () : () -> ()
+  "test.op_crash" () : () -> ()
   return
 }

diff  --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir
index 06e51a1a8f468..22a444040fabe 100644
--- a/mlir/test/mlir-reduce/multiple-function.mlir
+++ b/mlir/test/mlir-reduce/multiple-function.mlir
@@ -1,5 +1,5 @@
 // UNSUPPORTED: system-windows
-// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
 // This input should be reduced by the pass pipeline so that only
 // the @simple5 function remains as this is the shortest function
 // containing the interesting behavior.
@@ -16,7 +16,7 @@ func @simple2() {
 
 // CHECK-LABEL: func @simple3() {
 func @simple3() {
-  "test.crashOp" () : () -> ()
+  "test.op_crash" () : () -> ()
   return
 }
 
@@ -29,7 +29,7 @@ func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
   %0 = memref.alloc() : memref<2xf32>
   br ^bb3(%0 : memref<2xf32>)
 ^bb3(%1: memref<2xf32>):
-  "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+  "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
   return
 }
 

diff  --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir
index 663003009209a..fd90da3e08392 100644
--- a/mlir/test/mlir-reduce/simple-test.mlir
+++ b/mlir/test/mlir-reduce/simple-test.mlir
@@ -1,5 +1,5 @@
 // UNSUPPORTED: system-windows
-// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh'
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh'
 
 func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
   cond_br %arg0, ^bb1, ^bb2

diff  --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir
index 732963553e900..adb2376745b41 100644
--- a/mlir/test/mlir-reduce/single-function.mlir
+++ b/mlir/test/mlir-reduce/single-function.mlir
@@ -2,6 +2,6 @@
 // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer
 
 func @test() {
-  "test.crashOp"() : () -> ()
+  "test.op_crash"() : () -> ()
   return
 }

diff  --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt
index d58fe915a2864..b59bfcb441123 100644
--- a/mlir/tools/mlir-reduce/CMakeLists.txt
+++ b/mlir/tools/mlir-reduce/CMakeLists.txt
@@ -43,9 +43,6 @@ set(LIBS
   )
 
 add_llvm_tool(mlir-reduce
-  OptReductionPass.cpp
-  ReductionNode.cpp
-  ReductionTreePass.cpp
   mlir-reduce.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
deleted file mode 100644
index e3a7565bd6a58..0000000000000
--- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Pass class. It provides a framework for
-// the implementation of 
diff erent reduction passes in the MLIR Reduce tool. It
-// allows for custom specification of the variant generation behavior. It
-// implements methods that define the 
diff erent possible traversals of the
-// reduction tree.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Passes.h"
-
-#include "llvm/Support/Allocator.h"
-
-using namespace mlir;
-
-static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
-  if (opType == ModuleOp::getOperationName())
-    return std::make_unique<Reducer<ModuleOp>>();
-  else if (opType == FuncOp::getOperationName())
-    return std::make_unique<Reducer<FuncOp>>();
-  llvm_unreachable("Now only supports two built-in ops");
-}
-
-void ReductionTreePass::runOnOperation() {
-  ModuleOp module = this->getOperation();
-  std::unique_ptr<OpReducer> reducer = getOpReducer(opReducerName);
-  std::vector<std::pair<int, int>> ranges = {
-      {0, reducer->getNumTargetOps(module)}};
-
-  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
-
-  ReductionNode *root = allocator.Allocate();
-  new (root) ReductionNode(nullptr, ranges, allocator);
-
-  ModuleOp golden = module;
-  switch (traversalModeId) {
-  case TraversalMode::SinglePath:
-    golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, std::move(reducer), root);
-    break;
-  default:
-    llvm_unreachable("Unsupported mode");
-  }
-
-  if (golden != module) {
-    module.getBody()->clear();
-    module.getBody()->getOperations().splice(module.getBody()->begin(),
-                                             golden.getBody()->getOperations());
-    golden->destroy();
-  }
-}
-
-template <typename IteratorType>
-ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
-                                        std::unique_ptr<OpReducer> reducer,
-                                        ReductionNode *root) {
-  Tester test(testerName, testerArgs);
-  std::pair<Tester::Interestingness, size_t> initStatus =
-      test.isInteresting(module);
-
-  if (initStatus.first != Tester::Interestingness::True) {
-    LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested");
-    return module;
-  }
-
-  root->update(initStatus);
-
-  ReductionNode *smallestNode = root;
-  ModuleOp golden = module;
-
-  IteratorType iter(root);
-
-  while (iter != IteratorType::end()) {
-    ModuleOp cloneModule = module.clone();
-
-    ReductionNode &currentNode = *iter;
-    reducer->reduce(cloneModule, currentNode.getRanges());
-
-    std::pair<Tester::Interestingness, size_t> result =
-        test.isInteresting(cloneModule);
-    currentNode.update(result);
-
-    if (result.first == Tester::Interestingness::True &&
-        result.second < smallestNode->getSize()) {
-      smallestNode = ¤tNode;
-      golden = cloneModule;
-    } else {
-      cloneModule->destroy();
-    }
-
-    ++iter;
-  }
-
-  return golden;
-}
-
-std::unique_ptr<Pass> mlir::createReductionTreePass() {
-  return std::make_unique<ReductionTreePass>();
-}

diff  --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index b23177d42fce4..1801ce1ffb58a 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -13,22 +13,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <vector>
-
 #include "mlir/InitAllDialects.h"
 #include "mlir/InitAllPasses.h"
 #include "mlir/Parser.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
-#include "mlir/Reducer/OptReductionPass.h"
 #include "mlir/Reducer/Passes.h"
-#include "mlir/Reducer/Passes/OpReducer.h"
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Tester.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/ToolOutputFile.h"
 


        


More information about the Mlir-commits mailing list