[Mlir-commits] [mlir] [mlir][SPIR-V] Lower AND/OR/XOR vector reductions (PR #192293)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Fri Apr 17 03:39:09 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/192293
>From e6b1de6fb1b3e7446fa839a85c2e9e8fd615da0c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 15 Apr 2026 19:18:57 +0200
Subject: [PATCH 1/2] [mlir][SPIR-V] Lower AND/OR/XOR vector reductions
Lower vector.reduction <and>, <or>, and <xor> to spirv.BitwiseAnd, spirv.BitwiseOr, and spirv.BitwiseXor respectively
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 15 ++-
.../VectorToSPIRV/vector-to-spirv.mlir | 99 +++++++++++++++++++
2 files changed, 110 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c101a95685a25..3bf5eb67c70fb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -442,10 +442,17 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
- case vector::CombiningKind::AND:
- case vector::CombiningKind::OR:
- case vector::CombiningKind::XOR:
- return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+#define INT_CASE(kind, iop) \
+ case vector::CombiningKind::kind: \
+ assert(isa<IntegerType>(resultType)); \
+ result = spirv::iop::create(rewriter, loc, resultType, result, next); \
+ break
+
+ INT_CASE(AND, BitwiseAndOp);
+ INT_CASE(OR, BitwiseOrOp);
+ INT_CASE(XOR, BitwiseXorOp);
+
+#undef INT_CASE
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c399250151261..9eaf99203836a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -856,6 +856,105 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
// -----
+// CHECK-LABEL: func @reduction_and
+// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+// CHECK: %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]]
+// CHECK: %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]]
+// CHECK: %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S3]]
+// CHECK: return %[[AND2]]
+func.func @reduction_and(%v : vector<4xi32>) -> i32 {
+ %reduce = vector.reduction <and>, %v : vector<4xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_acc
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]]
+// CHECK: %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]]
+// CHECK: %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S]]
+// CHECK: return %[[AND2]]
+func.func @reduction_and_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <and>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or
+// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+// CHECK: %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]]
+// CHECK: %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]]
+// CHECK: %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S3]]
+// CHECK: return %[[OR2]]
+func.func @reduction_or(%v : vector<4xi32>) -> i32 {
+ %reduce = vector.reduction <or>, %v : vector<4xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_acc
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]]
+// CHECK: %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]]
+// CHECK: %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S]]
+// CHECK: return %[[OR2]]
+func.func @reduction_or_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <or>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_xor
+// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+// CHECK: %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]]
+// CHECK: %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]]
+// CHECK: %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S3]]
+// CHECK: return %[[XOR2]]
+func.func @reduction_xor(%v : vector<4xi32>) -> i32 {
+ %reduce = vector.reduction <xor>, %v : vector<4xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_xor_acc
+// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+// CHECK: %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]]
+// CHECK: %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]]
+// CHECK: %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S]]
+// CHECK: return %[[XOR2]]
+func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+ %reduce = vector.reduction <xor>, %v, %s : vector<3xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
// CHECK-LABEL: func @reduction_bf16_addf_mulf
>From eff0b7adda19fdc7d9f0719512be727b43adbc04 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Fri, 17 Apr 2026 12:38:57 +0200
Subject: [PATCH 2/2] Address part of review comments
---
.../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3bf5eb67c70fb..cac5932607871 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -435,6 +435,12 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
result = fop::create(rewriter, loc, resultType, result, next); \
break
+#define INT_CASE(kind, iop) \
+ case vector::CombiningKind::kind: \
+ assert(isa<IntegerType>(resultType)); \
+ result = spirv::iop::create(rewriter, loc, resultType, result, next); \
+ break
+
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
@@ -442,22 +448,16 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
-#define INT_CASE(kind, iop) \
- case vector::CombiningKind::kind: \
- assert(isa<IntegerType>(resultType)); \
- result = spirv::iop::create(rewriter, loc, resultType, result, next); \
- break
-
INT_CASE(AND, BitwiseAndOp);
INT_CASE(OR, BitwiseOrOp);
INT_CASE(XOR, BitwiseXorOp);
-#undef INT_CASE
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
#undef INT_AND_FLOAT_CASE
#undef INT_OR_FLOAT_CASE
+#undef INT_CASE
}
rewriter.replaceOp(reduceOp, result);
More information about the Mlir-commits
mailing list