[Mlir-commits] [mlir] c484c7d - [mlir-reduce] Reducer refactor.
Chia-hung Duan
llvmlistbot at llvm.org
Tue Jun 1 16:46:45 PDT 2021
Author: Chia-hung Duan
Date: 2021-06-02T07:45:00+08:00
New Revision: c484c7dd9d2382f07216ae9142ceb76272e21dc4
URL: https://github.com/llvm/llvm-project/commit/c484c7dd9d2382f07216ae9142ceb76272e21dc4
DIFF: https://github.com/llvm/llvm-project/commit/c484c7dd9d2382f07216ae9142ceb76272e21dc4.diff
LOG: [mlir-reduce] Reducer refactor.
* A Reducer is a kind of RewritePattern, so it's just the same as
writing graph rewrite.
* ReductionTreePass operates on Operation rather than ModuleOp, so that
* we are able to reduce a nested structure(e.g., module in module) by
* self-nesting.
Reviewed By: jpienaar, rriddle
Differential Revision: https://reviews.llvm.org/D101046
Added:
mlir/include/mlir/Reducer/ReductionPatternInterface.h
mlir/lib/Reducer/OptReductionPass.cpp
mlir/lib/Reducer/ReductionNode.cpp
mlir/lib/Reducer/ReductionTreePass.cpp
mlir/test/mlir-reduce/crashop-reduction.mlir
Modified:
mlir/include/mlir/Reducer/Passes.h
mlir/include/mlir/Reducer/Passes.td
mlir/include/mlir/Reducer/ReductionNode.h
mlir/lib/Reducer/CMakeLists.txt
mlir/lib/Reducer/Tester.cpp
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Reducer/MLIRTestReducer.cpp
mlir/test/mlir-reduce/dce-test.mlir
mlir/test/mlir-reduce/multiple-function.mlir
mlir/test/mlir-reduce/simple-test.mlir
mlir/test/mlir-reduce/single-function.mlir
mlir/tools/mlir-reduce/CMakeLists.txt
mlir/tools/mlir-reduce/mlir-reduce.cpp
Removed:
mlir/include/mlir/Reducer/OptReductionPass.h
mlir/include/mlir/Reducer/Passes/OpReducer.h
mlir/include/mlir/Reducer/ReductionTreePass.h
mlir/tools/mlir-reduce/OptReductionPass.cpp
mlir/tools/mlir-reduce/ReductionNode.cpp
mlir/tools/mlir-reduce/ReductionTreePass.cpp
################################################################################
diff --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h
deleted file mode 100644
index f2735395d7590..0000000000000
--- a/mlir/include/mlir/Reducer/OptReductionPass.h
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
-// run any optimization pass within it and only replaces the output module with
-// the transformed version if it is smaller and interesting.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H
-#define MLIR_REDUCER_OPTREDUCTIONPASS_H
-
-#include "PassDetail.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Tester.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/Support/Debug.h"
-
-namespace mlir {
-
-class OptReductionPass : public OptReductionBase<OptReductionPass> {
-public:
- OptReductionPass() = default;
-
- OptReductionPass(const OptReductionPass &srcPass) = default;
-
- /// Runs the pass instance in the pass pipeline.
- void runOnOperation() override;
-};
-
-} // end namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Reducer/Passes.h b/mlir/include/mlir/Reducer/Passes.h
index 8eb1d59b47cfa..438f47b329d04 100644
--- a/mlir/include/mlir/Reducer/Passes.h
+++ b/mlir/include/mlir/Reducer/Passes.h
@@ -9,8 +9,6 @@
#define MLIR_REDUCER_PASSES_H
#include "mlir/Pass/Pass.h"
-#include "mlir/Reducer/OptReductionPass.h"
-#include "mlir/Reducer/ReductionTreePass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index ce99507706e08..43e90122e6682 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -24,14 +24,12 @@ def CommonReductionPassOptions {
];
}
-def ReductionTree : Pass<"reduction-tree", "ModuleOp"> {
+def ReductionTree : Pass<"reduction-tree"> {
let summary = "A general reduction tree pass for the MLIR Reduce Tool";
let constructor = "mlir::createReductionTreePass()";
let options = [
- Option<"opReducerName", "op-reducer", "std::string", /* default */"",
- "The OpReducer to reduce the module">,
Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0", "The graph traversal mode">,
] # CommonReductionPassOptions.options;
diff --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h
deleted file mode 100644
index 6e48d41187356..0000000000000
--- a/mlir/include/mlir/Reducer/Passes/OpReducer.h
+++ /dev/null
@@ -1,76 +0,0 @@
-//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the OpReducer class. It defines a variant generator method
-// with the purpose of producing
diff erent variants by eliminating a
-// parameterizable type of operations from the parent module.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
-#define MLIR_REDUCER_PASSES_OPREDUCER_H
-
-#include <limits>
-
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/Tester.h"
-
-namespace mlir {
-
-class OpReducer {
-public:
- virtual ~OpReducer() = default;
- /// According to rangeToKeep, try to reduce the given module. We implicitly
- /// number each interesting operation and rangeToKeep indicates that if an
- /// operation's number falls into certain range, then we will not try to
- /// reduce that operation.
- virtual void reduce(ModuleOp module,
- ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
- /// Return the number of certain kind of operations that we would like to
- /// reduce. This can be used to build a range map to exclude uninterested
- /// operations.
- virtual int getNumTargetOps(ModuleOp module) const = 0;
-};
-
-/// Reducer is a helper class to remove potential uninteresting operations from
-/// module.
-template <typename OpType>
-class Reducer : public OpReducer {
-public:
- ~Reducer() override = default;
-
- int getNumTargetOps(ModuleOp module) const override {
- return std::distance(module.getOps<OpType>().begin(),
- module.getOps<OpType>().end());
- }
-
- void reduce(ModuleOp module,
- ArrayRef<ReductionNode::Range> rangeToKeep) override {
- std::vector<Operation *> opsToRemove;
- size_t keepIndex = 0;
-
- for (auto op : enumerate(module.getOps<OpType>())) {
- int index = op.index();
- if (keepIndex < rangeToKeep.size() &&
- index == rangeToKeep[keepIndex].second)
- ++keepIndex;
- if (keepIndex == rangeToKeep.size() ||
- index < rangeToKeep[keepIndex].first)
- opsToRemove.push_back(op.value());
- }
-
- for (Operation *o : opsToRemove) {
- o->dropAllUses();
- o->erase();
- }
- }
-};
-
-} // end namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h
index 6364157b9040b..65b2928bfd546 100644
--- a/mlir/include/mlir/Reducer/ReductionNode.h
+++ b/mlir/include/mlir/Reducer/ReductionNode.h
@@ -21,19 +21,25 @@
#include <vector>
#include "mlir/Reducer/Tester.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ToolOutputFile.h"
namespace mlir {
+class ModuleOp;
+class Region;
+
/// Defines the traversal method options to be used in the reduction tree
/// traversal.
enum TraversalMode { SinglePath, Backtrack, MultiPath };
-/// This class defines the ReductionNode which is used to generate variant and
-/// keep track of the necessary metadata for the reduction pass. The nodes are
-/// linked together in a reduction tree structure which defines the relationship
-/// between all the
diff erent generated variants.
+/// ReductionTreePass will build a reduction tree during module reduction and
+/// the ReductionNode represents the vertex of the tree. A ReductionNode records
+/// the information such as the reduced module, how this node is reduced from
+/// the parent node, etc. This information will be used to construct a reduction
+/// path to reduce the certain module.
class ReductionNode {
public:
template <TraversalMode mode>
@@ -44,23 +50,46 @@ class ReductionNode {
ReductionNode(ReductionNode *parent, std::vector<Range> range,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
- ReductionNode *getParent() const;
+ ReductionNode *getParent() const { return parent; }
+
+ /// If the ReductionNode hasn't been tested the interestingness, it'll be the
+ /// same module as the one in the parent node. Otherwise, the returned module
+ /// will have been applied certain reduction strategies. Note that it's not
+ /// necessary to be an interesting case or a reduced module (has smaller size
+ /// than parent's).
+ ModuleOp getModule() const { return module; }
+
+ /// Return the region we're reducing.
+ Region &getRegion() const { return *region; }
- size_t getSize() const;
+ /// Return the size of the module.
+ size_t getSize() const { return size; }
/// Returns true if the module exhibits the interesting behavior.
- Tester::Interestingness isInteresting() const;
+ Tester::Interestingness isInteresting() const { return interesting; }
- std::vector<Range> getRanges() const;
+ /// Return the range information that how this node is reduced from the parent
+ /// node.
+ ArrayRef<Range> getStartRanges() const { return startRanges; }
- std::vector<ReductionNode *> &getVariants();
+ /// Return the range set we are using to generate variants.
+ ArrayRef<Range> getRanges() const { return ranges; }
+
+ /// Return the generated variants(the child nodes).
+ ArrayRef<ReductionNode *> getVariants() const { return variants; }
/// Split the ranges and generate new variants.
- std::vector<ReductionNode *> generateNewVariants();
+ ArrayRef<ReductionNode *> generateNewVariants();
/// Update the interestingness result from tester.
void update(std::pair<Tester::Interestingness, size_t> result);
+ /// Each Reduction Node contains a copy of module for applying rewrite
+ /// patterns. In addition, we only apply rewrite patterns in a certain region.
+ /// In init(), we will duplicate the module from parent node and locate the
+ /// corresponding region.
+ LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
+
private:
/// A custom BFS iterator. The
diff erence between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
@@ -87,8 +116,7 @@ class ReductionNode {
BaseIterator &operator++() {
ReductionNode *top = visitQueue.front();
visitQueue.pop();
- std::vector<ReductionNode *> neighbors = getNeighbors(top);
- for (ReductionNode *node : neighbors)
+ for (ReductionNode *node : getNeighbors(top))
visitQueue.push(node);
return *this;
}
@@ -103,7 +131,7 @@ class ReductionNode {
ReductionNode *operator->() const { return visitQueue.front(); }
protected:
- std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
+ ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
return static_cast<T *>(this)->getNeighbors(node);
}
@@ -111,21 +139,42 @@ class ReductionNode {
std::queue<ReductionNode *> visitQueue;
};
- /// The size of module after applying the range constraints.
+ /// This is a copy of module from parent node. All the reducer patterns will
+ /// be applied to this instance.
+ ModuleOp module;
+
+ /// The region of certain operation we're reducing in the module
+ Region *region;
+
+ /// The node we are reduced from. It means we will be in variants of parent
+ /// node.
+ ReductionNode *parent;
+
+ /// The size of module after applying the reducer patterns with range
+ /// constraints. This is only valid while the interestingness has been tested.
size_t size;
/// This is true if the module has been evaluated and it exhibits the
/// interesting behavior.
Tester::Interestingness interesting;
- ReductionNode *parent;
-
- /// We will only keep the operation with index falls into the ranges.
- /// For example, number each function in a certain module and then we will
- /// remove the functions with index outside the ranges and see if the
- /// resulting module is still interesting.
+ /// `ranges` represents the selected subset of operations in the region. We
+ /// implictly number each operation in the region and ReductionTreePass will
+ /// apply reducer patterns on the operation falls into the `ranges`. We will
+ /// generate new ReductionNode with subset of `ranges` to see if we can do
+ /// further reduction. we may split the element in the `ranges` so that we can
+ /// have more subset variants from `ranges`.
+ /// Note that after applying the reducer patterns the number of operation in
+ /// the region may have changed, we need to update the `ranges` after that.
std::vector<Range> ranges;
+ /// `startRanges` records the ranges of operations selected from the parent
+ /// node to produce this ReductionNode. It can be used to construct the
+ /// reduction path from the root. I.e., if we apply the same reducer patterns
+ /// and `startRanges` selection on the parent region, we will get the same
+ /// module as this node.
+ const std::vector<Range> startRanges;
+
/// This points to the child variants that were created using this node as a
/// starting point.
std::vector<ReductionNode *> variants;
@@ -139,9 +188,9 @@ class ReductionNode::iterator<SinglePath>
: public BaseIterator<iterator<SinglePath>> {
friend BaseIterator<iterator<SinglePath>>;
using BaseIterator::BaseIterator;
- std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
+ ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
};
} // end namespace mlir
-#endif
+#endif // MLIR_REDUCER_REDUCTIONNODE_H
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
new file mode 100644
index 0000000000000..887d120e4c352
--- /dev/null
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -0,0 +1,56 @@
+//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
+#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+
+class RewritePatternSet;
+
+/// This is used to report the reduction patterns for a Dialect. While using
+/// mlir-reduce to reduce a module, we may want to transform certain cases into
+/// simpler forms by applying certain rewrite patterns. Implement the
+/// `populateReductionPatterns` to report those patterns by adding them to the
+/// RewritePatternSet.
+///
+/// Example:
+/// MyDialectReductionPattern::populateReductionPatterns(
+/// RewritePatternSet &patterns) {
+/// patterns.add<TensorOpReduction>(patterns.getContext());
+/// }
+///
+/// For DRR, mlir-tblgen will generate a helper function
+/// `populateWithGenerated` which has the same signature therefore you can
+/// delegate to the helper function as well.
+///
+/// Example:
+/// MyDialectReductionPattern::populateReductionPatterns(
+/// RewritePatternSet &patterns) {
+/// // Include the autogen file somewhere above.
+/// populateWithGenerated(patterns);
+/// }
+class DialectReductionPatternInterface
+ : public DialectInterface::Base<DialectReductionPatternInterface> {
+public:
+ /// Patterns provided here are intended to transform operations from a complex
+ /// form to a simpler form, without breaking the semantics of the program
+ /// being reduced. For example, you may want to replace the
+ /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
+ /// replacing an operation with a constant.
+ virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
+
+protected:
+ DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
+};
+
+} // end namespace mlir
+
+#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h
deleted file mode 100644
index a657d6187597b..0000000000000
--- a/mlir/include/mlir/Reducer/ReductionTreePass.h
+++ /dev/null
@@ -1,50 +0,0 @@
-//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Pass class. It provides a framework for
-// the implementation of
diff erent reduction passes in the MLIR Reduce tool. It
-// allows for custom specification of the variant generation behavior. It
-// implements methods that define the
diff erent possible traversals of the
-// reduction tree.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
-#define MLIR_REDUCER_REDUCTIONTREEPASS_H
-
-#include <vector>
-
-#include "PassDetail.h"
-#include "ReductionNode.h"
-#include "mlir/Reducer/Passes/OpReducer.h"
-#include "mlir/Reducer/Tester.h"
-
-#define DEBUG_TYPE "mlir-reduce"
-
-namespace mlir {
-
-/// This class defines the Reduction Tree Pass. It provides a framework to
-/// to implement a reduction pass using a tree structure to keep track of the
-/// generated reduced variants.
-class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
-public:
- ReductionTreePass() = default;
- ReductionTreePass(const ReductionTreePass &pass) = default;
-
- /// Runs the pass instance in the pass pipeline.
- void runOnOperation() override;
-
-private:
- template <typename IteratorType>
- ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
- ReductionNode *node);
-};
-
-} // end namespace mlir
-
-#endif
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 73601f031d574..e6dd933921792 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -1,7 +1,13 @@
add_mlir_library(MLIRReduce
+ OptReductionPass.cpp
+ ReductionNode.cpp
+ ReductionTreePass.cpp
Tester.cpp
LINK_LIBS PUBLIC
MLIRIR
+ MLIRPass
+ MLIRRewrite
+ MLIRTransformUtils
)
mlir_check_all_link_libraries(MLIRReduce)
diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp
similarity index 88%
rename from mlir/tools/mlir-reduce/OptReductionPass.cpp
rename to mlir/lib/Reducer/OptReductionPass.cpp
index 533afc0f2aeed..2d0f47ced930d 100644
--- a/mlir/tools/mlir-reduce/OptReductionPass.cpp
+++ b/mlir/lib/Reducer/OptReductionPass.cpp
@@ -12,15 +12,27 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Reducer/OptReductionPass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Reducer/PassDetail.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/Tester.h"
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "mlir-reduce"
using namespace mlir;
+namespace {
+
+class OptReductionPass : public OptReductionBase<OptReductionPass> {
+public:
+ /// Runs the pass instance in the pass pipeline.
+ void runOnOperation() override;
+};
+
+} // end anonymous namespace
+
/// Runs the pass instance in the pass pipeline.
void OptReductionPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp
similarity index 61%
rename from mlir/tools/mlir-reduce/ReductionNode.cpp
rename to mlir/lib/Reducer/ReductionNode.cpp
index a8e4af8c88223..e2a96681e9aa1 100644
--- a/mlir/tools/mlir-reduce/ReductionNode.cpp
+++ b/mlir/lib/Reducer/ReductionNode.cpp
@@ -15,6 +15,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
@@ -23,102 +24,102 @@
using namespace mlir;
ReductionNode::ReductionNode(
- ReductionNode *parent, std::vector<Range> ranges,
+ ReductionNode *parentNode, std::vector<Range> ranges,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
- : size(std::numeric_limits<size_t>::max()),
- interesting(Tester::Interestingness::Untested),
- /// Root node will have the parent pointer point to themselves.
- parent(parent == nullptr ? this : parent), ranges(ranges),
- allocator(allocator) {}
-
-/// Returns the size in bytes of the module.
-size_t ReductionNode::getSize() const { return size; }
-
-ReductionNode *ReductionNode::getParent() const { return parent; }
-
-/// Returns true if the module exhibits the interesting behavior.
-Tester::Interestingness ReductionNode::isInteresting() const {
- return interesting;
+ /// Root node will have the parent pointer point to themselves.
+ : parent(parentNode == nullptr ? this : parentNode),
+ size(std::numeric_limits<size_t>::max()),
+ interesting(Tester::Interestingness::Untested), ranges(ranges),
+ startRanges(ranges), allocator(allocator) {
+ if (parent != this)
+ if (failed(initialize(parent->getModule(), parent->getRegion())))
+ llvm_unreachable("unexpected initialization failure");
}
-std::vector<ReductionNode::Range> ReductionNode::getRanges() const {
- return ranges;
+LogicalResult ReductionNode::initialize(ModuleOp parentModule,
+ Region &targetRegion) {
+ // Use the mapper help us find the corresponding region after module clone.
+ BlockAndValueMapping mapper;
+ module = cast<ModuleOp>(parentModule->clone(mapper));
+ // Use the first block of targetRegion to locate the cloned region.
+ Block *block = mapper.lookup(&*targetRegion.begin());
+ region = block->getParent();
+ return success();
}
-std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
-
-#include <iostream>
-
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.
-std::vector<ReductionNode *> ReductionNode::generateNewVariants() {
- std::vector<ReductionNode *> newNodes;
+ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
+ int oldNumVariant = getVariants().size();
+
+ auto createNewNode = [this](std::vector<Range> ranges) {
+ return new (allocator.Allocate())
+ ReductionNode(this, std::move(ranges), allocator);
+ };
// If we haven't created new variant, then we can create varients by removing
// each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
// produce variants with range {{1, 3}} and {{4, 9}}.
- if (variants.size() == 0 && ranges.size() != 1) {
- for (const Range &range : ranges) {
- std::vector<Range> subRanges = ranges;
+ if (variants.size() == 0 && getRanges().size() > 1) {
+ for (const Range &range : getRanges()) {
+ std::vector<Range> subRanges = getRanges();
llvm::erase_value(subRanges, range);
- ReductionNode *newNode = allocator.Allocate();
- new (newNode) ReductionNode(this, subRanges, allocator);
- newNodes.push_back(newNode);
- variants.push_back(newNode);
+ variants.push_back(createNewNode(std::move(subRanges)));
}
- return newNodes;
+ return getVariants().drop_front(oldNumVariant);
}
// At here, we have created the type of variants mentioned above. We would
// like to split the max range into 2 to create 2 new variants. Continue on
// the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
// create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
- // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
+ // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
auto maxElement = std::max_element(
ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
return (lhs.second - lhs.first) > (rhs.second - rhs.first);
});
- // We can't split range with lenght 1, which means we can't produce new
+ // The length of range is less than 1, we can't split it to create new
// variant.
- if (maxElement->second - maxElement->first == 1)
+ if (maxElement->second - maxElement->first <= 1)
return {};
- auto createNewNode = [this](const std::vector<Range> &ranges) {
- ReductionNode *newNode = allocator.Allocate();
- new (newNode) ReductionNode(this, ranges, allocator);
- return newNode;
- };
-
Range maxRange = *maxElement;
- std::vector<Range> subRanges = ranges;
+ std::vector<Range> subRanges = getRanges();
auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
int half = (maxRange.first + maxRange.second) / 2;
*subRangesIter = std::make_pair(maxRange.first, half);
- newNodes.push_back(createNewNode(subRanges));
+ variants.push_back(createNewNode(subRanges));
*subRangesIter = std::make_pair(half, maxRange.second);
- newNodes.push_back(createNewNode(subRanges));
+ variants.push_back(createNewNode(std::move(subRanges)));
- variants.insert(variants.end(), newNodes.begin(), newNodes.end());
auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
it = ranges.insert(it, std::make_pair(maxRange.first, half));
// Remove the range that has been split.
ranges.erase(it + 2);
- return newNodes;
+ return getVariants().drop_front(oldNumVariant);
}
void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
std::tie(interesting, size) = result;
+ // After applying reduction, the number of operation in the region may have
+ // changed. Non-interesting case won't be explored thus it's safe to keep it
+ // in a stale status.
+ if (interesting == Tester::Interestingness::True) {
+ // This module may has been updated. Reset the range.
+ ranges.clear();
+ ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
+ }
}
-std::vector<ReductionNode *>
+ArrayRef<ReductionNode *>
ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
// Single Path: Traverses the smallest successful variant at each level until
// no new successful variants can be created at that level.
- llvm::ArrayRef<ReductionNode *> variantsFromParent =
+ ArrayRef<ReductionNode *> variantsFromParent =
node->getParent()->getVariants();
// The parent node created several variants and they may be waiting for
@@ -139,7 +140,8 @@ ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
smallest = node;
}
- if (smallest != nullptr) {
+ if (smallest != nullptr &&
+ smallest->getSize() < node->getParent()->getSize()) {
// We got a smallest one, keep traversing from this node.
node = smallest;
} else {
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
new file mode 100644
index 0000000000000..f564c3b42410f
--- /dev/null
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -0,0 +1,247 @@
+//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Reduction Tree Pass class. It provides a framework for
+// the implementation of
diff erent reduction passes in the MLIR Reduce tool. It
+// allows for custom specification of the variant generation behavior. It
+// implements methods that define the
diff erent possible traversals of the
+// reduction tree.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Reducer/PassDetail.h"
+#include "mlir/Reducer/Passes.h"
+#include "mlir/Reducer/ReductionNode.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Reducer/Tester.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ManagedStatic.h"
+
+using namespace mlir;
+
+/// We implicitly number each operation in the region and if an operation's
+/// number falls into rangeToKeep, we need to keep it and apply the given
+/// rewrite patterns on it.
+static void applyPatterns(Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ ArrayRef<ReductionNode::Range> rangeToKeep,
+ bool eraseOpNotInRange) {
+ std::vector<Operation *> opsNotInRange;
+ std::vector<Operation *> opsInRange;
+ size_t keepIndex = 0;
+ for (auto op : enumerate(region.getOps())) {
+ int index = op.index();
+ if (keepIndex < rangeToKeep.size() &&
+ index == rangeToKeep[keepIndex].second)
+ ++keepIndex;
+ if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
+ opsNotInRange.push_back(&op.value());
+ else
+ opsInRange.push_back(&op.value());
+ }
+
+ // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
+ // matching in above iteration. Besides, erase op not-in-range may end up in
+ // invalid module, so `applyOpPatternsAndFold` should come before that
+ // transform.
+ for (Operation *op : opsInRange)
+ // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
+ // because we don't have expectation this reduction will be success or not.
+ (void)applyOpPatternsAndFold(op, patterns);
+
+ if (eraseOpNotInRange)
+ for (Operation *op : opsNotInRange) {
+ op->dropAllUses();
+ op->erase();
+ }
+}
+
+/// We will apply the reducer patterns to the operations in the ranges specified
+/// by ReductionNode. Note that we are not able to remove an operation without
+/// replacing it with another valid operation. However, The validity of module
+/// reduction is based on the Tester provided by the user and that means certain
+/// invalid module is still interested by the use. Thus we provide an
+/// alternative way to remove operations, which is using `eraseOpNotInRange` to
+/// erase the operations not in the range specified by ReductionNode.
+template <typename IteratorType>
+static void findOptimal(ModuleOp module, Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ const Tester &test, bool eraseOpNotInRange) {
+ std::pair<Tester::Interestingness, size_t> initStatus =
+ test.isInteresting(module);
+ // While exploring the reduction tree, we always branch from an interesting
+ // node. Thus the root node must be interesting.
+ if (initStatus.first != Tester::Interestingness::True)
+ return;
+
+ llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
+
+ std::vector<ReductionNode::Range> ranges{
+ {0, std::distance(region.op_begin(), region.op_end())}};
+
+ ReductionNode *root = allocator.Allocate();
+ new (root) ReductionNode(nullptr, std::move(ranges), allocator);
+ // Duplicate the module for root node and locate the region in the copy.
+ if (failed(root->initialize(module, region)))
+ llvm_unreachable("unexpected initialization failure");
+ root->update(initStatus);
+
+ ReductionNode *smallestNode = root;
+ IteratorType iter(root);
+
+ while (iter != IteratorType::end()) {
+ ReductionNode ¤tNode = *iter;
+ Region &curRegion = currentNode.getRegion();
+
+ applyPatterns(curRegion, patterns, currentNode.getRanges(),
+ eraseOpNotInRange);
+ currentNode.update(test.isInteresting(currentNode.getModule()));
+
+ if (currentNode.isInteresting() == Tester::Interestingness::True &&
+ currentNode.getSize() < smallestNode->getSize())
+ smallestNode = ¤tNode;
+
+ ++iter;
+ }
+
+ // At here, we have found an optimal path to reduce the given region. Retrieve
+ // the path and apply the reducer to it.
+ SmallVector<ReductionNode *> trace;
+ ReductionNode *curNode = smallestNode;
+ trace.push_back(curNode);
+ while (curNode != root) {
+ curNode = curNode->getParent();
+ trace.push_back(curNode);
+ }
+
+ // Reduce the region through the optimal path.
+ while (!trace.empty()) {
+ ReductionNode *top = trace.pop_back_val();
+ applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ }
+
+ if (test.isInteresting(module).first != Tester::Interestingness::True)
+ llvm::report_fatal_error("Reduced module is not interesting");
+ if (test.isInteresting(module).second != smallestNode->getSize())
+ llvm::report_fatal_error(
+ "Reduced module doesn't have consistent size with smallestNode");
+}
+
+template <typename IteratorType>
+static void findOptimal(ModuleOp module, Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ const Tester &test) {
+ // We separate the reduction process into 2 steps, the first one is to erase
+ // redundant operations and the second one is to apply the reducer patterns.
+
+ // In the first phase, we don't apply any patterns so that we only select the
+ // range of operations to keep to the module stay interesting.
+ findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
+ /*eraseOpNotInRange=*/true);
+ // In the second phase, we suppose that no operation is redundant, so we try
+ // to rewrite the operation into simpler form.
+ findOptimal<IteratorType>(module, region, patterns, test,
+ /*eraseOpNotInRange=*/false);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Reduction Pattern Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternInterfaceCollection
+ : public DialectInterfaceCollection<DialectReductionPatternInterface> {
+public:
+ using Base::Base;
+
+ // Collect the reduce patterns defined by each dialect.
+ void populateReductionPatterns(RewritePatternSet &pattern) const {
+ for (const DialectReductionPatternInterface &interface : *this)
+ interface.populateReductionPatterns(pattern);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ReductionTreePass
+//===----------------------------------------------------------------------===//
+
+/// This class defines the Reduction Tree Pass. It provides a framework to
+/// to implement a reduction pass using a tree structure to keep track of the
+/// generated reduced variants.
+class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
+public:
+ ReductionTreePass() = default;
+ ReductionTreePass(const ReductionTreePass &pass) = default;
+
+ LogicalResult initialize(MLIRContext *context) override;
+
+ /// Runs the pass instance in the pass pipeline.
+ void runOnOperation() override;
+
+private:
+ void reduceOp(ModuleOp module, Region ®ion);
+
+ FrozenRewritePatternSet reducerPatterns;
+};
+
+} // end anonymous namespace
+
+LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+ RewritePatternSet patterns(context);
+ ReductionPatternInterfaceCollection reducePatternCollection(context);
+ reducePatternCollection.populateReductionPatterns(patterns);
+ reducerPatterns = std::move(patterns);
+ return success();
+}
+
+void ReductionTreePass::runOnOperation() {
+ Operation *topOperation = getOperation();
+ while (topOperation->getParentOp() != nullptr)
+ topOperation = topOperation->getParentOp();
+ ModuleOp module = cast<ModuleOp>(topOperation);
+
+ SmallVector<Operation *, 8> workList;
+ workList.push_back(getOperation());
+
+ do {
+ Operation *op = workList.pop_back_val();
+
+ for (Region ®ion : op->getRegions())
+ if (!region.empty())
+ reduceOp(module, region);
+
+ for (Region ®ion : op->getRegions())
+ for (Operation &op : region.getOps())
+ if (op.getNumRegions() != 0)
+ workList.push_back(&op);
+ } while (!workList.empty());
+}
+
+void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
+ Tester test(testerName, testerArgs);
+ switch (traversalModeId) {
+ case TraversalMode::SinglePath:
+ findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
+ module, region, reducerPatterns, test);
+ break;
+ default:
+ llvm_unreachable("Unsupported mode");
+ }
+}
+
+std::unique_ptr<Pass> mlir::createReductionTreePass() {
+ return std::make_unique<ReductionTreePass>();
+}
diff --git a/mlir/lib/Reducer/Tester.cpp b/mlir/lib/Reducer/Tester.cpp
index c0d4862481016..b5531a9a0ca02 100644
--- a/mlir/lib/Reducer/Tester.cpp
+++ b/mlir/lib/Reducer/Tester.cpp
@@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/Tester.h"
-
+#include "mlir/IR/Verifier.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
@@ -25,6 +25,12 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
std::pair<Tester::Interestingness, size_t>
Tester::isInteresting(ModuleOp module) const {
+ // The reduced module should always be vaild, or we may end up retaining the
+ // error message by an invalid case. Besides, an invalid module may not be
+ // able to print properly.
+ if (failed(verify(module)))
+ return std::make_pair(Interestingness::False, /*size=*/0);
+
SmallString<128> filepath;
int fd;
@@ -50,7 +56,6 @@ Tester::isInteresting(ModuleOp module) const {
/// true if the interesting behavior is present in the test case or false
/// otherwise.
Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
-
std::vector<StringRef> testerArgs;
testerArgs.push_back(testCase);
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 32870af2a4fe9..d1cf46ae5788b 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_library(MLIRTestDialect
MLIRInferTypeOpInterface
MLIRLinalgTransforms
MLIRPass
+ MLIRReduce
MLIRStandard
MLIRStandardOpsTransforms
MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 396f71eac802f..77aa20956bf51 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestAttributes.h"
+#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -16,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h"
@@ -170,6 +172,18 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
};
+
+struct TestReductionPatternInterface : public DialectReductionPatternInterface {
+public:
+ TestReductionPatternInterface(Dialect *dialect)
+ : DialectReductionPatternInterface(dialect) {}
+
+ virtual void
+ populateReductionPatterns(RewritePatternSet &patterns) const final {
+ populateTestReductionPatterns(patterns);
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -207,7 +221,7 @@ void TestDialect::initialize() {
#include "TestOps.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
- TestInlinerInterface>();
+ TestInlinerInterface, TestReductionPatternInterface>();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 79cc8dd789267..d57a2c119723a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -34,6 +34,7 @@
namespace mlir {
class DLTIDialect;
+class RewritePatternSet;
} // namespace mlir
#include "TestOpEnums.h.inc"
@@ -47,6 +48,7 @@ class DLTIDialect;
namespace mlir {
namespace test {
void registerTestDialect(DialectRegistry ®istry);
+void populateTestReductionPatterns(RewritePatternSet &patterns);
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8c3e45600b9f8..847436ea3fb09 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2113,4 +2113,19 @@ def DataLayoutQueryOp : TEST_Op<"data_layout_query"> {
let results = (outs AnyType:$res);
}
+//===----------------------------------------------------------------------===//
+// Test Reducer Patterns
+//===----------------------------------------------------------------------===//
+
+def OpCrashLong : TEST_Op<"op_crash_long"> {
+ let arguments = (ins I32, I32, I32);
+ let results = (outs I32);
+}
+
+def OpCrashShort : TEST_Op<"op_crash_short"> {
+ let results = (outs I32);
+}
+
+def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 32f47f40ad36a..eabf9e0110eb7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -58,6 +58,14 @@ namespace {
#include "TestPatterns.inc"
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// Test Reduce Pattern Interface
+//===----------------------------------------------------------------------===//
+
+void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
+ populateWithGenerated(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Canonicalizer Driver.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp
index cb9d219349095..fe16390129e94 100644
--- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp
+++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp
@@ -38,7 +38,7 @@ void TestReducer::runOnFunction() {
op.walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
- if (opName == "test.crashOp") {
+ if (opName.contains("op_crash")) {
llvm::errs() << "MLIR Reducer Test generated failure: Found "
"\"crashOp\" operation\n";
exit(1);
diff --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir
new file mode 100644
index 0000000000000..dd65b0fbc0854
--- /dev/null
+++ b/mlir/test/mlir-reduce/crashop-reduction.mlir
@@ -0,0 +1,20 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
+// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short".
+
+// CHECK-NOT: func @simple1() {
+func @simple1() {
+ return
+}
+
+// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
+func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
+ // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32
+ %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32
+ return
+}
+
+// CHECK-NOT: func @simple5() {
+func @simple5() {
+ return
+}
diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir
index 7faa1505b2aaf..c0a93c8eb6ed6 100644
--- a/mlir/test/mlir-reduce/dce-test.mlir
+++ b/mlir/test/mlir-reduce/dce-test.mlir
@@ -12,6 +12,6 @@ func nested @dead_nested_function()
// CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
- "test.crashOp" () : () -> ()
+ "test.op_crash" () : () -> ()
return
}
diff --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir
index 06e51a1a8f468..22a444040fabe 100644
--- a/mlir/test/mlir-reduce/multiple-function.mlir
+++ b/mlir/test/mlir-reduce/multiple-function.mlir
@@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows
-// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// This input should be reduced by the pass pipeline so that only
// the @simple5 function remains as this is the shortest function
// containing the interesting behavior.
@@ -16,7 +16,7 @@ func @simple2() {
// CHECK-LABEL: func @simple3() {
func @simple3() {
- "test.crashOp" () : () -> ()
+ "test.op_crash" () : () -> ()
return
}
@@ -29,7 +29,7 @@ func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
- "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+ "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
diff --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir
index 663003009209a..fd90da3e08392 100644
--- a/mlir/test/mlir-reduce/simple-test.mlir
+++ b/mlir/test/mlir-reduce/simple-test.mlir
@@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows
-// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh'
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh'
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
diff --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir
index 732963553e900..adb2376745b41 100644
--- a/mlir/test/mlir-reduce/single-function.mlir
+++ b/mlir/test/mlir-reduce/single-function.mlir
@@ -2,6 +2,6 @@
// RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer
func @test() {
- "test.crashOp"() : () -> ()
+ "test.op_crash"() : () -> ()
return
}
diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt
index d58fe915a2864..b59bfcb441123 100644
--- a/mlir/tools/mlir-reduce/CMakeLists.txt
+++ b/mlir/tools/mlir-reduce/CMakeLists.txt
@@ -43,9 +43,6 @@ set(LIBS
)
add_llvm_tool(mlir-reduce
- OptReductionPass.cpp
- ReductionNode.cpp
- ReductionTreePass.cpp
mlir-reduce.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp
deleted file mode 100644
index e3a7565bd6a58..0000000000000
--- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the Reduction Tree Pass class. It provides a framework for
-// the implementation of
diff erent reduction passes in the MLIR Reduce tool. It
-// allows for custom specification of the variant generation behavior. It
-// implements methods that define the
diff erent possible traversals of the
-// reduction tree.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Passes.h"
-
-#include "llvm/Support/Allocator.h"
-
-using namespace mlir;
-
-static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
- if (opType == ModuleOp::getOperationName())
- return std::make_unique<Reducer<ModuleOp>>();
- else if (opType == FuncOp::getOperationName())
- return std::make_unique<Reducer<FuncOp>>();
- llvm_unreachable("Now only supports two built-in ops");
-}
-
-void ReductionTreePass::runOnOperation() {
- ModuleOp module = this->getOperation();
- std::unique_ptr<OpReducer> reducer = getOpReducer(opReducerName);
- std::vector<std::pair<int, int>> ranges = {
- {0, reducer->getNumTargetOps(module)}};
-
- llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
-
- ReductionNode *root = allocator.Allocate();
- new (root) ReductionNode(nullptr, ranges, allocator);
-
- ModuleOp golden = module;
- switch (traversalModeId) {
- case TraversalMode::SinglePath:
- golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, std::move(reducer), root);
- break;
- default:
- llvm_unreachable("Unsupported mode");
- }
-
- if (golden != module) {
- module.getBody()->clear();
- module.getBody()->getOperations().splice(module.getBody()->begin(),
- golden.getBody()->getOperations());
- golden->destroy();
- }
-}
-
-template <typename IteratorType>
-ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
- std::unique_ptr<OpReducer> reducer,
- ReductionNode *root) {
- Tester test(testerName, testerArgs);
- std::pair<Tester::Interestingness, size_t> initStatus =
- test.isInteresting(module);
-
- if (initStatus.first != Tester::Interestingness::True) {
- LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested");
- return module;
- }
-
- root->update(initStatus);
-
- ReductionNode *smallestNode = root;
- ModuleOp golden = module;
-
- IteratorType iter(root);
-
- while (iter != IteratorType::end()) {
- ModuleOp cloneModule = module.clone();
-
- ReductionNode ¤tNode = *iter;
- reducer->reduce(cloneModule, currentNode.getRanges());
-
- std::pair<Tester::Interestingness, size_t> result =
- test.isInteresting(cloneModule);
- currentNode.update(result);
-
- if (result.first == Tester::Interestingness::True &&
- result.second < smallestNode->getSize()) {
- smallestNode = ¤tNode;
- golden = cloneModule;
- } else {
- cloneModule->destroy();
- }
-
- ++iter;
- }
-
- return golden;
-}
-
-std::unique_ptr<Pass> mlir::createReductionTreePass() {
- return std::make_unique<ReductionTreePass>();
-}
diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index b23177d42fce4..1801ce1ffb58a 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -13,22 +13,14 @@
//
//===----------------------------------------------------------------------===//
-#include <vector>
-
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
-#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Reducer/Passes.h"
-#include "mlir/Reducer/Passes/OpReducer.h"
-#include "mlir/Reducer/ReductionNode.h"
-#include "mlir/Reducer/ReductionTreePass.h"
-#include "mlir/Reducer/Tester.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"
More information about the Mlir-commits
mailing list