[Mlir-commits] [mlir] [mlir][SPIR-V] Lower boolean vector reductions (PR #192267)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Tue Apr 21 00:26:36 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/192267
>From 846b813e7b96bf10696e8d4a03708141bb90ffcf Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 15 Apr 2026 16:28:14 +0200
Subject: [PATCH 1/3] [mlir][SPIR-V] Add spirv.Any and spirv.All ops and lower
vector boolean reductions to them
Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and> on boolean vectors to spirv.Any/spirv.All
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 26 +++++++++++
.../VectorToSPIRV/vector-to-spirv.mlir | 46 +++++++++++++++++++
2 files changed, 72 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2149892cc603d..3709589d46c56 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -416,6 +416,32 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
+
+ // Handle boolean reductions with spirv.Any / spirv.All.
+ if (resultType.isInteger(1)) {
+ auto kind = reduceOp.getKind();
+ if (kind == vector::CombiningKind::OR ||
+ kind == vector::CombiningKind::AND) {
+ Value result;
+ if (kind == vector::CombiningKind::OR)
+ result = spirv::AnyOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ else
+ result = spirv::AllOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ if (Value acc = adaptor.getAcc()) {
+ if (kind == vector::CombiningKind::OR)
+ result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
+ result, acc);
+ else
+ result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+ result, acc);
+ }
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+ }
+
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 9eaf99203836a..48a1298bc4877 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -955,6 +955,52 @@ func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
// -----
+// CHECK-LABEL: func @reduction_or_bool
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+// CHECK: %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+// CHECK: return %[[R]]
+func.func @reduction_or_bool(%v : vector<4xi1>) -> i1 {
+ %reduce = vector.reduction <or>, %v : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+// CHECK: %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+// CHECK: return %[[R]]
+func.func @reduction_and_bool(%v : vector<4xi1>) -> i1 {
+ %reduce = vector.reduction <and>, %v : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_bool_acc
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+// CHECK: %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+// CHECK: %[[A:.+]] = spirv.LogicalOr %[[R]], %[[S]]
+// CHECK: return %[[A]]
+func.func @reduction_or_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+ %reduce = vector.reduction <or>, %v, %s : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool_acc
+// CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+// CHECK: %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+// CHECK: %[[A:.+]] = spirv.LogicalAnd %[[R]], %[[S]]
+// CHECK: return %[[A]]
+func.func @reduction_and_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+ %reduce = vector.reduction <and>, %v, %s : vector<4xi1> into i1
+ return %reduce : i1
+}
+
+// -----
+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
// CHECK-LABEL: func @reduction_bf16_addf_mulf
>From 83e97415f1061a39dcc34c4a908e925284b2080c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 21 Apr 2026 09:23:11 +0200
Subject: [PATCH 2/3] Address comments
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3709589d46c56..7c856aab3ffe6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -419,24 +419,24 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
// Handle boolean reductions with spirv.Any / spirv.All.
if (resultType.isInteger(1)) {
- auto kind = reduceOp.getKind();
- if (kind == vector::CombiningKind::OR ||
- kind == vector::CombiningKind::AND) {
- Value result;
- if (kind == vector::CombiningKind::OR)
- result = spirv::AnyOp::create(rewriter, loc, resultType,
- adaptor.getVector());
- else
- result = spirv::AllOp::create(rewriter, loc, resultType,
- adaptor.getVector());
- if (Value acc = adaptor.getAcc()) {
- if (kind == vector::CombiningKind::OR)
- result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
- result, acc);
- else
- result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
- result, acc);
- }
+ vector::CombiningKind kind = reduceOp.getKind();
+
+ if (kind == vector::CombiningKind::OR) {
+ Value result = spirv::AnyOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ if (Value acc = adaptor.getAcc())
+ result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
+ result, acc);
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+
+ if (kind == vector::CombiningKind::AND) {
+ Value result = spirv::AllOp::create(rewriter, loc, resultType,
+ adaptor.getVector());
+ if (Value acc = adaptor.getAcc())
+ result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+ result, acc);
rewriter.replaceOp(reduceOp, result);
return success();
}
>From 28f09035777670a8a6a05879a0019cdf79b1258c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 21 Apr 2026 09:26:22 +0200
Subject: [PATCH 3/3] fmt
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 7c856aab3ffe6..921075736e97b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -425,8 +425,8 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
Value result = spirv::AnyOp::create(rewriter, loc, resultType,
adaptor.getVector());
if (Value acc = adaptor.getAcc())
- result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
- result, acc);
+ result = spirv::LogicalOrOp::create(rewriter, loc, resultType, result,
+ acc);
rewriter.replaceOp(reduceOp, result);
return success();
}
More information about the Mlir-commits
mailing list