[Mlir-commits] [mlir] [MLIR] Add reduction interface with tester to mlir-reduce (PR #166096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 5 12:56:30 PST 2025
https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/166096
>From 289910fbbaa15fd5556da402289bf0cfafa842ac Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 2 Nov 2025 21:39:29 +0100
Subject: [PATCH 1/3] mlir-reduce: add reduction interface with tester
---
.../mlir/Reducer/ReductionPatternInterface.h | 26 ++++++++++++++++
mlir/include/mlir/Reducer/Tester.h | 6 ++++
mlir/lib/Reducer/ReductionTreePass.cpp | 31 +++++++++++++++++--
3 files changed, 61 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a85562fda4d93..c4d7a94479358 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -10,6 +10,7 @@
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Reducer/Tester.h"
namespace mlir {
@@ -51,6 +52,31 @@ class DialectReductionPatternInterface
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
+/// This interface extends the `dialectreductionpatterninterface` by allowing
+/// reduction patterns to use a `Tester` instance. Some reduction patterns may
+/// need to run tester to determine whether certain transformations preserve the
+/// "interesting" behavior of the program. This is mostly useful when pattern
+/// should choose between multiple modifications.
+/// Implementation follows the same logic as the
+/// `dialectreductionpatterninterface`.
+///
+/// Example:
+/// MyDialectReductionPatternWithTester::populateReductionPatterns(
+/// RewritePatternSet &patterns, Tester &tester) {
+/// patterns.add<PatternWithTester>(patterns.getContext(), tester);
+/// }
+class DialectReductionPatternWithTesterInterface
+ : public DialectInterface::Base<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ virtual void populateReductionPatterns(RewritePatternSet &patterns,
+ Tester &tester) const = 0;
+
+protected:
+ DialectReductionPatternWithTesterInterface(Dialect *dialect)
+ : Base(dialect) {}
+};
+
} // namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index eb44afc7c1c15..bed4408342034 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -36,6 +36,9 @@ class Tester {
Untested,
};
+ Tester() = default;
+ Tester(const Tester &) = default;
+
Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
/// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
/// Return whether the file in the given path is interesting.
Interestingness isInteresting(StringRef testCase) const;
+ void setTestScript(StringRef script) { testScript = script; }
+ void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
+
private:
StringRef testScript;
ArrayRef<std::string> testScriptArgs;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..af94cd798f629 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -181,6 +181,24 @@ class ReductionPatternInterfaceCollection
}
};
+//===----------------------------------------------------------------------===//
+// Reduction Pattern With Tester Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternWithTesterInterfaceCollection
+ : public DialectInterfaceCollection<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ using Base::Base;
+
+ // Collect the reduce patterns defined by each dialect.
+ void populateReductionPatterns(RewritePatternSet &pattern,
+ Tester &tester) const {
+ for (const DialectReductionPatternWithTesterInterface &interface : *this)
+ interface.populateReductionPatterns(pattern, tester);
+ }
+};
+
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
@@ -201,15 +219,25 @@ class ReductionTreePass
private:
LogicalResult reduceOp(ModuleOp module, Region ®ion);
+ Tester tester;
FrozenRewritePatternSet reducerPatterns;
};
} // namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+ tester.setTestScript(testerName);
+ tester.setTestScriptArgs(testerArgs);
+
RewritePatternSet patterns(context);
+
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
+
+ ReductionPatternWithTesterInterfaceCollection
+ reducePatternWithTesterCollection(context);
+ reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+
reducerPatterns = std::move(patterns);
return success();
}
@@ -244,11 +272,10 @@ void ReductionTreePass::runOnOperation() {
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
- Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, tester);
default:
return module.emitError() << "unsupported traversal mode detected";
}
>From 0612a343fd0cbbb3128a316031084c2057a53782 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Wed, 5 Nov 2025 21:40:32 +0100
Subject: [PATCH 2/3] change to default implementation approach
---
.../mlir/Reducer/ReductionPatternInterface.h | 12 +-------
mlir/lib/Reducer/ReductionTreePass.cpp | 29 ++++---------------
2 files changed, 6 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index c4d7a94479358..06b8b62a6bc2f 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -47,6 +47,7 @@ class DialectReductionPatternInterface
/// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
/// replacing an operation with a constant.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
+ virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, Tester &tester) const {}
protected:
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
@@ -65,17 +66,6 @@ class DialectReductionPatternInterface
/// RewritePatternSet &patterns, Tester &tester) {
/// patterns.add<PatternWithTester>(patterns.getContext(), tester);
/// }
-class DialectReductionPatternWithTesterInterface
- : public DialectInterface::Base<
- DialectReductionPatternWithTesterInterface> {
-public:
- virtual void populateReductionPatterns(RewritePatternSet &patterns,
- Tester &tester) const = 0;
-
-protected:
- DialectReductionPatternWithTesterInterface(Dialect *dialect)
- : Base(dialect) {}
-};
} // namespace mlir
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index af94cd798f629..1e00ed645f71e 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -174,28 +174,13 @@ class ReductionPatternInterfaceCollection
public:
using Base::Base;
- // Collect the reduce patterns defined by each dialect.
- void populateReductionPatterns(RewritePatternSet &pattern) const {
- for (const DialectReductionPatternInterface &interface : *this)
- interface.populateReductionPatterns(pattern);
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Reduction Pattern With Tester Interface Collection
-//===----------------------------------------------------------------------===//
-
-class ReductionPatternWithTesterInterfaceCollection
- : public DialectInterfaceCollection<
- DialectReductionPatternWithTesterInterface> {
-public:
- using Base::Base;
-
// Collect the reduce patterns defined by each dialect.
void populateReductionPatterns(RewritePatternSet &pattern,
Tester &tester) const {
- for (const DialectReductionPatternWithTesterInterface &interface : *this)
- interface.populateReductionPatterns(pattern, tester);
+ for (const DialectReductionPatternInterface &interface : *this) {
+ interface.populateReductionPatterns(pattern);
+ interface.populateReductionPatternsWithTester(pattern, tester);
+ }
}
};
@@ -232,11 +217,7 @@ LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
RewritePatternSet patterns(context);
ReductionPatternInterfaceCollection reducePatternCollection(context);
- reducePatternCollection.populateReductionPatterns(patterns);
-
- ReductionPatternWithTesterInterfaceCollection
- reducePatternWithTesterCollection(context);
- reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+ reducePatternCollection.populateReductionPatterns(patterns, tester);
reducerPatterns = std::move(patterns);
return success();
>From 04bdf55ae92873dba8e8fe343890f5a58a1102a6 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Wed, 5 Nov 2025 21:56:14 +0100
Subject: [PATCH 3/3] organize comments and format
---
.../mlir/Reducer/ReductionPatternInterface.h | 24 +++++++------------
1 file changed, 8 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index 06b8b62a6bc2f..a33877dc0bd77 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -47,26 +47,18 @@ class DialectReductionPatternInterface
/// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
/// replacing an operation with a constant.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
- virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, Tester &tester) const {}
+
+ /// This method extends `populateReductionPatterns` by allowing reduction
+ /// patterns to use a `Tester` instance. Some reduction patterns may need to
+ /// run tester to determine whether certain transformations preserve the
+ /// "interesting" behavior of the program. This is mostly useful when pattern
+ /// should choose between multiple modifications.
+ virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
+ Tester &tester) const {}
protected:
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
-
-/// This interface extends the `dialectreductionpatterninterface` by allowing
-/// reduction patterns to use a `Tester` instance. Some reduction patterns may
-/// need to run tester to determine whether certain transformations preserve the
-/// "interesting" behavior of the program. This is mostly useful when pattern
-/// should choose between multiple modifications.
-/// Implementation follows the same logic as the
-/// `dialectreductionpatterninterface`.
-///
-/// Example:
-/// MyDialectReductionPatternWithTester::populateReductionPatterns(
-/// RewritePatternSet &patterns, Tester &tester) {
-/// patterns.add<PatternWithTester>(patterns.getContext(), tester);
-/// }
-
} // namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
More information about the Mlir-commits
mailing list