[Mlir-commits] [mlir] 21f8d41 - Refactor Reduction Tree Pass

Mauricio Sifontes llvmlistbot at llvm.org
Thu Aug 20 21:59:49 PDT 2020


Author: Mauricio Sifontes
Date: 2020-08-21T04:59:24Z
New Revision: 21f8d414689387d97120a49df3dabca10e3262e4

URL: https://github.com/llvm/llvm-project/commit/21f8d414689387d97120a49df3dabca10e3262e4
DIFF: https://github.com/llvm/llvm-project/commit/21f8d414689387d97120a49df3dabca10e3262e4.diff

LOG: Refactor Reduction Tree Pass

Refactor the way the reduction tree pass works in the MLIR Reduce tool by introducing a set of utilities that facilitate the implementation of new Reducer classes to be used in the passes.

This will allow for the fast implementation of general transformations to operate on all mlir modules as well as custom transformations for different dialects.

These utilities allow for the implementation of Reducer classes by simply defining a method that indexes the operations/blocks/regions to be transformed and a method to perform the deletion or transfomration based on the indexes.

Create the transformSpace class member in the ReductionNode class to keep track of the indexes that have already been transformed or deleted at a current level.

Delete the FunctionReducer class and replace it with the OpReducer class to reflect this new API while performing the same transformation and allowing the instantiation of a reduction pass for different types of operations at the module's highest hierarchichal level.

Modify the SinglePath Traversal method to reflect the use of the new API.

Reviewed: jpienaar

Differential Revision: https://reviews.llvm.org/D85591

Added: 
    mlir/include/mlir/Reducer/Passes/OpReducer.h
    mlir/include/mlir/Reducer/ReductionTreeUtils.h
    mlir/tools/mlir-reduce/Passes/OpReducer.cpp
    mlir/tools/mlir-reduce/ReductionTreeUtils.cpp

Modified: 
    mlir/include/mlir/Reducer/OptReductionPass.h
    mlir/include/mlir/Reducer/ReductionNode.h
    mlir/include/mlir/Reducer/ReductionTreePass.h
    mlir/tools/mlir-reduce/CMakeLists.txt
    mlir/tools/mlir-reduce/OptReductionPass.cpp
    mlir/tools/mlir-reduce/ReductionNode.cpp
    mlir/tools/mlir-reduce/mlir-reduce.cpp

Removed: 
    mlir/include/mlir/Reducer/Passes/FunctionReducer.h
    mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp
    mlir/tools/mlir-reduce/ReductionTreePass.cpp


################################################################################
diff  --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h
index 2168ea215950..3549c2801f34 100644
--- a/mlir/include/mlir/Reducer/OptReductionPass.h
+++ b/mlir/include/mlir/Reducer/OptReductionPass.h
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file defines the Opt Reduction Pass Wrapper. It creates a pass to run
-// any optimization pass within it and only replaces the output module with the
-// transformed version if it is smaller and interesting.
+// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
+// run any optimization pass within it and only replaces the output module with
+// the transformed version if it is smaller and interesting.
 //
 //===----------------------------------------------------------------------===//
 
@@ -28,7 +28,7 @@ namespace mlir {
 
 class OptReductionPass : public OptReductionBase<OptReductionPass> {
 public:
-  OptReductionPass(const Tester *test, MLIRContext *context,
+  OptReductionPass(const Tester &test, MLIRContext *context,
                    std::unique_ptr<Pass> optPass);
 
   OptReductionPass(const OptReductionPass &srcPass);
@@ -41,7 +41,7 @@ class OptReductionPass : public OptReductionBase<OptReductionPass> {
   MLIRContext *context;
 
   // This is used to test the interesting behavior of the transformed module.
-  const Tester *test;
+  const Tester &test;
 
   // Points to the mlir-opt pass to be called.
   std::unique_ptr<Pass> optPass;

diff  --git a/mlir/include/mlir/Reducer/Passes/FunctionReducer.h b/mlir/include/mlir/Reducer/Passes/FunctionReducer.h
deleted file mode 100644
index f4b094bb692e..000000000000
--- a/mlir/include/mlir/Reducer/Passes/FunctionReducer.h
+++ /dev/null
@@ -1,36 +0,0 @@
-//===- FunctionReducer.h - MLIR Reduce Function Reducer ---------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the FunctionReducer class. It defines a variant generator
-// method with the purpose of producing 
diff erent variants by eliminating
-// functions from the  parent module.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_PASSES_FUNCTIONREDUCER_H
-#define MLIR_REDUCER_PASSES_FUNCTIONREDUCER_H
-
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/Tester.h"
-
-namespace mlir {
-
-/// The FunctionReducer class defines a variant generator method that produces
-/// multiple variants by eliminating 
diff erent operations from the
-/// parent module.
-class FunctionReducer {
-public:
-  /// Generate variants by removing functions from the module in the parent
-  /// Reduction Node and link the variants as children in the Reduction Tree
-  /// Pass.
-  void generateVariants(ReductionNode *parent, const Tester *test);
-};
-
-} // end namespace mlir
-
-#endif

diff  --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h
new file mode 100644
index 000000000000..e12bd790ace2
--- /dev/null
+++ b/mlir/include/mlir/Reducer/Passes/OpReducer.h
@@ -0,0 +1,107 @@
+//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the OpReducer class. It defines a variant generator method
+// with the purpose of producing 
diff erent variants by eliminating a
+// parametarizable type of operations from the  parent module.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
+#define MLIR_REDUCER_PASSES_OPREDUCER_H
+
+#include "mlir/IR/Region.h"
+#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/Reducer/ReductionTreeUtils.h"
+#include "mlir/Reducer/Tester.h"
+
+namespace mlir {
+
+class OpReducerImpl {
+public:
+  OpReducerImpl(
+      llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps);
+
+  /// Return the name of this reducer class.
+  StringRef getName();
+
+  /// Return the initial transformSpace containing the transformable indices.
+  std::vector<bool> initTransformSpace(ModuleOp module);
+
+  /// Generate variants by removing OpType operations from the module in the
+  /// parent and link the variants as childs in the Reduction Tree Pass.
+  void generateVariants(ReductionNode *parent, const Tester &test,
+                        int numVariants);
+
+  /// Generate variants by removing OpType operations from the module in the
+  /// parent and link the variants as childs in the Reduction Tree Pass. The
+  /// transform argument defines the function used to remove the OpTpye
+  /// operations in range of indexed OpType operations.
+  void generateVariants(ReductionNode *parent, const Tester &test,
+                        int numVariants,
+                        llvm::function_ref<void(ModuleOp, int, int)> transform);
+
+private:
+  llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps;
+};
+
+/// The OpReducer class defines a variant generator method that produces
+/// multiple variants by eliminating 
diff erent OpType operations from the
+/// parent module.
+template <typename OpType>
+class OpReducer {
+public:
+  OpReducer() : impl(new OpReducerImpl(getSpecificOps)) {}
+
+  /// Returns the vector of pointer to the OpType operations in the module.
+  static std::vector<Operation *> getSpecificOps(ModuleOp module) {
+    std::vector<Operation *> ops;
+    for (auto op : module.getOps<OpType>()) {
+      ops.push_back(op);
+    }
+    return ops;
+  }
+
+  /// Deletes the OpType operations in the module in the specified index.
+  static void deleteOps(ModuleOp module, int start, int end) {
+    std::vector<Operation *> opsToRemove;
+
+    for (auto op : enumerate(getSpecificOps(module))) {
+      int index = op.index();
+      if (index >= start && index < end)
+        opsToRemove.push_back(op.value());
+    }
+
+    for (Operation *o : opsToRemove) {
+      o->dropAllUses();
+      o->erase();
+    }
+  }
+
+  /// Return the name of this reducer class.
+  StringRef getName() { return impl->getName(); }
+
+  /// Return the initial transformSpace cointaing the transformable indices.
+  std::vector<bool> initTransformSpace(ModuleOp module) {
+    return impl->initTransformSpace(module);
+  }
+
+  /// Generate variants by removing OpType operations from the module in the
+  /// parent and link the variants as childs in the Reduction Tree Pass.
+  void generateVariants(ReductionNode *parent, const Tester &test,
+                        int numVariants) {
+    impl->generateVariants(parent, test, numVariants, deleteOps);
+  }
+
+private:
+  std::unique_ptr<OpReducerImpl> impl;
+};
+
+} // end namespace mlir
+
+#endif

diff  --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index bac282b784ef..0a4aba8858f4 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -32,8 +32,11 @@ class ReductionNode {
 public:
   ReductionNode(ModuleOp module, ReductionNode *parent);
 
+  ReductionNode(ModuleOp module, ReductionNode *parent,
+                std::vector<bool> transformSpace);
+
   /// Calculates and initializes the size and interesting values of the node.
-  void measureAndTest(const Tester *test);
+  void measureAndTest(const Tester &test);
 
   /// Returns the module.
   ModuleOp getModule() const { return module; }
@@ -50,11 +53,20 @@ class ReductionNode {
   /// Returns the pointer to a child variant by index.
   ReductionNode *getVariant(unsigned long index) const;
 
+  /// Returns the number of child variants.
+  int variantsSize() const;
+
   /// Returns true if the vector containing the child variants is empty.
   bool variantsEmpty() const;
 
   /// Sort the child variants and remove the uninteresting ones.
-  void organizeVariants(const Tester *test);
+  void organizeVariants(const Tester &test);
+
+  /// Returns the number of child variants.
+  int transformSpaceSize();
+
+  /// Returns a vector indicating the transformed indices as true.
+  const std::vector<bool> getTransformSpace();
 
 private:
   /// Link a child variant node.
@@ -74,6 +86,10 @@ class ReductionNode {
   // This indicates if the module has been evalueated (measured and tested).
   bool evaluated;
 
+  // Indicates the indices in the node that have been transformed in previous
+  // levels of the reduction tree.
+  std::vector<bool> transformSpace;
+
   // This points to the child variants that were created using this node as a
   // starting point.
   std::vector<std::unique_ptr<ReductionNode>> variants;

diff  --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h
index d07a475e4f99..be13191dbd38 100644
--- a/mlir/include/mlir/Reducer/ReductionTreePass.h
+++ b/mlir/include/mlir/Reducer/ReductionTreePass.h
@@ -21,21 +21,17 @@
 
 #include "PassDetail.h"
 #include "ReductionNode.h"
-#include "mlir/Reducer/Passes/FunctionReducer.h"
+#include "mlir/Reducer/Passes/OpReducer.h"
+#include "mlir/Reducer/ReductionTreeUtils.h"
 #include "mlir/Reducer/Tester.h"
 
+#define DEBUG_TYPE "mlir-reduce"
+
 namespace mlir {
 
-/// Defines the traversal method options to be used in the reduction tree
+// Defines the traversal method options to be used in the reduction tree
 /// traversal.
-enum TraversalMode { SinglePath, MultiPath, Concurrent, Backtrack };
-
-// This class defines the non- templated utilities used by the ReductionTreePass
-// class.
-class ReductionTreeUtils {
-public:
-  static void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
-};
+enum TraversalMode { SinglePath, Backtrack, MultiPath };
 
 /// This class defines the Reduction Tree Pass. It provides a framework to
 /// to implement a reduction pass using a tree structure to keep track of the
@@ -44,29 +40,37 @@ template <typename Reducer, TraversalMode mode>
 class ReductionTreePass
     : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> {
 public:
-  ReductionTreePass(const Tester *test) : test(test) {}
-
   ReductionTreePass(const ReductionTreePass &pass)
       : ReductionTreeBase<ReductionTreePass<Reducer, mode>>(pass),
         root(new ReductionNode(pass.root->getModule().clone(), nullptr)),
         test(pass.test) {}
 
+  ReductionTreePass(const Tester &test) : test(test) {}
+
   /// Runs the pass instance in the pass pipeline.
   void runOnOperation() override {
     ModuleOp module = this->getOperation();
-    this->root = std::make_unique<ReductionNode>(module, nullptr);
+    Reducer reducer;
+    std::vector<bool> transformSpace = reducer.initTransformSpace(module);
     ReductionNode *reduced;
 
+    this->root =
+        std::make_unique<ReductionNode>(module, nullptr, transformSpace);
+
+    root->measureAndTest(test);
+
+    LLVM_DEBUG(llvm::dbgs() << "\nReduction Tree Pass: " << reducer.getName(););
     switch (mode) {
     case SinglePath:
+      LLVM_DEBUG(llvm::dbgs() << " (Single Path)\n";);
       reduced = singlePathTraversal();
       break;
     default:
       llvm::report_fatal_error("Traversal method not currently supported.");
     }
 
-    ReductionTreeUtils utils;
-    utils.updateGoldenModule(module, reduced->getModule());
+    ReductionTreeUtils::updateGoldenModule(module,
+                                           reduced->getModule().clone());
   }
 
 private:
@@ -79,26 +83,69 @@ class ReductionTreePass
 
   // This is used to test the interesting behavior of the reduction nodes in the
   // tree.
-  const Tester *test;
+  const Tester &test;
 
   /// Traverse the most reduced path in the reduction tree by generating the
   /// variants at each level using the Reducer parameter's generateVariants
   /// function. Stops when no new successful variants can be created at the
   /// current level.
   ReductionNode *singlePathTraversal() {
-    ReductionNode *currLevel = root.get();
+    ReductionNode *currNode = root.get();
+    ReductionNode *smallestNode = currNode;
+    int tSpaceSize = currNode->transformSpaceSize();
+    std::vector<int> path;
+
+    ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
+
+    LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variant: applying the ");
+    LLVM_DEBUG(llvm::dbgs() << "transformation to the entire module\n");
 
-    while (true) {
-      reducer.generateVariants(currLevel, test);
-      currLevel->organizeVariants(test);
+    reducer.generateVariants(currNode, test, 1);
+    LLVM_DEBUG(llvm::dbgs() << "Testing\n");
+    currNode->organizeVariants(test);
 
-      if (currLevel->variantsEmpty())
+    if (!currNode->variantsEmpty())
+      return currNode->getVariant(0);
+
+    while (tSpaceSize != 1) {
+      ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
+
+      LLVM_DEBUG(llvm::dbgs() << "\nGenerating 2 variants: applying the ");
+      LLVM_DEBUG(llvm::dbgs() << "transformation to two 
diff erent sections ");
+      LLVM_DEBUG(llvm::dbgs() << "of transformable indices\n");
+
+      reducer.generateVariants(currNode, test, 2);
+      LLVM_DEBUG(llvm::dbgs() << "Testing\n");
+      currNode->organizeVariants(test);
+
+      if (currNode->variantsEmpty())
         break;
 
-      currLevel = currLevel->getVariant(0);
+      currNode = currNode->getVariant(0);
+      tSpaceSize = currNode->transformSpaceSize();
+      path.push_back(0);
+    }
+
+    if (tSpaceSize == 1) {
+      ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
+
+      LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variants: applying the ");
+      LLVM_DEBUG(llvm::dbgs() << "transformation to the only transformable");
+      LLVM_DEBUG(llvm::dbgs() << "index\n");
+
+      reducer.generateVariants(currNode, test, 1);
+      LLVM_DEBUG(llvm::dbgs() << "Testing\n");
+      currNode->organizeVariants(test);
+
+      if (!currNode->variantsEmpty()) {
+        currNode = currNode->getVariant(0);
+        path.push_back(0);
+
+        ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path);
+      }
     }
 
-    return currLevel;
+    return currNode;
   }
 };
 

diff  --git a/mlir/include/mlir/Reducer/ReductionTreeUtils.h b/mlir/include/mlir/Reducer/ReductionTreeUtils.h
new file mode 100644
index 000000000000..411186f29faa
--- /dev/null
+++ b/mlir/include/mlir/Reducer/ReductionTreeUtils.h
@@ -0,0 +1,53 @@
+//===- ReductionTreeUtils.h - Reduction Tree utilities ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Utilities. It defines pass independent
+// methods that help in the reduction passes of the MLIR Reduce tool.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_REDUCTIONTREEUTILS_H
+#define MLIR_REDUCER_REDUCTIONTREEUTILS_H
+
+#include <tuple>
+
+#include "PassDetail.h"
+#include "ReductionNode.h"
+#include "mlir/Reducer/Tester.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+
+// Defines the utilities for the implementation of custom reduction
+// passes using the ReductionTreePass framework.
+namespace ReductionTreeUtils {
+
+/// Update the golden module's content with that of the reduced module.
+void updateGoldenModule(ModuleOp &golden, ModuleOp reduced);
+
+/// Update the the smallest node traversed so far in the reduction tree and
+/// print the debugging information for the currNode being traversed.
+void updateSmallestNode(ReductionNode *currNode, ReductionNode *&smallestNode,
+                        std::vector<int> path);
+
+/// Create a transform space index vector based on the specified number of
+/// indices.
+std::vector<bool> createTransformSpace(ModuleOp module, int numIndices);
+
+/// Create the specified number of variants by applying the transform method
+/// to 
diff erent ranges of indices in the parent module. The isDeletion bolean
+/// specifies if the transformation is the deletion of indices.
+void createVariants(ReductionNode *parent, const Tester &test, int numVariants,
+                    llvm::function_ref<void(ModuleOp, int, int)> transform,
+                    bool isDeletion);
+
+} // namespace ReductionTreeUtils
+
+} // end namespace mlir
+
+#endif

diff  --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt
index f581eee21fab..8262dc94b6f2 100644
--- a/mlir/tools/mlir-reduce/CMakeLists.txt
+++ b/mlir/tools/mlir-reduce/CMakeLists.txt
@@ -33,9 +33,9 @@ set(LIBS
 
 add_llvm_tool(mlir-reduce
   OptReductionPass.cpp
-  Passes/FunctionReducer.cpp
+  Passes/OpReducer.cpp
   ReductionNode.cpp
-  ReductionTreePass.cpp
+  ReductionTreeUtils.cpp
   mlir-reduce.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/tools/mlir-reduce/OptReductionPass.cpp
index dbb3d97046d4..fbc3053aa826 100644
--- a/mlir/tools/mlir-reduce/OptReductionPass.cpp
+++ b/mlir/tools/mlir-reduce/OptReductionPass.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file defines the Opt Reduction Pass class. It creates a pass to run
-// any optimization pass within it and only replaces the output module with the
-// transformed version if it is smaller and interesting.
+// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
+// run any optimization pass within it and only replaces the output module with
+// the transformed version if it is smaller and interesting.
 //
 //===----------------------------------------------------------------------===//
 
@@ -18,7 +18,7 @@
 
 using namespace mlir;
 
-OptReductionPass::OptReductionPass(const Tester *test, MLIRContext *context,
+OptReductionPass::OptReductionPass(const Tester &test, MLIRContext *context,
                                    std::unique_ptr<Pass> optPass)
     : context(context), test(test), optPass(std::move(optPass)) {}
 

diff  --git a/mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp b/mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp
deleted file mode 100644
index ac97848c96ea..000000000000
--- a/mlir/tools/mlir-reduce/Passes/FunctionReducer.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-//===- FunctionReducer.cpp - MLIR Reduce Function Reducer -----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the FunctionReducer class. It defines a variant generator
-// class to be used in a Reduction Tree Pass instantiation with the aim of
-// reducing the number of function operations in an MLIR Module.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Reducer/Passes/FunctionReducer.h"
-#include "mlir/IR/Function.h"
-
-using namespace mlir;
-
-/// Return the number of function operations in the module's body.
-int countFunctions(ModuleOp module) {
-  auto ops = module.getOps<FuncOp>();
-  return std::distance(ops.begin(), ops.end());
-}
-
-/// Generate variants by removing function operations from the module in the
-/// parent and link the variants as children in the Reduction Tree Pass.
-void FunctionReducer::generateVariants(ReductionNode *parent,
-                                       const Tester *test) {
-  ModuleOp module = parent->getModule();
-  int opCount = countFunctions(module);
-  int sectionSize = opCount / 2;
-  std::vector<Operation *> opsToRemove;
-
-  if (opCount == 0)
-    return;
-
-  // Create a variant by deleting all ops.
-  if (opCount == 1) {
-    opsToRemove.clear();
-    ModuleOp moduleVariant = module.clone();
-
-    for (FuncOp op : moduleVariant.getOps<FuncOp>())
-      opsToRemove.push_back(op);
-
-    for (Operation *o : opsToRemove)
-      o->erase();
-
-    new ReductionNode(moduleVariant, parent);
-
-    return;
-  }
-
-  // Create two variants by bisecting the module.
-  for (int i = 0; i < 2; ++i) {
-    opsToRemove.clear();
-    ModuleOp moduleVariant = module.clone();
-
-    for (auto op : enumerate(moduleVariant.getOps<FuncOp>())) {
-      int index = op.index();
-      if (index >= sectionSize * i && index < sectionSize * (i + 1))
-        opsToRemove.push_back(op.value());
-    }
-
-    for (Operation *o : opsToRemove)
-      o->erase();
-
-    new ReductionNode(moduleVariant, parent);
-  }
-
-  return;
-}

diff  --git a/mlir/tools/mlir-reduce/Passes/OpReducer.cpp b/mlir/tools/mlir-reduce/Passes/OpReducer.cpp
new file mode 100644
index 000000000000..d94e80384cf7
--- /dev/null
+++ b/mlir/tools/mlir-reduce/Passes/OpReducer.cpp
@@ -0,0 +1,41 @@
+//===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the OpReducer class. It defines a variant generator method
+// with the purpose of producing 
diff erent variants by eliminating a
+// parametarizable type of operations from the  parent module.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Reducer/Passes/OpReducer.h"
+
+using namespace mlir;
+
+OpReducerImpl::OpReducerImpl(
+    llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps)
+    : getSpecificOps(getSpecificOps) {}
+
+/// Return the name of this reducer class.
+StringRef OpReducerImpl::getName() {
+  return StringRef("High Level Operation Reduction");
+}
+
+/// Return the initial transformSpace cointaing the transformable indices.
+std::vector<bool> OpReducerImpl::initTransformSpace(ModuleOp module) {
+  auto ops = getSpecificOps(module);
+  int numOps = std::distance(ops.begin(), ops.end());
+  return ReductionTreeUtils::createTransformSpace(module, numOps);
+}
+
+/// Generate variants by removing opType operations from the module in the
+/// parent and link the variants as childs in the Reduction Tree Pass.
+void OpReducerImpl::generateVariants(
+    ReductionNode *parent, const Tester &test, int numVariants,
+    llvm::function_ref<void(ModuleOp, int, int)> transform) {
+  ReductionTreeUtils::createVariants(parent, test, numVariants, transform,
+                                     true);
+}

diff  --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/tools/mlir-reduce/ReductionNode.cpp
index ed1ff8534f7c..2b2ce6ed0851 100644
--- a/mlir/tools/mlir-reduce/ReductionNode.cpp
+++ b/mlir/tools/mlir-reduce/ReductionNode.cpp
@@ -26,8 +26,16 @@ ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
     parent->linkVariant(this);
 }
 
+ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
+                             std::vector<bool> transformSpace)
+    : module(module), evaluated(false), transformSpace(transformSpace) {
+
+  if (parent != nullptr)
+    parent->linkVariant(this);
+}
+
 /// Calculates and updates the size and interesting values of the module.
-void ReductionNode::measureAndTest(const Tester *test) {
+void ReductionNode::measureAndTest(const Tester &test) {
   SmallString<128> filepath;
   int fd;
 
@@ -46,7 +54,7 @@ void ReductionNode::measureAndTest(const Tester *test) {
     llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
 
   size = out.os().tell();
-  interesting = test->isInteresting(filepath);
+  interesting = test.isInteresting(filepath);
   evaluated = true;
 }
 
@@ -67,6 +75,9 @@ ReductionNode *ReductionNode::getVariant(unsigned long index) const {
   return nullptr;
 }
 
+/// Returns the number of child variants.
+int ReductionNode::variantsSize() const { return variants.size(); }
+
 /// Returns true if the child variants vector is empty.
 bool ReductionNode::variantsEmpty() const { return variants.empty(); }
 
@@ -77,7 +88,7 @@ void ReductionNode::linkVariant(ReductionNode *newVariant) {
 }
 
 /// Sort the child variants and remove the uninteresting ones.
-void ReductionNode::organizeVariants(const Tester *test) {
+void ReductionNode::organizeVariants(const Tester &test) {
   // Ensure all variants are evaluated.
   for (auto &var : variants)
     if (!var->isEvaluated())
@@ -107,3 +118,13 @@ void ReductionNode::organizeVariants(const Tester *test) {
   // Remove uninteresting variants.
   variants.resize(interestingCount);
 }
+
+/// Returns the number of non transformed indices.
+int ReductionNode::transformSpaceSize() {
+  return std::count(transformSpace.begin(), transformSpace.end(), false);
+}
+
+/// Returns a vector of the transformable indices in the Module.
+const std::vector<bool> ReductionNode::getTransformSpace() {
+  return transformSpace;
+}

diff  --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
deleted file mode 100644
index d18c69364168..000000000000
--- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp
+++ /dev/null
@@ -1,28 +0,0 @@
-//===- ReductionTreePass.cpp - Reduction Tree Pass Implementation ---------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Pass. It provides a framework for
-// the implementation of 
diff erent reduction passes in the MLIR Reduce tool. It
-// allows for custom specification of the variant generation behavior. It
-// implements methods that define the 
diff erent possible traversals of the
-// reduction tree.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Reducer/ReductionTreePass.h"
-
-using namespace mlir;
-
-/// Update the golden module's content with that of the reduced module.
-void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
-                                            ModuleOp reduced) {
-  golden.getBody()->clear();
-
-  golden.getBody()->getOperations().splice(golden.getBody()->begin(),
-                                           reduced.getBody()->getOperations());
-}

diff  --git a/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp b/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp
new file mode 100644
index 000000000000..5fdb1341f89c
--- /dev/null
+++ b/mlir/tools/mlir-reduce/ReductionTreeUtils.cpp
@@ -0,0 +1,157 @@
+//===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Utilities. It defines pass independent
+// methods that help in a reduction pass of the MLIR Reduce tool.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Reducer/ReductionTreeUtils.h"
+
+#define DEBUG_TYPE "mlir-reduce"
+
+using namespace mlir;
+
+/// Update the golden module's content with that of the reduced module.
+void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
+                                            ModuleOp reduced) {
+  golden.getBody()->clear();
+
+  golden.getBody()->getOperations().splice(golden.getBody()->begin(),
+                                           reduced.getBody()->getOperations());
+}
+
+/// Update the the smallest node traversed so far in the reduction tree and
+/// print the debugging information for the currNode being traversed.
+void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
+                                            ReductionNode *&smallestNode,
+                                            std::vector<int> path) {
+  LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
+  for (int nodeIndex : path)
+    LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
+
+  LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
+  if (currNode->getSize() < smallestNode->getSize()) {
+    LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
+    smallestNode = currNode;
+  }
+}
+
+/// Create a transform space index vector based on the specified number of
+/// indices.
+std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
+                                                           int numIndices) {
+  std::vector<bool> transformSpace;
+  for (int i = 0; i < numIndices; ++i)
+    transformSpace.push_back(false);
+
+  return transformSpace;
+}
+
+/// Translate section start and end into a vector of ranges specifying the
+/// section in the non transformed indices in the transform space.
+static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
+                                                   int start, int end) {
+  std::vector<std::tuple<int, int>> ranges;
+  int rangeStart = 0;
+  int rangeEnd = 0;
+  bool inside = false;
+  int transformableCount = 0;
+
+  for (auto element : llvm::enumerate(tSpace)) {
+    int index = element.index();
+    bool value = element.value();
+
+    if (start <= transformableCount && transformableCount < end) {
+      if (!value && !inside) {
+        inside = true;
+        rangeStart = index;
+      }
+      if (value && inside) {
+        rangeEnd = index;
+        ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
+        inside = false;
+      }
+    }
+
+    if (!value)
+      transformableCount++;
+
+    if (transformableCount == end && inside) {
+      ranges.push_back(std::make_tuple(rangeStart, index + 1));
+      inside = false;
+      break;
+    }
+  }
+
+  return ranges;
+}
+
+/// Create the specified number of variants by applying the transform method
+/// to 
diff erent ranges of indices in the parent module. The isDeletion bolean
+/// specifies if the transformation is the deletion of indices.
+void ReductionTreeUtils::createVariants(
+    ReductionNode *parent, const Tester &test, int numVariants,
+    llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
+  std::vector<bool> newTSpace;
+  ModuleOp module = parent->getModule();
+
+  std::vector<bool> parentTSpace = parent->getTransformSpace();
+  int indexCount = parent->transformSpaceSize();
+  std::vector<std::tuple<int, int>> ranges;
+
+  // No new variants can be created.
+  if (indexCount == 0)
+    return;
+
+  // Create a single variant by transforming the unique index.
+  if (indexCount == 1) {
+    ModuleOp variantModule = module.clone();
+    if (isDeletion) {
+      transform(variantModule, 0, 1);
+    } else {
+      ranges = getRanges(parentTSpace, 0, parentTSpace.size());
+      transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
+    }
+
+    new ReductionNode(variantModule, parent, newTSpace);
+
+    return;
+  }
+
+  // Create the specified number of variants.
+  for (int i = 0; i < numVariants; ++i) {
+    ModuleOp variantModule = module.clone();
+    newTSpace = parent->getTransformSpace();
+    int sectionSize = indexCount / numVariants;
+    int sectionStart = sectionSize * i;
+    int sectionEnd = sectionSize * (i + 1);
+
+    if (i == numVariants - 1)
+      sectionEnd = indexCount;
+
+    if (isDeletion)
+      transform(variantModule, sectionStart, sectionEnd);
+
+    ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
+
+    for (auto range : ranges) {
+      int rangeStart = std::get<0>(range);
+      int rangeEnd = std::get<1>(range);
+
+      for (int x = rangeStart; x < rangeEnd; ++x)
+        newTSpace[x] = true;
+
+      if (!isDeletion)
+        transform(variantModule, rangeStart, rangeEnd);
+    }
+
+    // Create Reduction Node in the Reduction tree
+    new ReductionNode(variantModule, parent, newTSpace);
+  }
+}

diff  --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index 4c69aa0ad217..e60aa2f95c59 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Reducer/OptReductionPass.h"
+#include "mlir/Reducer/Passes/OpReducer.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionTreePass.h"
 #include "mlir/Reducer/Tester.h"
@@ -103,15 +104,16 @@ int main(int argc, char **argv) {
   if (passTestSpecifier == "DCE") {
 
     // Opt Reduction Pass with SymbolDCEPass as opt pass.
-    pm.addPass(std::make_unique<OptReductionPass>(&test, &context,
+    pm.addPass(std::make_unique<OptReductionPass>(test, &context,
                                                   createSymbolDCEPass()));
 
   } else if (passTestSpecifier == "function-reducer") {
 
     // Reduction tree pass with OpReducer variant generation and single path
     // traversal.
-    pm.addPass(std::make_unique<ReductionTreePass<FunctionReducer, SinglePath>>(
-        &test));
+    pm.addPass(
+        std::make_unique<ReductionTreePass<OpReducer<FuncOp>, SinglePath>>(
+            test));
   }
 
   ModuleOp m = moduleRef.get().clone();


        


More information about the Mlir-commits mailing list