[Mlir-commits] [mlir] [mlir][SPIR-V] Lower boolean vector reductions (PR #192267)

Arseniy Obolenskiy llvmlistbot at llvm.org
Tue Apr 21 00:26:36 PDT 2026


https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/192267

>From 846b813e7b96bf10696e8d4a03708141bb90ffcf Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 15 Apr 2026 16:28:14 +0200
Subject: [PATCH 1/3] [mlir][SPIR-V] Add spirv.Any and spirv.All ops and lower
 vector boolean reductions to them

Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and> on boolean vectors to spirv.Any/spirv.All
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 26 +++++++++++
 .../VectorToSPIRV/vector-to-spirv.mlir        | 46 +++++++++++++++++++
 2 files changed, 72 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2149892cc603d..3709589d46c56 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -416,6 +416,32 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
     auto [resultType, extractedElements] = *reductionInfo;
     Location loc = reduceOp->getLoc();
+
+    // Handle boolean reductions with spirv.Any / spirv.All.
+    if (resultType.isInteger(1)) {
+      auto kind = reduceOp.getKind();
+      if (kind == vector::CombiningKind::OR ||
+          kind == vector::CombiningKind::AND) {
+        Value result;
+        if (kind == vector::CombiningKind::OR)
+          result = spirv::AnyOp::create(rewriter, loc, resultType,
+                                        adaptor.getVector());
+        else
+          result = spirv::AllOp::create(rewriter, loc, resultType,
+                                        adaptor.getVector());
+        if (Value acc = adaptor.getAcc()) {
+          if (kind == vector::CombiningKind::OR)
+            result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
+                                                result, acc);
+          else
+            result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+                                                 result, acc);
+        }
+        rewriter.replaceOp(reduceOp, result);
+        return success();
+      }
+    }
+
     Value result = extractedElements.front();
     for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 9eaf99203836a..48a1298bc4877 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -955,6 +955,52 @@ func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_or_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_or_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <or>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_and_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <and>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalOr %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_or_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <or>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalAnd %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_and_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <and>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: func @reduction_bf16_addf_mulf

>From 83e97415f1061a39dcc34c4a908e925284b2080c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 21 Apr 2026 09:23:11 +0200
Subject: [PATCH 2/3] Address comments

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 36 +++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3709589d46c56..7c856aab3ffe6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -419,24 +419,24 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
     // Handle boolean reductions with spirv.Any / spirv.All.
     if (resultType.isInteger(1)) {
-      auto kind = reduceOp.getKind();
-      if (kind == vector::CombiningKind::OR ||
-          kind == vector::CombiningKind::AND) {
-        Value result;
-        if (kind == vector::CombiningKind::OR)
-          result = spirv::AnyOp::create(rewriter, loc, resultType,
-                                        adaptor.getVector());
-        else
-          result = spirv::AllOp::create(rewriter, loc, resultType,
-                                        adaptor.getVector());
-        if (Value acc = adaptor.getAcc()) {
-          if (kind == vector::CombiningKind::OR)
-            result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
-                                                result, acc);
-          else
-            result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
-                                                 result, acc);
-        }
+      vector::CombiningKind kind = reduceOp.getKind();
+
+      if (kind == vector::CombiningKind::OR) {
+        Value result = spirv::AnyOp::create(rewriter, loc, resultType,
+                                            adaptor.getVector());
+        if (Value acc = adaptor.getAcc())
+          result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
+                                              result, acc);
+        rewriter.replaceOp(reduceOp, result);
+        return success();
+      }
+
+      if (kind == vector::CombiningKind::AND) {
+        Value result = spirv::AllOp::create(rewriter, loc, resultType,
+                                            adaptor.getVector());
+        if (Value acc = adaptor.getAcc())
+          result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+                                               result, acc);
         rewriter.replaceOp(reduceOp, result);
         return success();
       }

>From 28f09035777670a8a6a05879a0019cdf79b1258c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 21 Apr 2026 09:26:22 +0200
Subject: [PATCH 3/3] fmt

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 7c856aab3ffe6..921075736e97b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -425,8 +425,8 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
         Value result = spirv::AnyOp::create(rewriter, loc, resultType,
                                             adaptor.getVector());
         if (Value acc = adaptor.getAcc())
-          result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
-                                              result, acc);
+          result = spirv::LogicalOrOp::create(rewriter, loc, resultType, result,
+                                              acc);
         rewriter.replaceOp(reduceOp, result);
         return success();
       }



More information about the Mlir-commits mailing list