[Mlir-commits] [mlir] [mlir][vector] Add DialectReductionPatternInterface and broadcast reduction (PR #183181)

Omer Farkash llvmlistbot at llvm.org
Thu Feb 26 01:46:33 PST 2026


https://github.com/OmerFarkash updated https://github.com/llvm/llvm-project/pull/183181

>From 3cac7a6af90c84ddca145be7a8c4e7ee135465b3 Mon Sep 17 00:00:00 2001
From: omer farkash <omerfarkash at gmail.com>
Date: Tue, 24 Feb 2026 15:09:25 +0200
Subject: [PATCH] [mlir][vector] Add generic reduction pattern for
 VectorDialect

This commit addresses reviewer feedback to improve the vector reduction pattern:
1. Dependency Inversion: The reduction logic has been moved from the `Vector` dialect into the `Reducer` library (`ReductionPatterns.cpp`). It is now implemented as an `ExternalModel` attached via `DialectReductionPatternInterface`, keeping the core dialect free of tooling dependencies.
2. Generic Reduction: The pattern now uses `MatchAnyOpTypeTag()` to generically match *any* operation returning a `VectorType`, replacing those results with `ub.poison`.
3. Test Fix: Added `UNSUPPORTED: system-windows` and a dedicated `vector-test.sh` script to ensure cross-platform CI stability and properly guide the reduction pass based on LLVM's expected exit codes.
---
 mlir/lib/Reducer/CMakeLists.txt          |  5 +-
 mlir/lib/Reducer/ReductionPatterns.cpp   | 69 ++++++++++++++++++++++++
 mlir/test/mlir-reduce/vector-reduce.mlir | 11 ++++
 mlir/test/mlir-reduce/vector-test.sh     |  7 +++
 mlir/tools/mlir-reduce/mlir-reduce.cpp   |  9 ++++
 5 files changed, 100 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Reducer/ReductionPatterns.cpp
 create mode 100644 mlir/test/mlir-reduce/vector-reduce.mlir
 create mode 100755 mlir/test/mlir-reduce/vector-test.sh

diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index 68864e373c993..a7bb1fa5817a0 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -3,16 +3,19 @@ add_mlir_library(MLIRReduce
    ReductionNode.cpp
    ReductionTreePass.cpp
    Tester.cpp
+   ReductionPatterns.cpp
 
    LINK_LIBS PUBLIC
    MLIRIR
    MLIRPass
    MLIRRewrite
    MLIRTransformUtils
+   MLIRVectorDialect
+   MLIRUBDialect
 
    DEPENDS
    MLIRReducerIncGen
    MLIRDialectReductionPatternInterfaceIncGen
 )
 
-mlir_check_all_link_libraries(MLIRReduce)
+mlir_check_all_link_libraries(MLIRReduce)
\ No newline at end of file
diff --git a/mlir/lib/Reducer/ReductionPatterns.cpp b/mlir/lib/Reducer/ReductionPatterns.cpp
new file mode 100644
index 0000000000000..f7fa0ca8bd3ce
--- /dev/null
+++ b/mlir/lib/Reducer/ReductionPatterns.cpp
@@ -0,0 +1,69 @@
+//===- ReductionPatterns.cpp - MLIR Reducer Patterns ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+
+using namespace mlir;
+
+namespace {
+/// A generic reduction pattern that replaces any operation returning a
+/// VectorType with a ub.poison value of the same type.
+struct GenericVectorPoisonReduction : public RewritePattern {
+  GenericVectorPoisonReduction(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Prevent infinite loops by ignoring operations that are already poison.
+    if (isa<ub::PoisonOp>(op))
+      return failure();
+
+    // Check if the operation has at least one vector result.
+    bool hasVectorResult = llvm::any_of(
+        op->getResultTypes(), [](Type t) { return isa<VectorType>(t); });
+
+    if (!hasVectorResult)
+      return failure();
+
+    SmallVector<Value> replacements;
+    for (auto [idx, type] : llvm::enumerate(op->getResultTypes())) {
+      if (isa<VectorType>(type))
+        replacements.push_back(
+            ub::PoisonOp::create(rewriter, op->getLoc(), type));
+      else
+        replacements.push_back(
+            op->getResult(idx)); // Preserve non-vector results.
+    }
+
+    rewriter.replaceOp(op, replacements);
+    return success();
+  }
+};
+
+/// Dialect interface to attach the reduction pattern to the Vector dialect.
+struct VectorReductionInterface : public DialectReductionPatternInterface {
+  VectorReductionInterface(Dialect *dialect)
+      : DialectReductionPatternInterface(dialect) {}
+
+  void populateReductionPatterns(RewritePatternSet &patterns) const override {
+    patterns.add<GenericVectorPoisonReduction>(patterns.getContext());
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerReducerExtension(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+    dialect->addInterfaces<VectorReductionInterface>();
+  });
+}
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/test/mlir-reduce/vector-reduce.mlir b/mlir/test/mlir-reduce/vector-reduce.mlir
new file mode 100644
index 0000000000000..7112609f331b9
--- /dev/null
+++ b/mlir/test/mlir-reduce/vector-reduce.mlir
@@ -0,0 +1,11 @@
+// UNSUPPORTED: system-windows
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/vector-test.sh' | FileCheck %s
+
+// CHECK-LABEL: func.func @reduce_vector_op
+func.func @reduce_vector_op(%arg0: f32) -> vector<4xf32> {
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[POISON:.*]] = ub.poison : vector<4xf32>
+  // CHECK: return %[[POISON]] : vector<4xf32>
+  %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/mlir-reduce/vector-test.sh b/mlir/test/mlir-reduce/vector-test.sh
new file mode 100755
index 0000000000000..b66ab5f51c850
--- /dev/null
+++ b/mlir/test/mlir-reduce/vector-test.sh
@@ -0,0 +1,7 @@
+#!/bin/sh
+# If the file still contains 'vector<4xf32>', it's interesting!
+if grep -q 'vector<4xf32>' "$1"; then
+  exit 1
+else
+  exit 0
+fi
diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index 44b21b805e8c3..51e9f2794edd5 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -19,8 +19,15 @@
 #include "mlir/InitAllPasses.h"
 #include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
 
+#include "mlir/InitAllPasses.h"
+#include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
+
 using namespace mlir;
 
+namespace mlir {
+void registerReducerExtension(DialectRegistry &registry);
+}
+
 namespace test {
 #ifdef MLIR_INCLUDE_TESTS
 void registerTestDialect(DialectRegistry &);
@@ -32,6 +39,8 @@ int main(int argc, char **argv) {
 
   DialectRegistry registry;
   registerAllDialects(registry);
+
+  mlir::registerReducerExtension(registry);
 #ifdef MLIR_INCLUDE_TESTS
   test::registerTestDialect(registry);
 #endif



More information about the Mlir-commits mailing list