[Mlir-commits] [mlir] e7a2cf1 - [mlir][SPIR-V] Lower boolean vector reductions (#192267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 21 02:12:33 PDT 2026


Author: Arseniy Obolenskiy
Date: 2026-04-21T11:12:28+02:00
New Revision: e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7

URL: https://github.com/llvm/llvm-project/commit/e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7
DIFF: https://github.com/llvm/llvm-project/commit/e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7.diff

LOG: [mlir][SPIR-V] Lower boolean vector reductions (#192267)

Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and>
on boolean vectors to spirv.Any/spirv.All

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2149892cc603d..921075736e97b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -416,6 +416,32 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
     auto [resultType, extractedElements] = *reductionInfo;
     Location loc = reduceOp->getLoc();
+
+    // Handle boolean reductions with spirv.Any / spirv.All.
+    if (resultType.isInteger(1)) {
+      vector::CombiningKind kind = reduceOp.getKind();
+
+      if (kind == vector::CombiningKind::OR) {
+        Value result = spirv::AnyOp::create(rewriter, loc, resultType,
+                                            adaptor.getVector());
+        if (Value acc = adaptor.getAcc())
+          result = spirv::LogicalOrOp::create(rewriter, loc, resultType, result,
+                                              acc);
+        rewriter.replaceOp(reduceOp, result);
+        return success();
+      }
+
+      if (kind == vector::CombiningKind::AND) {
+        Value result = spirv::AllOp::create(rewriter, loc, resultType,
+                                            adaptor.getVector());
+        if (Value acc = adaptor.getAcc())
+          result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+                                               result, acc);
+        rewriter.replaceOp(reduceOp, result);
+        return success();
+      }
+    }
+
     Value result = extractedElements.front();
     for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 9eaf99203836a..48a1298bc4877 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -955,6 +955,52 @@ func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_or_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_or_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <or>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_and_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <and>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalOr %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_or_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <or>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalAnd %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_and_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <and>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: func @reduction_bf16_addf_mulf


        


More information about the Mlir-commits mailing list