[Mlir-commits] [mlir] [mlir][SPIR-V] Lower AND/OR/XOR vector reductions (PR #192293)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 10:21:02 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Arseniy Obolenskiy (aobolensk)

<details>
<summary>Changes</summary>

Lower vector.reduction <and>, <or>, and <xor> to spirv.BitwiseAnd, spirv.BitwiseOr, and spirv.BitwiseXor respectively

---
Full diff: https://github.com/llvm/llvm-project/pull/192293.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+11-4) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+99) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c101a95685a25..3bf5eb67c70fb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -442,10 +442,17 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
         INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
 
-      case vector::CombiningKind::AND:
-      case vector::CombiningKind::OR:
-      case vector::CombiningKind::XOR:
-        return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+#define INT_CASE(kind, iop)                                                    \
+  case vector::CombiningKind::kind:                                            \
+    assert(isa<IntegerType>(resultType));                                      \
+    result = spirv::iop::create(rewriter, loc, resultType, result, next);      \
+    break
+
+        INT_CASE(AND, BitwiseAndOp);
+        INT_CASE(OR, BitwiseOrOp);
+        INT_CASE(XOR, BitwiseXorOp);
+
+#undef INT_CASE
       default:
         return rewriter.notifyMatchFailure(reduceOp, "not handled here");
       }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c399250151261..9eaf99203836a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -856,6 +856,105 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_and
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+//       CHECK:   %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+//       CHECK:   %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]]
+//       CHECK:   %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]]
+//       CHECK:   %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S3]]
+//       CHECK:   return %[[AND2]]
+func.func @reduction_and(%v : vector<4xi32>) -> i32 {
+  %reduce = vector.reduction <and>, %v : vector<4xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[AND0:.+]] = spirv.BitwiseAnd %[[S0]], %[[S1]]
+//       CHECK:   %[[AND1:.+]] = spirv.BitwiseAnd %[[AND0]], %[[S2]]
+//       CHECK:   %[[AND2:.+]] = spirv.BitwiseAnd %[[AND1]], %[[S]]
+//       CHECK:   return %[[AND2]]
+func.func @reduction_and_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <and>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+//       CHECK:   %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+//       CHECK:   %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]]
+//       CHECK:   %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]]
+//       CHECK:   %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S3]]
+//       CHECK:   return %[[OR2]]
+func.func @reduction_or(%v : vector<4xi32>) -> i32 {
+  %reduce = vector.reduction <or>, %v : vector<4xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[OR0:.+]] = spirv.BitwiseOr %[[S0]], %[[S1]]
+//       CHECK:   %[[OR1:.+]] = spirv.BitwiseOr %[[OR0]], %[[S2]]
+//       CHECK:   %[[OR2:.+]] = spirv.BitwiseOr %[[OR1]], %[[S]]
+//       CHECK:   return %[[OR2]]
+func.func @reduction_or_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <or>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_xor
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+//       CHECK:   %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+//       CHECK:   %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]]
+//       CHECK:   %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]]
+//       CHECK:   %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S3]]
+//       CHECK:   return %[[XOR2]]
+func.func @reduction_xor(%v : vector<4xi32>) -> i32 {
+  %reduce = vector.reduction <xor>, %v : vector<4xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_xor_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[XOR0:.+]] = spirv.BitwiseXor %[[S0]], %[[S1]]
+//       CHECK:   %[[XOR1:.+]] = spirv.BitwiseXor %[[XOR0]], %[[S2]]
+//       CHECK:   %[[XOR2:.+]] = spirv.BitwiseXor %[[XOR1]], %[[S]]
+//       CHECK:   return %[[XOR2]]
+func.func @reduction_xor_acc(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <xor>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: func @reduction_bf16_addf_mulf

``````````

</details>


https://github.com/llvm/llvm-project/pull/192293


More information about the Mlir-commits mailing list