[Mlir-commits] [mlir] [mlir][vector] Add DialectReductionPatternInterface and broadcast reduction (PR #183181)
Omer Farkash
llvmlistbot at llvm.org
Thu Feb 26 01:46:33 PST 2026
https://github.com/OmerFarkash updated https://github.com/llvm/llvm-project/pull/183181
>From 3cac7a6af90c84ddca145be7a8c4e7ee135465b3 Mon Sep 17 00:00:00 2001
From: omer farkash <omerfarkash at gmail.com>
Date: Tue, 24 Feb 2026 15:09:25 +0200
Subject: [PATCH] [mlir][vector] Add generic reduction pattern for
VectorDialect
This commit addresses reviewer feedback to improve the vector reduction pattern:
1. Dependency Inversion: The reduction logic has been moved from the `Vector` dialect into the `Reducer` library (`ReductionPatterns.cpp`). It is now implemented as an `ExternalModel` attached via `DialectReductionPatternInterface`, keeping the core dialect free of tooling dependencies.
2. Generic Reduction: The pattern now uses `MatchAnyOpTypeTag()` to generically match *any* operation returning a `VectorType`, replacing those results with `ub.poison`.
3. Test Fix: Added `UNSUPPORTED: system-windows` and a dedicated `vector-test.sh` script to ensure cross-platform CI stability and properly guide the reduction pass based on LLVM's expected exit codes.
---
mlir/lib/Reducer/CMakeLists.txt | 5 +-
mlir/lib/Reducer/ReductionPatterns.cpp | 69 ++++++++++++++++++++++++
mlir/test/mlir-reduce/vector-reduce.mlir | 11 ++++
mlir/test/mlir-reduce/vector-test.sh | 7 +++
mlir/tools/mlir-reduce/mlir-reduce.cpp | 9 ++++
5 files changed, 100 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/Reducer/ReductionPatterns.cpp
create mode 100644 mlir/test/mlir-reduce/vector-reduce.mlir
create mode 100755 mlir/test/mlir-reduce/vector-test.sh
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..a7bb1fa5817a0 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -3,16 +3,19 @@ add_mlir_library(MLIRReduce
ReductionNode.cpp
ReductionTreePass.cpp
Tester.cpp
+ ReductionPatterns.cpp
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRRewrite
MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRUBDialect
DEPENDS
MLIRReducerIncGen
MLIRDialectReductionPatternInterfaceIncGen
)
-mlir_check_all_link_libraries(MLIRReduce)
+mlir_check_all_link_libraries(MLIRReduce)
\ No newline at end of file
diff --git a/mlir/lib/Reducer/ReductionPatterns.cpp b/mlir/lib/Reducer/ReductionPatterns.cpp
new file mode 100644
index 0000000000000..f7fa0ca8bd3ce
--- /dev/null
+++ b/mlir/lib/Reducer/ReductionPatterns.cpp
@@ -0,0 +1,69 @@
+//===- ReductionPatterns.cpp - MLIR Reducer Patterns ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+
+using namespace mlir;
+
+namespace {
+/// A generic reduction pattern that replaces any operation returning a
+/// VectorType with a ub.poison value of the same type.
+struct GenericVectorPoisonReduction : public RewritePattern {
+ GenericVectorPoisonReduction(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Prevent infinite loops by ignoring operations that are already poison.
+ if (isa<ub::PoisonOp>(op))
+ return failure();
+
+ // Check if the operation has at least one vector result.
+ bool hasVectorResult = llvm::any_of(
+ op->getResultTypes(), [](Type t) { return isa<VectorType>(t); });
+
+ if (!hasVectorResult)
+ return failure();
+
+ SmallVector<Value> replacements;
+ for (auto [idx, type] : llvm::enumerate(op->getResultTypes())) {
+ if (isa<VectorType>(type))
+ replacements.push_back(
+ ub::PoisonOp::create(rewriter, op->getLoc(), type));
+ else
+ replacements.push_back(
+ op->getResult(idx)); // Preserve non-vector results.
+ }
+
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+};
+
+/// Dialect interface to attach the reduction pattern to the Vector dialect.
+struct VectorReductionInterface : public DialectReductionPatternInterface {
+ VectorReductionInterface(Dialect *dialect)
+ : DialectReductionPatternInterface(dialect) {}
+
+ void populateReductionPatterns(RewritePatternSet &patterns) const override {
+ patterns.add<GenericVectorPoisonReduction>(patterns.getContext());
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerReducerExtension(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ dialect->addInterfaces<VectorReductionInterface>();
+ });
+}
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/test/mlir-reduce/vector-reduce.mlir b/mlir/test/mlir-reduce/vector-reduce.mlir
new file mode 100644
index 0000000000000..7112609f331b9
--- /dev/null
+++ b/mlir/test/mlir-reduce/vector-reduce.mlir
@@ -0,0 +1,11 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/vector-test.sh' | FileCheck %s
+
+// CHECK-LABEL: func.func @reduce_vector_op
+func.func @reduce_vector_op(%arg0: f32) -> vector<4xf32> {
+ // CHECK-NOT: vector.broadcast
+ // CHECK: %[[POISON:.*]] = ub.poison : vector<4xf32>
+ // CHECK: return %[[POISON]] : vector<4xf32>
+ %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/mlir-reduce/vector-test.sh b/mlir/test/mlir-reduce/vector-test.sh
new file mode 100755
index 0000000000000..b66ab5f51c850
--- /dev/null
+++ b/mlir/test/mlir-reduce/vector-test.sh
@@ -0,0 +1,7 @@
+#!/bin/sh
+# If the file still contains 'vector<4xf32>', it's interesting!
+if grep -q 'vector<4xf32>' "$1"; then
+ exit 1
+else
+ exit 0
+fi
diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index 44b21b805e8c3..51e9f2794edd5 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -19,8 +19,15 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
+#include "mlir/InitAllPasses.h"
+#include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
+
using namespace mlir;
+namespace mlir {
+void registerReducerExtension(DialectRegistry ®istry);
+}
+
namespace test {
#ifdef MLIR_INCLUDE_TESTS
void registerTestDialect(DialectRegistry &);
@@ -32,6 +39,8 @@ int main(int argc, char **argv) {
DialectRegistry registry;
registerAllDialects(registry);
+
+ mlir::registerReducerExtension(registry);
#ifdef MLIR_INCLUDE_TESTS
test::registerTestDialect(registry);
#endif
More information about the Mlir-commits
mailing list