[Mlir-commits] [mlir] [MLIR] Add reduction interface with tester to mlir-reduce (PR #166096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 2 12:50:02 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: AidinT (aidint)
<details>
<summary>Changes</summary>
Currently we don't have a support for patterns that need access to `Tester` instance in `mlir-reduce`. This PR adds `DialectReductionPatternWithTesterInterface ` to the set of supported interfaces. Dialects can use this interface to inject tester into the pattern classes.
---
Full diff: https://github.com/llvm/llvm-project/pull/166096.diff
3 Files Affected:
- (modified) mlir/include/mlir/Reducer/ReductionPatternInterface.h (+26)
- (modified) mlir/include/mlir/Reducer/Tester.h (+6)
- (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+29-2)
``````````diff
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a85562fda4d93..c4d7a94479358 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -10,6 +10,7 @@
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Reducer/Tester.h"
namespace mlir {
@@ -51,6 +52,31 @@ class DialectReductionPatternInterface
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
+/// This interface extends the `dialectreductionpatterninterface` by allowing
+/// reduction patterns to use a `Tester` instance. Some reduction patterns may
+/// need to run tester to determine whether certain transformations preserve the
+/// "interesting" behavior of the program. This is mostly useful when pattern
+/// should choose between multiple modifications.
+/// Implementation follows the same logic as the
+/// `dialectreductionpatterninterface`.
+///
+/// Example:
+/// MyDialectReductionPatternWithTester::populateReductionPatterns(
+/// RewritePatternSet &patterns, Tester &tester) {
+/// patterns.add<PatternWithTester>(patterns.getContext(), tester);
+/// }
+class DialectReductionPatternWithTesterInterface
+ : public DialectInterface::Base<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ virtual void populateReductionPatterns(RewritePatternSet &patterns,
+ Tester &tester) const = 0;
+
+protected:
+ DialectReductionPatternWithTesterInterface(Dialect *dialect)
+ : Base(dialect) {}
+};
+
} // namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index eb44afc7c1c15..bed4408342034 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -36,6 +36,9 @@ class Tester {
Untested,
};
+ Tester() = default;
+ Tester(const Tester &) = default;
+
Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
/// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
/// Return whether the file in the given path is interesting.
Interestingness isInteresting(StringRef testCase) const;
+ void setTestScript(StringRef script) { testScript = script; }
+ void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
+
private:
StringRef testScript;
ArrayRef<std::string> testScriptArgs;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..af94cd798f629 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -181,6 +181,24 @@ class ReductionPatternInterfaceCollection
}
};
+//===----------------------------------------------------------------------===//
+// Reduction Pattern With Tester Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternWithTesterInterfaceCollection
+ : public DialectInterfaceCollection<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ using Base::Base;
+
+ // Collect the reduce patterns defined by each dialect.
+ void populateReductionPatterns(RewritePatternSet &pattern,
+ Tester &tester) const {
+ for (const DialectReductionPatternWithTesterInterface &interface : *this)
+ interface.populateReductionPatterns(pattern, tester);
+ }
+};
+
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
@@ -201,15 +219,25 @@ class ReductionTreePass
private:
LogicalResult reduceOp(ModuleOp module, Region ®ion);
+ Tester tester;
FrozenRewritePatternSet reducerPatterns;
};
} // namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+ tester.setTestScript(testerName);
+ tester.setTestScriptArgs(testerArgs);
+
RewritePatternSet patterns(context);
+
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
+
+ ReductionPatternWithTesterInterfaceCollection
+ reducePatternWithTesterCollection(context);
+ reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+
reducerPatterns = std::move(patterns);
return success();
}
@@ -244,11 +272,10 @@ void ReductionTreePass::runOnOperation() {
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
- Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, tester);
default:
return module.emitError() << "unsupported traversal mode detected";
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/166096
More information about the Mlir-commits
mailing list