[Mlir-commits] [mlir] [MLIR] Add reduction interface with tester to mlir-reduce (PR #166096)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 5 12:56:30 PST 2025


https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/166096

>From 289910fbbaa15fd5556da402289bf0cfafa842ac Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 2 Nov 2025 21:39:29 +0100
Subject: [PATCH 1/3] mlir-reduce: add reduction interface with tester

---
 .../mlir/Reducer/ReductionPatternInterface.h  | 26 ++++++++++++++++
 mlir/include/mlir/Reducer/Tester.h            |  6 ++++
 mlir/lib/Reducer/ReductionTreePass.cpp        | 31 +++++++++++++++++--
 3 files changed, 61 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a85562fda4d93..c4d7a94479358 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
 
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Reducer/Tester.h"
 
 namespace mlir {
 
@@ -51,6 +52,31 @@ class DialectReductionPatternInterface
   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
 };
 
+/// This interface extends the `dialectreductionpatterninterface` by allowing
+/// reduction patterns to use a `Tester` instance. Some reduction patterns may
+/// need to run tester to determine whether certain transformations preserve the
+/// "interesting" behavior of the program. This is mostly useful when pattern
+/// should choose between multiple modifications.
+/// Implementation follows the same logic as the
+/// `dialectreductionpatterninterface`.
+///
+/// Example:
+///   MyDialectReductionPatternWithTester::populateReductionPatterns(
+///       RewritePatternSet &patterns, Tester &tester) {
+///       patterns.add<PatternWithTester>(patterns.getContext(), tester);
+///   }
+class DialectReductionPatternWithTesterInterface
+    : public DialectInterface::Base<
+          DialectReductionPatternWithTesterInterface> {
+public:
+  virtual void populateReductionPatterns(RewritePatternSet &patterns,
+                                         Tester &tester) const = 0;
+
+protected:
+  DialectReductionPatternWithTesterInterface(Dialect *dialect)
+      : Base(dialect) {}
+};
+
 } // namespace mlir
 
 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index eb44afc7c1c15..bed4408342034 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -36,6 +36,9 @@ class Tester {
     Untested,
   };
 
+  Tester() = default;
+  Tester(const Tester &) = default;
+
   Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
 
   /// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
   /// Return whether the file in the given path is interesting.
   Interestingness isInteresting(StringRef testCase) const;
 
+  void setTestScript(StringRef script) { testScript = script; }
+  void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
+
 private:
   StringRef testScript;
   ArrayRef<std::string> testScriptArgs;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..af94cd798f629 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -181,6 +181,24 @@ class ReductionPatternInterfaceCollection
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Reduction Pattern With Tester Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternWithTesterInterfaceCollection
+    : public DialectInterfaceCollection<
+          DialectReductionPatternWithTesterInterface> {
+public:
+  using Base::Base;
+
+  // Collect the reduce patterns defined by each dialect.
+  void populateReductionPatterns(RewritePatternSet &pattern,
+                                 Tester &tester) const {
+    for (const DialectReductionPatternWithTesterInterface &interface : *this)
+      interface.populateReductionPatterns(pattern, tester);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ReductionTreePass
 //===----------------------------------------------------------------------===//
@@ -201,15 +219,25 @@ class ReductionTreePass
 private:
   LogicalResult reduceOp(ModuleOp module, Region &region);
 
+  Tester tester;
   FrozenRewritePatternSet reducerPatterns;
 };
 
 } // namespace
 
 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+  tester.setTestScript(testerName);
+  tester.setTestScriptArgs(testerArgs);
+
   RewritePatternSet patterns(context);
+
   ReductionPatternInterfaceCollection reducePatternCollection(context);
   reducePatternCollection.populateReductionPatterns(patterns);
+
+  ReductionPatternWithTesterInterfaceCollection
+      reducePatternWithTesterCollection(context);
+  reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+
   reducerPatterns = std::move(patterns);
   return success();
 }
@@ -244,11 +272,10 @@ void ReductionTreePass::runOnOperation() {
 }
 
 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
-  Tester test(testerName, testerArgs);
   switch (traversalModeId) {
   case TraversalMode::SinglePath:
     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, region, reducerPatterns, test);
+        module, region, reducerPatterns, tester);
   default:
     return module.emitError() << "unsupported traversal mode detected";
   }

>From 0612a343fd0cbbb3128a316031084c2057a53782 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Wed, 5 Nov 2025 21:40:32 +0100
Subject: [PATCH 2/3] change to default implementation approach

---
 .../mlir/Reducer/ReductionPatternInterface.h  | 12 +-------
 mlir/lib/Reducer/ReductionTreePass.cpp        | 29 ++++---------------
 2 files changed, 6 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index c4d7a94479358..06b8b62a6bc2f 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -47,6 +47,7 @@ class DialectReductionPatternInterface
   /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
   /// replacing an operation with a constant.
   virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
+  virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, Tester &tester) const {}
 
 protected:
   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
@@ -65,17 +66,6 @@ class DialectReductionPatternInterface
 ///       RewritePatternSet &patterns, Tester &tester) {
 ///       patterns.add<PatternWithTester>(patterns.getContext(), tester);
 ///   }
-class DialectReductionPatternWithTesterInterface
-    : public DialectInterface::Base<
-          DialectReductionPatternWithTesterInterface> {
-public:
-  virtual void populateReductionPatterns(RewritePatternSet &patterns,
-                                         Tester &tester) const = 0;
-
-protected:
-  DialectReductionPatternWithTesterInterface(Dialect *dialect)
-      : Base(dialect) {}
-};
 
 } // namespace mlir
 
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index af94cd798f629..1e00ed645f71e 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -174,28 +174,13 @@ class ReductionPatternInterfaceCollection
 public:
   using Base::Base;
 
-  // Collect the reduce patterns defined by each dialect.
-  void populateReductionPatterns(RewritePatternSet &pattern) const {
-    for (const DialectReductionPatternInterface &interface : *this)
-      interface.populateReductionPatterns(pattern);
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// Reduction Pattern With Tester Interface Collection
-//===----------------------------------------------------------------------===//
-
-class ReductionPatternWithTesterInterfaceCollection
-    : public DialectInterfaceCollection<
-          DialectReductionPatternWithTesterInterface> {
-public:
-  using Base::Base;
-
   // Collect the reduce patterns defined by each dialect.
   void populateReductionPatterns(RewritePatternSet &pattern,
                                  Tester &tester) const {
-    for (const DialectReductionPatternWithTesterInterface &interface : *this)
-      interface.populateReductionPatterns(pattern, tester);
+    for (const DialectReductionPatternInterface &interface : *this) {
+      interface.populateReductionPatterns(pattern);
+      interface.populateReductionPatternsWithTester(pattern, tester);
+    }
   }
 };
 
@@ -232,11 +217,7 @@ LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
   RewritePatternSet patterns(context);
 
   ReductionPatternInterfaceCollection reducePatternCollection(context);
-  reducePatternCollection.populateReductionPatterns(patterns);
-
-  ReductionPatternWithTesterInterfaceCollection
-      reducePatternWithTesterCollection(context);
-  reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+  reducePatternCollection.populateReductionPatterns(patterns, tester);
 
   reducerPatterns = std::move(patterns);
   return success();

>From 04bdf55ae92873dba8e8fe343890f5a58a1102a6 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Wed, 5 Nov 2025 21:56:14 +0100
Subject: [PATCH 3/3] organize comments and format

---
 .../mlir/Reducer/ReductionPatternInterface.h  | 24 +++++++------------
 1 file changed, 8 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index 06b8b62a6bc2f..a33877dc0bd77 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -47,26 +47,18 @@ class DialectReductionPatternInterface
   /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
   /// replacing an operation with a constant.
   virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
-  virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, Tester &tester) const {}
+
+  /// This method extends `populateReductionPatterns` by allowing reduction
+  /// patterns to use a `Tester` instance. Some reduction patterns may need to
+  /// run tester to determine whether certain transformations preserve the
+  /// "interesting" behavior of the program. This is mostly useful when pattern
+  /// should choose between multiple modifications.
+  virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
+                                                   Tester &tester) const {}
 
 protected:
   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
 };
-
-/// This interface extends the `dialectreductionpatterninterface` by allowing
-/// reduction patterns to use a `Tester` instance. Some reduction patterns may
-/// need to run tester to determine whether certain transformations preserve the
-/// "interesting" behavior of the program. This is mostly useful when pattern
-/// should choose between multiple modifications.
-/// Implementation follows the same logic as the
-/// `dialectreductionpatterninterface`.
-///
-/// Example:
-///   MyDialectReductionPatternWithTester::populateReductionPatterns(
-///       RewritePatternSet &patterns, Tester &tester) {
-///       patterns.add<PatternWithTester>(patterns.getContext(), tester);
-///   }
-
 } // namespace mlir
 
 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H



More information about the Mlir-commits mailing list