[Mlir-commits] [mlir] e7a2cf1 - [mlir][SPIR-V] Lower boolean vector reductions (#192267)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 21 02:12:33 PDT 2026
Author: Arseniy Obolenskiy
Date: 2026-04-21T11:12:28+02:00
New Revision: e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7
URL: https://github.com/llvm/llvm-project/commit/e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7
DIFF: https://github.com/llvm/llvm-project/commit/e7a2cf1243ba3ece7e41fddbe63f6fa810bc55c7.diff
LOG: [mlir][SPIR-V] Lower boolean vector reductions (#192267)
Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and>
on boolean vectors to spirv.Any/spirv.All
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2149892cc603d..921075736e97b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -416,6 +416,32 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
+
+ // Handle boolean reductions with spirv.Any / spirv.All.
+ if (resultType.isInteger(1)) {
+ vector::CombiningKind kind = reduceOp.getKind();
+
+ if (kind == vector::CombiningKind::OR) {
+ Value result = spirv::AnyOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ if (Value acc = adaptor.getAcc())
+ result = spirv::LogicalOrOp::create(rewriter, loc, resultType, result,
+ acc);
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+
+ if (kind == vector::CombiningKind::AND) {
+ Value result = spirv::AllOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ if (Value acc = adaptor.getAcc())
+ result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+ result, acc);
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+ }
+
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 9eaf99203836a..48a1298bc4877 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -955,6 +955,52 @@ func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
// -----
+// CHECK-LABEL: func @reduction_or_bool
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+// CHECK: %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+// CHECK: return %[[R]]
+func.func @reduction_or_bool(%v : vector<4xi1>) -> i1 {
+ %reduce = vector.reduction <or>, %v : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+// CHECK: %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+// CHECK: return %[[R]]
+func.func @reduction_and_bool(%v : vector<4xi1>) -> i1 {
+ %reduce = vector.reduction <and>, %v : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_bool_acc
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+// CHECK: %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+// CHECK: %[[A:.+]] = spirv.LogicalOr %[[R]], %[[S]]
+// CHECK: return %[[A]]
+func.func @reduction_or_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+ %reduce = vector.reduction <or>, %v, %s : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool_acc
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+// CHECK: %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+// CHECK: %[[A:.+]] = spirv.LogicalAnd %[[R]], %[[S]]
+// CHECK: return %[[A]]
+func.func @reduction_and_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+ %reduce = vector.reduction <and>, %v, %s : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
// CHECK-LABEL: func @reduction_bf16_addf_mulf
More information about the Mlir-commits
mailing list