[Mlir-commits] [mlir] 6b0cef3 - Refactor the architecture of mlir-reduce

Jacques Pienaar llvmlistbot at llvm.org
Wed Apr 14 13:41:42 PDT 2021


Author: Chia-hung Duan
Date: 2021-04-14T13:40:44-07:00
New Revision: 6b0cef3e02eef277c60c95e4c1ca71f71091d8ae

URL: https://github.com/llvm/llvm-project/commit/6b0cef3e02eef277c60c95e4c1ca71f71091d8ae
DIFF: https://github.com/llvm/llvm-project/commit/6b0cef3e02eef277c60c95e4c1ca71f71091d8ae.diff

LOG: Refactor the architecture of mlir-reduce

Add iterator for ReductionNode traversal and use range to indicate the
region we would like to keep. Refactor the interaction between
Pass/Tester/ReductionNode.
Now it'll be easier to add new traversal type and OpReducer

Reviewed By: jpienaar, rriddle

Differential Revision: https://reviews.llvm.org/D99713

Added: 
    mlir/tools/mlir-reduce/ReductionTreePass.cpp

Modified: 
    mlir/include/mlir/Reducer/Passes/OpReducer.h
    mlir/include/mlir/Reducer/ReductionNode.h
    mlir/include/mlir/Reducer/ReductionTreePass.h
    mlir/include/mlir/Reducer/Tester.h
    mlir/lib/Reducer/CMakeLists.txt
    mlir/lib/Reducer/Tester.cpp
    mlir/tools/mlir-reduce/CMakeLists.txt
    mlir/tools/mlir-reduce/OptReductionPass.cpp
    mlir/tools/mlir-reduce/ReductionNode.cpp
    mlir/tools/mlir-reduce/mlir-reduce.cpp

Removed: 
    mlir/include/mlir/Reducer/ReductionTreeUtils.h
    mlir/tools/mlir-reduce/Passes/OpReducer.cpp
    mlir/tools/mlir-reduce/ReductionTreeUtils.cpp


################################################################################
diff  --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h
index b16b5c0d763a9..6e48d41187356 100644
--- a/mlir/include/mlir/Reducer/Passes/OpReducer.h
+++ b/mlir/include/mlir/Reducer/Passes/OpReducer.h
@@ -15,65 +15,52 @@
 #ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
 #define MLIR_REDUCER_PASSES_OPREDUCER_H
 
-#include "mlir/IR/Region.h"
+#include <limits>
+
 #include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/ReductionTreeUtils.h"
 #include "mlir/Reducer/Tester.h"
 
 namespace mlir {
 
-class OpReducerImpl {
+class OpReducer {
 public:
-  OpReducerImpl(
-      llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps);
-
-  /// Return the name of this reducer class.
-  StringRef getName();
-
-  /// Return the initial transformSpace containing the transformable indices.
-  std::vector<bool> initTransformSpace(ModuleOp module);
-
-  /// Generate variants by removing OpType operations from the module in the
-  /// parent and link the variants as childs in the Reduction Tree Pass.
-  void generateVariants(ReductionNode *parent, const Tester &test,
-                        int numVariants);
-
-  /// Generate variants by removing OpType operations from the module in the
-  /// parent and link the variants as childs in the Reduction Tree Pass. The
-  /// transform argument defines the function used to remove the OpTpye
-  /// operations in range of indexed OpType operations.
-  void generateVariants(ReductionNode *parent, const Tester &test,
-                        int numVariants,
-                        llvm::function_ref<void(ModuleOp, int, int)> transform);
-
-private:
-  llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps;
+  virtual ~OpReducer() = default;
+  /// According to rangeToKeep, try to reduce the given module. We implicitly
+  /// number each interesting operation and rangeToKeep indicates that if an
+  /// operation's number falls into certain range, then we will not try to
+  /// reduce that operation.
+  virtual void reduce(ModuleOp module,
+                      ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
+  /// Return the number of certain kind of operations that we would like to
+  /// reduce. This can be used to build a range map to exclude uninterested
+  /// operations.
+  virtual int getNumTargetOps(ModuleOp module) const = 0;
 };
 
-/// The OpReducer class defines a variant generator method that produces
-/// multiple variants by eliminating 
diff erent OpType operations from the
-/// parent module.
+/// Reducer is a helper class to remove potential uninteresting operations from
+/// module.
 template <typename OpType>
-class OpReducer {
+class Reducer : public OpReducer {
 public:
-  OpReducer() : impl(new OpReducerImpl(getSpecificOps)) {}
+  ~Reducer() override = default;
 
-  /// Returns the vector of pointer to the OpType operations in the module.
-  static std::vector<Operation *> getSpecificOps(ModuleOp module) {
-    std::vector<Operation *> ops;
-    for (auto op : module.getOps<OpType>()) {
-      ops.push_back(op);
-    }
-    return ops;
+  int getNumTargetOps(ModuleOp module) const override {
+    return std::distance(module.getOps<OpType>().begin(),
+                         module.getOps<OpType>().end());
   }
 
-  /// Deletes the OpType operations in the module in the specified index.
-  static void deleteOps(ModuleOp module, int start, int end) {
+  void reduce(ModuleOp module,
+              ArrayRef<ReductionNode::Range> rangeToKeep) override {
     std::vector<Operation *> opsToRemove;
+    size_t keepIndex = 0;
 
-    for (auto op : enumerate(getSpecificOps(module))) {
+    for (auto op : enumerate(module.getOps<OpType>())) {
       int index = op.index();
-      if (index >= start && index < end)
+      if (keepIndex < rangeToKeep.size() &&
+          index == rangeToKeep[keepIndex].second)
+        ++keepIndex;
+      if (keepIndex == rangeToKeep.size() ||
+          index < rangeToKeep[keepIndex].first)
         opsToRemove.push_back(op.value());
     }
 
@@ -82,24 +69,6 @@ class OpReducer {
       o->erase();
     }
   }
-
-  /// Return the name of this reducer class.
-  StringRef getName() { return impl->getName(); }
-
-  /// Return the initial transformSpace containing the transformable indices.
-  std::vector<bool> initTransformSpace(ModuleOp module) {
-    return impl->initTransformSpace(module);
-  }
-
-  /// Generate variants by removing OpType operations from the module in the
-  /// parent and link the variants as childs in the Reduction Tree Pass.
-  void generateVariants(ReductionNode *parent, const Tester &test,
-                        int numVariants) {
-    impl->generateVariants(parent, test, numVariants, deleteOps);
-  }
-
-private:
-  std::unique_ptr<OpReducerImpl> impl;
 };
 
 } // end namespace mlir

diff  --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 79758b0330545..6364157b9040b 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -17,82 +17,129 @@
 #ifndef MLIR_REDUCER_REDUCTIONNODE_H
 #define MLIR_REDUCER_REDUCTIONNODE_H
 
+#include <queue>
 #include <vector>
 
 #include "mlir/Reducer/Tester.h"
+#include "llvm/Support/Allocator.h"
 #include "llvm/Support/ToolOutputFile.h"
 
 namespace mlir {
 
-/// This class defines the ReductionNode which is used to wrap the module of
-/// a generated variant and keep track of the necessary metadata for the
-/// reduction pass. The nodes are linked together in a reduction tree structure
-/// which defines the relationship between all the 
diff erent generated variants.
+/// Defines the traversal method options to be used in the reduction tree
+/// traversal.
+enum TraversalMode { SinglePath, Backtrack, MultiPath };
+
+/// This class defines the ReductionNode which is used to generate variant and
+/// keep track of the necessary metadata for the reduction pass. The nodes are
+/// linked together in a reduction tree structure which defines the relationship
+/// between all the 
diff erent generated variants.
 class ReductionNode {
 public:
-  ReductionNode(ModuleOp module, ReductionNode *parent);
-
-  ReductionNode(ModuleOp module, ReductionNode *parent,
-                std::vector<bool> transformSpace);
+  template <TraversalMode mode>
+  class iterator;
 
-  /// Calculates and initializes the size and interesting values of the node.
-  void measureAndTest(const Tester &test);
+  using Range = std::pair<int, int>;
 
-  /// Returns the module.
-  ModuleOp getModule() const { return module; }
+  ReductionNode(ReductionNode *parent, std::vector<Range> range,
+                llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
 
-  /// Returns true if the size and interestingness have been calculated.
-  bool isEvaluated() const;
+  ReductionNode *getParent() const;
 
-  /// Returns the size in bytes of the module.
-  int getSize() const;
+  size_t getSize() const;
 
   /// Returns true if the module exhibits the interesting behavior.
-  bool isInteresting() const;
-
-  /// Returns the pointer to a child variant by index.
-  ReductionNode *getVariant(unsigned long index) const;
+  Tester::Interestingness isInteresting() const;
 
-  /// Returns the number of child variants.
-  int variantsSize() const;
+  std::vector<Range> getRanges() const;
 
-  /// Returns true if the vector containing the child variants is empty.
-  bool variantsEmpty() const;
+  std::vector<ReductionNode *> &getVariants();
 
-  /// Sort the child variants and remove the uninteresting ones.
-  void organizeVariants(const Tester &test);
+  /// Split the ranges and generate new variants.
+  std::vector<ReductionNode *> generateNewVariants();
 
-  /// Returns the number of child variants.
-  int transformSpaceSize();
-
-  /// Returns a vector indicating the transformed indices as true.
-  const std::vector<bool> getTransformSpace();
+  /// Update the interestingness result from tester.
+  void update(std::pair<Tester::Interestingness, size_t> result);
 
 private:
-  /// Link a child variant node.
-  void linkVariant(ReductionNode *newVariant);
-
-  // This is the MLIR module of this variant.
-  ModuleOp module;
-
-  // This is true if the module has been evaluated and it exhibits the
-  // interesting behavior.
-  bool interesting;
-
-  // This indicates the number of characters in the printed module if the module
-  // has been evaluated.
-  int size;
-
-  // This indicates if the module has been evaluated (measured and tested).
-  bool evaluated;
-
-  // Indicates the indices in the node that have been transformed in previous
-  // levels of the reduction tree.
-  std::vector<bool> transformSpace;
+  /// A custom BFS iterator. The 
diff erence between
+  /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
+  /// We may explore more neighbors at certain node if we didn't find interested
+  /// event. As a result, we defer pushing adjacent nodes until poping the last
+  /// visited node. The graph exploration strategy will be put in
+  /// getNeighbors().
+  ///
+  /// Subclass BaseIterator and implement traversal strategy in getNeighbors().
+  template <typename T>
+  class BaseIterator {
+  public:
+    BaseIterator(ReductionNode *node) { visitQueue.push(node); }
+    BaseIterator(const BaseIterator &) = default;
+    BaseIterator() = default;
+
+    static BaseIterator end() { return BaseIterator(); }
+
+    bool operator==(const BaseIterator &i) {
+      return visitQueue == i.visitQueue;
+    }
+    bool operator!=(const BaseIterator &i) { return !(*this == i); }
+
+    BaseIterator &operator++() {
+      ReductionNode *top = visitQueue.front();
+      visitQueue.pop();
+      std::vector<ReductionNode *> neighbors = getNeighbors(top);
+      for (ReductionNode *node : neighbors)
+        visitQueue.push(node);
+      return *this;
+    }
+
+    BaseIterator operator++(int) {
+      BaseIterator tmp = *this;
+      ++*this;
+      return tmp;
+    }
+
+    ReductionNode &operator*() const { return *(visitQueue.front()); }
+    ReductionNode *operator->() const { return visitQueue.front(); }
+
+  protected:
+    std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
+      return static_cast<T *>(this)->getNeighbors(node);
+    }
+
+  private:
+    std::queue<ReductionNode *> visitQueue;
+  };
+
+  /// The size of module after applying the range constraints.
+  size_t size;
+
+  /// This is true if the module has been evaluated and it exhibits the
+  /// interesting behavior.
+  Tester::Interestingness interesting;
+
+  ReductionNode *parent;
+
+  /// We will only keep the operation with index falls into the ranges.
+  /// For example, number each function in a certain module and then we will
+  /// remove the functions with index outside the ranges and see if the
+  /// resulting module is still interesting.
+  std::vector<Range> ranges;
+
+  /// This points to the child variants that were created using this node as a
+  /// starting point.
+  std::vector<ReductionNode *> variants;
+
+  llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator;
+};
 
-  // This points to the child variants that were created using this node as a
-  // starting point.
-  std::vector<std::unique_ptr<ReductionNode>> variants;
+// Specialized iterator for SinglePath traversal
+template <>
+class ReductionNode::iterator<SinglePath>
+    : public BaseIterator<iterator<SinglePath>> {
+  friend BaseIterator<iterator<SinglePath>>;
+  using BaseIterator::BaseIterator;
+  std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
 };
 
 } // end namespace mlir

diff  --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h
index be13191dbd384..8f1269e4ae29b 100644
--- a/mlir/include/mlir/Reducer/ReductionTreePass.h
+++ b/mlir/include/mlir/Reducer/ReductionTreePass.h
@@ -22,131 +22,40 @@
 #include "PassDetail.h"
 #include "ReductionNode.h"
 #include "mlir/Reducer/Passes/OpReducer.h"
-#include "mlir/Reducer/ReductionTreeUtils.h"
 #include "mlir/Reducer/Tester.h"
 
 #define DEBUG_TYPE "mlir-reduce"
 
 namespace mlir {
 
-// Defines the traversal method options to be used in the reduction tree
-/// traversal.
-enum TraversalMode { SinglePath, Backtrack, MultiPath };
-
 /// This class defines the Reduction Tree Pass. It provides a framework to
 /// to implement a reduction pass using a tree structure to keep track of the
 /// generated reduced variants.
-template <typename Reducer, TraversalMode mode>
-class ReductionTreePass
-    : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
+class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
 public:
   ReductionTreePass(const ReductionTreePass &pass)
-      : ReductionTreeBase<ReductionTreePass<Reducer, mode>>(pass),
-        root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
-        test(pass.test) {}
+      : ReductionTreeBase<ReductionTreePass>(pass), opType(pass.opType),
+        mode(pass.mode), test(pass.test) {}
 
-  ReductionTreePass(const Tester &test) : test(test) {}
+  ReductionTreePass(StringRef opType, TraversalMode mode, const Tester &test)
+      : opType(opType), mode(mode), test(test) {}
 
   /// Runs the pass instance in the pass pipeline.
-  void runOnOperation() override {
-    ModuleOp module = this->getOperation();
-    Reducer reducer;
-    std::vector<bool> transformSpace = reducer.initTransformSpace(module);
-    ReductionNode *reduced;
-
-    this->root =
-        std::make_unique<ReductionNode>(module, nullptr, transformSpace);
-
-    root->measureAndTest(test);
-
-    LLVM_DEBUG(llvm::dbgs() << "\nReduction Tree Pass: " << reducer.getName(););
-    switch (mode) {
-    case SinglePath:
-      LLVM_DEBUG(llvm::dbgs() << " (Single Path)\n";);
-      reduced = singlePathTraversal();
-      break;
-    default:
-      llvm::report_fatal_error("Traversal method not currently supported.");
-    }
-
-    ReductionTreeUtils::updateGoldenModule(module,
-                                           reduced->getModule().clone());
-  }
+  void runOnOperation() override;
 
 private:
-  // Points to the root node in this reduction tree.
-  std::unique_ptr<ReductionNode> root;
-
-  // This object defines the variant generation at each level of the reduction
-  // tree.
-  Reducer reducer;
-
-  // This is used to test the interesting behavior of the reduction nodes in the
-  // tree.
-  const Tester &test;
-
-  /// Traverse the most reduced path in the reduction tree by generating the
-  /// variants at each level using the Reducer parameter's generateVariants
-  /// function. Stops when no new successful variants can be created at the
-  /// current level.
-  ReductionNode *singlePathTraversal() {
-    ReductionNode *currNode = root.get();
-    ReductionNode *smallestNode = currNode;
-    int tSpaceSize = currNode->transformSpaceSize();
-    std::vector<int> path;
-
-    ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
-
-    LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variant: applying the ");
-    LLVM_DEBUG(llvm::dbgs() << "transformation to the entire module\n");
+  template <typename IteratorType>
+  ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
+                       ReductionNode *node);
 
-    reducer.generateVariants(currNode, test, 1);
-    LLVM_DEBUG(llvm::dbgs() << "Testing\n");
-    currNode->organizeVariants(test);
+  /// The name of operation that we will try to remove.
+  StringRef opType;
 
-    if (!currNode->variantsEmpty())
-      return currNode->getVariant(0);
+  TraversalMode mode;
 
-    while (tSpaceSize != 1) {
-      ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
-
-      LLVM_DEBUG(llvm::dbgs() << "\nGenerating 2 variants: applying the ");
-      LLVM_DEBUG(llvm::dbgs() << "transformation to two 
diff erent sections ");
-      LLVM_DEBUG(llvm::dbgs() << "of transformable indices\n");
-
-      reducer.generateVariants(currNode, test, 2);
-      LLVM_DEBUG(llvm::dbgs() << "Testing\n");
-      currNode->organizeVariants(test);
-
-      if (currNode->variantsEmpty())
-        break;
-
-      currNode = currNode->getVariant(0);
-      tSpaceSize = currNode->transformSpaceSize();
-      path.push_back(0);
-    }
-
-    if (tSpaceSize == 1) {
-      ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
-
-      LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variants: applying the ");
-      LLVM_DEBUG(llvm::dbgs() << "transformation to the only transformable");
-      LLVM_DEBUG(llvm::dbgs() << "index\n");
-
-      reducer.generateVariants(currNode, test, 1);
-      LLVM_DEBUG(llvm::dbgs() << "Testing\n");
-      currNode->organizeVariants(test);
-
-      if (!currNode->variantsEmpty()) {
-        currNode = currNode->getVariant(0);
-        path.push_back(0);
-
-        ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
-      }
-    }
-
-    return currNode;
-  }
+  /// This is used to test the interesting behavior of the reduction nodes in
+  /// the tree.
+  const Tester &test;
 };
 
 } // end namespace mlir

diff  --git a/mlir/include/mlir/Reducer/ReductionTreeUtils.h b/mlir/include/mlir/Reducer/ReductionTreeUtils.h
deleted file mode 100644
index cb938e2e4765a..0000000000000
--- a/mlir/include/mlir/Reducer/ReductionTreeUtils.h
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- ReductionTreeUtils.h - Reduction Tree utilities ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Utilities. It defines pass independent
-// methods that help in the reduction passes of the MLIR Reduce tool.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_REDUCTIONTREEUTILS_H
-#define MLIR_REDUCER_REDUCTIONTREEUTILS_H
-
-#include <tuple>
-
-#include "PassDetail.h"
-#include "ReductionNode.h"
-#include "mlir/Reducer/Tester.h"
-#include "llvm/Support/Debug.h"
-
-namespace mlir {
-
-// Defines the utilities for the implementation of custom reduction
-// passes using the ReductionTreePass framework.
-namespace ReductionTreeUtils {
-
-/// Update the golden module's content with that of the reduced module.
-void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
-
-/// Update the smallest node traversed so far in the reduction tree and
-/// print the debugging information for the currNode being traversed.
-void updateSmallestNode(ReductionNode *currNode, ReductionNode *&smallestNode,
-                        std::vector<int> path);
-
-/// Create a transform space index vector based on the specified number of
-/// indices.
-std::vector<bool> createTransformSpace(ModuleOp module, int numIndices);
-
-/// Create the specified number of variants by applying the transform method
-/// to 
diff erent ranges of indices in the parent module. The isDeletion boolean
-/// specifies if the transformation is the deletion of indices.
-void createVariants(ReductionNode *parent, const Tester &test, int numVariants,
-                    llvm::function_ref<void(ModuleOp, int, int)> transform,
-                    bool isDeletion);
-
-} // namespace ReductionTreeUtils
-
-} // end namespace mlir
-
-#endif

diff  --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index b54bb94e11b95..5969d63eaee2c 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -32,12 +32,21 @@ namespace mlir {
 /// case file.
 class Tester {
 public:
+  enum class Interestingness {
+    True,
+    False,
+    Untested,
+  };
+
   Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
 
   /// Runs the interestingness testing script on a MLIR test case file. Returns
   /// true if the interesting behavior is present in the test case or false
   /// otherwise.
-  bool isInteresting(StringRef testCase) const;
+  std::pair<Interestingness, size_t> isInteresting(ModuleOp module) const;
+
+  /// Return whether the file in the given path is interesting.
+  Interestingness isInteresting(StringRef testCase) const;
 
 private:
   StringRef testScript;

diff  --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index dd5fd277ce120..73601f031d574 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_library(MLIRReduce
    Tester.cpp
-   DEPENDS
+   LINK_LIBS PUBLIC
    MLIRIR
- )
- 
- mlir_check_all_link_libraries(MLIRReduce)
\ No newline at end of file
+)
+
+mlir_check_all_link_libraries(MLIRReduce)

diff  --git a/mlir/lib/Reducer/Tester.cpp b/mlir/lib/Reducer/Tester.cpp
index 3ca0e93935865..c0d4862481016 100644
--- a/mlir/lib/Reducer/Tester.cpp
+++ b/mlir/lib/Reducer/Tester.cpp
@@ -16,15 +16,40 @@
 
 #include "mlir/Reducer/Tester.h"
 
+#include "llvm/Support/ToolOutputFile.h"
+
 using namespace mlir;
 
 Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
     : testScript(scriptName), testScriptArgs(scriptArgs) {}
 
+std::pair<Tester::Interestingness, size_t>
+Tester::isInteresting(ModuleOp module) const {
+  SmallString<128> filepath;
+  int fd;
+
+  // Print module to temporary file.
+  std::error_code ec =
+      llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
+
+  if (ec)
+    llvm::report_fatal_error("Error making unique filename: " + ec.message());
+
+  llvm::ToolOutputFile out(filepath, fd);
+  module.print(out.os());
+  out.os().close();
+
+  if (out.os().has_error())
+    llvm::report_fatal_error("Error emitting the IR to file '" + filepath);
+
+  size_t size = out.os().tell();
+  return std::make_pair(isInteresting(filepath), size);
+}
+
 /// Runs the interestingness testing script on a MLIR test case file. Returns
 /// true if the interesting behavior is present in the test case or false
 /// otherwise.
-bool Tester::isInteresting(StringRef testCase) const {
+Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
 
   std::vector<StringRef> testerArgs;
   testerArgs.push_back(testCase);
@@ -44,7 +69,7 @@ bool Tester::isInteresting(StringRef testCase) const {
                              false);
 
   if (!result)
-    return false;
+    return Interestingness::False;
 
-  return true;
+  return Interestingness::True;
 }

diff  --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt
index 958c2c94cc684..162306e1a72b4 100644
--- a/mlir/tools/mlir-reduce/CMakeLists.txt
+++ b/mlir/tools/mlir-reduce/CMakeLists.txt
@@ -45,9 +45,8 @@ set(LIBS
 
 add_llvm_tool(mlir-reduce
   OptReductionPass.cpp
-  Passes/OpReducer.cpp
   ReductionNode.cpp
-  ReductionTreeUtils.cpp
+  ReductionTreePass.cpp
   mlir-reduce.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/tools/mlir-reduce/OptReductionPass.cpp
index 2ad55e948618d..97b9b3e3aa372 100644
--- a/mlir/tools/mlir-reduce/OptReductionPass.cpp
+++ b/mlir/tools/mlir-reduce/OptReductionPass.cpp
@@ -36,21 +36,25 @@ void OptReductionPass::runOnOperation() {
   PassManager pmTransform(context);
   pmTransform.addPass(std::move(optPass));
 
+  std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
+
   if (failed(pmTransform.run(moduleVariant)))
     return;
 
-  ReductionNode original(module, nullptr);
-  original.measureAndTest(test);
-
-  ReductionNode reduced(moduleVariant, nullptr);
-  reduced.measureAndTest(test);
+  std::pair<Tester::Interestingness, int> reduced =
+      test.isInteresting(moduleVariant);
 
-  if (reduced.isInteresting() && reduced.getSize() < original.getSize()) {
-    ReductionTreeUtils::updateGoldenModule(module, reduced.getModule().clone());
+  if (reduced.first == Tester::Interestingness::True &&
+      reduced.second < original.second) {
+    module.getBody()->clear();
+    module.getBody()->getOperations().splice(
+        module.getBody()->begin(), moduleVariant.getBody()->getOperations());
     LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
   } else {
     LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
   }
 
+  moduleVariant->destroy();
+
   LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
 }

diff  --git a/mlir/tools/mlir-reduce/Passes/OpReducer.cpp b/mlir/tools/mlir-reduce/Passes/OpReducer.cpp
deleted file mode 100644
index 8455b3831c421..0000000000000
--- a/mlir/tools/mlir-reduce/Passes/OpReducer.cpp
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the OpReducer class. It defines a variant generator method
-// with the purpose of producing 
diff erent variants by eliminating a
-// parameterizable type of operations from the  parent module.
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Reducer/Passes/OpReducer.h"
-
-using namespace mlir;
-
-OpReducerImpl::OpReducerImpl(
-    llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps)
-    : getSpecificOps(getSpecificOps) {}
-
-/// Return the name of this reducer class.
-StringRef OpReducerImpl::getName() {
-  return StringRef("High Level Operation Reduction");
-}
-
-/// Return the initial transformSpace containing the transformable indices.
-std::vector<bool> OpReducerImpl::initTransformSpace(ModuleOp module) {
-  auto ops = getSpecificOps(module);
-  int numOps = std::distance(ops.begin(), ops.end());
-  return ReductionTreeUtils::createTransformSpace(module, numOps);
-}
-
-/// Generate variants by removing opType operations from the module in the
-/// parent and link the variants as childs in the Reduction Tree Pass.
-void OpReducerImpl::generateVariants(
-    ReductionNode *parent, const Tester &test, int numVariants,
-    llvm::function_ref<void(ModuleOp, int, int)> transform) {
-  ReductionTreeUtils::createVariants(parent, test, numVariants, transform,
-                                     true);
-}

diff  --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/tools/mlir-reduce/ReductionNode.cpp
index bd4ef51786ec7..a8e4af8c88223 100644
--- a/mlir/tools/mlir-reduce/ReductionNode.cpp
+++ b/mlir/tools/mlir-reduce/ReductionNode.cpp
@@ -15,116 +15,138 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Reducer/ReductionNode.h"
+#include "llvm/ADT/STLExtras.h"
+
+#include <algorithm>
+#include <limits>
 
 using namespace mlir;
 
-/// Sets up the metadata and links the node to its parent.
-ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
-    : module(module), evaluated(false) {
+ReductionNode::ReductionNode(
+    ReductionNode *parent, std::vector<Range> ranges,
+    llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
+    : size(std::numeric_limits<size_t>::max()),
+      interesting(Tester::Interestingness::Untested),
+      /// Root node will have the parent pointer point to themselves.
+      parent(parent == nullptr ? this : parent), ranges(ranges),
+      allocator(allocator) {}
 
-  if (parent != nullptr)
-    parent->linkVariant(this);
-}
+/// Returns the size in bytes of the module.
+size_t ReductionNode::getSize() const { return size; }
 
-ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
-                             std::vector<bool> transformSpace)
-    : module(module), evaluated(false), transformSpace(transformSpace) {
+ReductionNode *ReductionNode::getParent() const { return parent; }
 
-  if (parent != nullptr)
-    parent->linkVariant(this);
+/// Returns true if the module exhibits the interesting behavior.
+Tester::Interestingness ReductionNode::isInteresting() const {
+  return interesting;
 }
 
-/// Calculates and updates the size and interesting values of the module.
-void ReductionNode::measureAndTest(const Tester &test) {
-  SmallString<128> filepath;
-  int fd;
-
-  // Print module to temporary file.
-  std::error_code ec =
-      llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
-
-  if (ec)
-    llvm::report_fatal_error("Error making unique filename: " + ec.message());
-
-  llvm::ToolOutputFile out(filepath, fd);
-  module.print(out.os());
-  out.os().close();
-
-  if (out.os().has_error())
-    llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
-
-  size = out.os().tell();
-  interesting = test.isInteresting(filepath);
-  evaluated = true;
+std::vector<ReductionNode::Range> ReductionNode::getRanges() const {
+  return ranges;
 }
 
-/// Returns true if the size and interestingness have been calculated.
-bool ReductionNode::isEvaluated() const { return evaluated; }
-
-/// Returns the size in bytes of the module.
-int ReductionNode::getSize() const { return size; }
+std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
+
+#include <iostream>
+
+/// If we haven't explored any variants from this node, we will create N
+/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
+/// max element in `ranges` and create 2 new variants for each call.
+std::vector<ReductionNode *> ReductionNode::generateNewVariants() {
+  std::vector<ReductionNode *> newNodes;
+
+  // If we haven't created new variant, then we can create varients by removing
+  // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
+  // produce variants with range {{1, 3}} and {{4, 9}}.
+  if (variants.size() == 0 && ranges.size() != 1) {
+    for (const Range &range : ranges) {
+      std::vector<Range> subRanges = ranges;
+      llvm::erase_value(subRanges, range);
+      ReductionNode *newNode = allocator.Allocate();
+      new (newNode) ReductionNode(this, subRanges, allocator);
+      newNodes.push_back(newNode);
+      variants.push_back(newNode);
+    }
 
-/// Returns true if the module exhibits the interesting behavior.
-bool ReductionNode::isInteresting() const { return interesting; }
+    return newNodes;
+  }
 
-/// Returns the pointers to the child variants.
-ReductionNode *ReductionNode::getVariant(unsigned long index) const {
-  if (index < variants.size())
-    return variants[index].get();
+  // At here, we have created the type of variants mentioned above. We would
+  // like to split the max range into 2 to create 2 new variants. Continue on
+  // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
+  // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
+  // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
+  auto maxElement = std::max_element(
+      ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
+        return (lhs.second - lhs.first) > (rhs.second - rhs.first);
+      });
 
-  return nullptr;
+  // We can't split range with lenght 1, which means we can't produce new
+  // variant.
+  if (maxElement->second - maxElement->first == 1)
+    return {};
+
+  auto createNewNode = [this](const std::vector<Range> &ranges) {
+    ReductionNode *newNode = allocator.Allocate();
+    new (newNode) ReductionNode(this, ranges, allocator);
+    return newNode;
+  };
+
+  Range maxRange = *maxElement;
+  std::vector<Range> subRanges = ranges;
+  auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
+  int half = (maxRange.first + maxRange.second) / 2;
+  *subRangesIter = std::make_pair(maxRange.first, half);
+  newNodes.push_back(createNewNode(subRanges));
+  *subRangesIter = std::make_pair(half, maxRange.second);
+  newNodes.push_back(createNewNode(subRanges));
+
+  variants.insert(variants.end(), newNodes.begin(), newNodes.end());
+  auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
+  it = ranges.insert(it, std::make_pair(maxRange.first, half));
+  // Remove the range that has been split.
+  ranges.erase(it + 2);
+
+  return newNodes;
 }
 
-/// Returns the number of child variants.
-int ReductionNode::variantsSize() const { return variants.size(); }
-
-/// Returns true if the child variants vector is empty.
-bool ReductionNode::variantsEmpty() const { return variants.empty(); }
-
-/// Link a child variant node.
-void ReductionNode::linkVariant(ReductionNode *newVariant) {
-  std::unique_ptr<ReductionNode> ptrVariant(newVariant);
-  variants.push_back(std::move(ptrVariant));
+void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
+  std::tie(interesting, size) = result;
 }
 
-/// Sort the child variants and remove the uninteresting ones.
-void ReductionNode::organizeVariants(const Tester &test) {
-  // Ensure all variants are evaluated.
-  for (auto &var : variants)
-    if (!var->isEvaluated())
-      var->measureAndTest(test);
-
-  // Sort variants by interestingness and size.
-  llvm::array_pod_sort(
-      variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
-        if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
-          return 0;
-
-        if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
-          return 1;
-
-        return (lhs->get()->getSize(), rhs->get()->getSize());
-      });
-
-  int interestingCount = 0;
-  for (auto &var : variants) {
-    if (var->isInteresting()) {
-      ++interestingCount;
-    } else {
-      break;
-    }
+std::vector<ReductionNode *>
+ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
+  // Single Path: Traverses the smallest successful variant at each level until
+  // no new successful variants can be created at that level.
+  llvm::ArrayRef<ReductionNode *> variantsFromParent =
+      node->getParent()->getVariants();
+
+  // The parent node created several variants and they may be waiting for
+  // examing interestingness. In Single Path approach, we will select the
+  // smallest variant to continue our exploration. Thus we should wait until the
+  // last variant to be examed then do the following traversal decision.
+  if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
+        return node->isInteresting() != Tester::Interestingness::Untested;
+      })) {
+    return {};
   }
 
-  // Remove uninteresting variants.
-  variants.resize(interestingCount);
-}
+  ReductionNode *smallest = nullptr;
+  for (ReductionNode *node : variantsFromParent) {
+    if (node->isInteresting() != Tester::Interestingness::True)
+      continue;
+    if (smallest == nullptr || node->getSize() < smallest->getSize())
+      smallest = node;
+  }
 
-/// Returns the number of non transformed indices.
-int ReductionNode::transformSpaceSize() {
-  return std::count(transformSpace.begin(), transformSpace.end(), false);
-}
+  if (smallest != nullptr) {
+    // We got a smallest one, keep traversing from this node.
+    node = smallest;
+  } else {
+    // None of these variants is interesting, let the parent node to generate
+    // more variants.
+    node = node->getParent();
+  }
 
-/// Returns a vector of the transformable indices in the Module.
-const std::vector<bool> ReductionNode::getTransformSpace() {
-  return transformSpace;
+  return node->generateNewVariants();
 }

diff  --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
new file mode 100644
index 0000000000000..6dbf783d2e6ff
--- /dev/null
+++ b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
@@ -0,0 +1,95 @@
+//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Pass class. It provides a framework for
+// the implementation of 
diff erent reduction passes in the MLIR Reduce tool. It
+// allows for custom specification of the variant generation behavior. It
+// implements methods that define the 
diff erent possible traversals of the
+// reduction tree.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Reducer/ReductionTreePass.h"
+
+#include "llvm/Support/Allocator.h"
+
+using namespace mlir;
+
+static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
+  if (opType == ModuleOp::getOperationName())
+    return std::make_unique<Reducer<ModuleOp>>();
+  else if (opType == FuncOp::getOperationName())
+    return std::make_unique<Reducer<FuncOp>>();
+  llvm_unreachable("Now only supports two built-in ops");
+}
+
+void ReductionTreePass::runOnOperation() {
+  ModuleOp module = this->getOperation();
+  std::unique_ptr<OpReducer> reducer = getOpReducer(opType);
+  std::vector<std::pair<int, int>> ranges = {
+      {0, reducer->getNumTargetOps(module)}};
+
+  llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+  ReductionNode *root = allocator.Allocate();
+  new (root) ReductionNode(nullptr, ranges, allocator);
+
+  ModuleOp golden = module;
+  switch (mode) {
+  case TraversalMode::SinglePath:
+    golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
+        module, std::move(reducer), root);
+    break;
+  default:
+    llvm_unreachable("Unsupported mode");
+  }
+
+  if (golden != module) {
+    module.getBody()->clear();
+    module.getBody()->getOperations().splice(module.getBody()->begin(),
+                                             golden.getBody()->getOperations());
+    golden->destroy();
+  }
+}
+
+template <typename IteratorType>
+ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
+                                        std::unique_ptr<OpReducer> reducer,
+                                        ReductionNode *root) {
+  std::pair<Tester::Interestingness, size_t> initStatus =
+      test.isInteresting(module);
+  root->update(initStatus);
+
+  ReductionNode *smallestNode = root;
+  ModuleOp golden = module;
+
+  IteratorType iter(root);
+
+  while (iter != IteratorType::end()) {
+    ModuleOp cloneModule = module.clone();
+
+    ReductionNode &currentNode = *iter;
+    reducer->reduce(cloneModule, currentNode.getRanges());
+
+    std::pair<Tester::Interestingness, size_t> result =
+        test.isInteresting(cloneModule);
+    currentNode.update(result);
+
+    if (result.first == Tester::Interestingness::True &&
+        result.second < smallestNode->getSize()) {
+      smallestNode = ¤tNode;
+      golden = cloneModule;
+    } else {
+      cloneModule->destroy();
+    }
+
+    ++iter;
+  }
+
+  return golden;
+}

diff  --git a/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp b/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp
deleted file mode 100644
index 820c19a4f6673..0000000000000
--- a/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp
+++ /dev/null
@@ -1,159 +0,0 @@
-//===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Utilities. It defines pass independent
-// methods that help in a reduction pass of the MLIR Reduce tool.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Reducer/ReductionTreeUtils.h"
-
-#define DEBUG_TYPE "mlir-reduce"
-
-using namespace mlir;
-
-/// Update the golden module's content with that of the reduced module.
-void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
-                                            ModuleOp reduced) {
-  golden.getBody()->clear();
-
-  golden.getBody()->getOperations().splice(golden.getBody()->begin(),
-                                           reduced.getBody()->getOperations());
-}
-
-/// Update the smallest node traversed so far in the reduction tree and
-/// print the debugging information for the currNode being traversed.
-void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
-                                            ReductionNode *&smallestNode,
-                                            std::vector<int> path) {
-  LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
-  #ifndef NDEBUG
-  for (int nodeIndex : path)
-    LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
-  #endif
-
-  LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
-  if (currNode->getSize() < smallestNode->getSize()) {
-    LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
-    smallestNode = currNode;
-  }
-}
-
-/// Create a transform space index vector based on the specified number of
-/// indices.
-std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
-                                                           int numIndices) {
-  std::vector<bool> transformSpace;
-  for (int i = 0; i < numIndices; ++i)
-    transformSpace.push_back(false);
-
-  return transformSpace;
-}
-
-/// Translate section start and end into a vector of ranges specifying the
-/// section in the non transformed indices in the transform space.
-static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
-                                                   int start, int end) {
-  std::vector<std::tuple<int, int>> ranges;
-  int rangeStart = 0;
-  int rangeEnd = 0;
-  bool inside = false;
-  int transformableCount = 0;
-
-  for (auto element : llvm::enumerate(tSpace)) {
-    int index = element.index();
-    bool value = element.value();
-
-    if (start <= transformableCount && transformableCount < end) {
-      if (!value && !inside) {
-        inside = true;
-        rangeStart = index;
-      }
-      if (value && inside) {
-        rangeEnd = index;
-        ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
-        inside = false;
-      }
-    }
-
-    if (!value)
-      transformableCount++;
-
-    if (transformableCount == end && inside) {
-      ranges.push_back(std::make_tuple(rangeStart, index + 1));
-      inside = false;
-      break;
-    }
-  }
-
-  return ranges;
-}
-
-/// Create the specified number of variants by applying the transform method
-/// to 
diff erent ranges of indices in the parent module. The isDeletion boolean
-/// specifies if the transformation is the deletion of indices.
-void ReductionTreeUtils::createVariants(
-    ReductionNode *parent, const Tester &test, int numVariants,
-    llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
-  std::vector<bool> newTSpace;
-  ModuleOp module = parent->getModule();
-
-  std::vector<bool> parentTSpace = parent->getTransformSpace();
-  int indexCount = parent->transformSpaceSize();
-  std::vector<std::tuple<int, int>> ranges;
-
-  // No new variants can be created.
-  if (indexCount == 0)
-    return;
-
-  // Create a single variant by transforming the unique index.
-  if (indexCount == 1) {
-    ModuleOp variantModule = module.clone();
-    if (isDeletion) {
-      transform(variantModule, 0, 1);
-    } else {
-      ranges = getRanges(parentTSpace, 0, parentTSpace.size());
-      transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
-    }
-
-    new ReductionNode(variantModule, parent, newTSpace);
-
-    return;
-  }
-
-  // Create the specified number of variants.
-  for (int i = 0; i < numVariants; ++i) {
-    ModuleOp variantModule = module.clone();
-    newTSpace = parent->getTransformSpace();
-    int sectionSize = indexCount / numVariants;
-    int sectionStart = sectionSize * i;
-    int sectionEnd = sectionSize * (i + 1);
-
-    if (i == numVariants - 1)
-      sectionEnd = indexCount;
-
-    if (isDeletion)
-      transform(variantModule, sectionStart, sectionEnd);
-
-    ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
-
-    for (auto range : ranges) {
-      int rangeStart = std::get<0>(range);
-      int rangeEnd = std::get<1>(range);
-
-      for (int x = rangeStart; x < rangeEnd; ++x)
-        newTSpace[x] = true;
-
-      if (!isDeletion)
-        transform(variantModule, rangeStart, rangeEnd);
-    }
-
-    // Create Reduction Node in the Reduction tree
-    new ReductionNode(variantModule, parent, newTSpace);
-  }
-}

diff  --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index d995683bb30c2..7df1dc155d38f 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -103,7 +103,7 @@ int main(int argc, char **argv) {
   // Initialize test environment.
   const Tester test(testFilename, testArguments);
 
-  if (!test.isInteresting(inputFilename))
+  if (test.isInteresting(inputFilename) != Tester::Interestingness::True)
     llvm::report_fatal_error(
         "Input test case does not exhibit interesting behavior");
 
@@ -118,11 +118,10 @@ int main(int argc, char **argv) {
 
   } else if (passTestSpecifier == "function-reducer") {
 
-    // Reduction tree pass with OpReducer variant generation and single path
+    // Reduction tree pass with Reducer variant generation and single path
     // traversal.
-    pm.addPass(
-        std::make_unique<ReductionTreePass<OpReducer<FuncOp>, SinglePath>>(
-            test));
+    pm.addPass(std::make_unique<ReductionTreePass>(
+        FuncOp::getOperationName(), TraversalMode::SinglePath, test));
   }
 
   ModuleOp m = moduleRef.get().clone();


        


More information about the Mlir-commits mailing list