[Mlir-commits] [mlir] [mlir][SPIR-V] Add spirv.Any and spirv.All ops and lower vector boolean reductions (PR #192267)

Arseniy Obolenskiy llvmlistbot at llvm.org
Wed Apr 15 07:30:37 PDT 2026


https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/192267

Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and> on boolean vectors to spirv.Any/spirv.All

>From 15b36ec2148e46afc5cf5bc229163ee779a0e26d Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 15 Apr 2026 16:28:14 +0200
Subject: [PATCH] [mlir][SPIR-V] Add spirv.Any and spirv.All ops and lower
 vector boolean reductions to them

Define SPIR-V OpAny/OpAll operations, lower vector.reduction <or>/<and> on boolean vectors to spirv.Any/spirv.All
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  3 +
 .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td  | 64 +++++++++++++++
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 26 +++++++
 .../VectorToSPIRV/vector-to-spirv.mlir        | 46 +++++++++++
 mlir/test/Dialect/SPIRV/IR/logical-ops.mlir   | 78 +++++++++++++++++++
 5 files changed, 217 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 297dde3a67b2a..27abd13b8ddb1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4512,6 +4512,8 @@ def SPIRV_OC_OpIAddCarry                      : I32EnumAttrCase<"OpIAddCarry", 1
 def SPIRV_OC_OpISubBorrow                     : I32EnumAttrCase<"OpISubBorrow", 150>;
 def SPIRV_OC_OpUMulExtended                   : I32EnumAttrCase<"OpUMulExtended", 151>;
 def SPIRV_OC_OpSMulExtended                   : I32EnumAttrCase<"OpSMulExtended", 152>;
+def SPIRV_OC_OpAny                            : I32EnumAttrCase<"OpAny", 154>;
+def SPIRV_OC_OpAll                            : I32EnumAttrCase<"OpAll", 155>;
 def SPIRV_OC_OpIsNan                          : I32EnumAttrCase<"OpIsNan", 156>;
 def SPIRV_OC_OpIsInf                          : I32EnumAttrCase<"OpIsInf", 157>;
 def SPIRV_OC_OpIsFinite                       : I32EnumAttrCase<"OpIsFinite", 158>;
@@ -4714,6 +4716,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpOuterProduct, SPIRV_OC_OpDot,
       SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
+      SPIRV_OC_OpAny, SPIRV_OC_OpAll,
       SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite,
       SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
       SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 179219042c882..4de76ca884534 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -48,6 +48,70 @@ class SPIRV_LogicalUnaryOp<string mnemonic, Type operandType,
 
 // -----
 
+def SPIRV_AnyOp : SPIRV_Op<"Any", [Pure]> {
+  let summary = "Result is true if any component of Vector is true.";
+
+  let description = [{
+    Result Type must be a Boolean type scalar.
+
+    Vector must be a vector of Boolean type.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.Any %vector : vector<4xi1>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_Bool>:$vector
+  );
+
+  let results = (outs
+    SPIRV_Bool:$result
+  );
+
+  let assemblyFormat = "$vector attr-dict `:` type($vector)";
+
+  let hasVerifier = 0;
+}
+
+// -----
+
+def SPIRV_AllOp : SPIRV_Op<"All", [Pure]> {
+  let summary = "Result is true if all components of Vector are true.";
+
+  let description = [{
+    Result Type must be a Boolean type scalar.
+
+    Vector must be a vector of Boolean type.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.All %vector : vector<4xi1>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_Bool>:$vector
+  );
+
+  let results = (outs
+    SPIRV_Bool:$result
+  );
+
+  let assemblyFormat = "$vector attr-dict `:` type($vector)";
+
+  let hasVerifier = 0;
+}
+
+// -----
+
 def SPIRV_FOrdEqualOp : SPIRV_LogicalBinaryOp<"FOrdEqual", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point comparison for being ordered and equal.";
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c101a95685a25..3544a5d8fd7a8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -416,6 +416,32 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
     auto [resultType, extractedElements] = *reductionInfo;
     Location loc = reduceOp->getLoc();
+
+    // Handle boolean reductions with spirv.Any / spirv.All.
+    if (resultType.isInteger(1)) {
+      auto kind = reduceOp.getKind();
+      if (kind == vector::CombiningKind::OR ||
+          kind == vector::CombiningKind::AND) {
+        Value result;
+        if (kind == vector::CombiningKind::OR)
+          result = spirv::AnyOp::create(rewriter, loc, resultType,
+                                        adaptor.getVector());
+        else
+          result = spirv::AllOp::create(rewriter, loc, resultType,
+                                        adaptor.getVector());
+        if (Value acc = adaptor.getAcc()) {
+          if (kind == vector::CombiningKind::OR)
+            result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
+                                                result, acc);
+          else
+            result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
+                                                 result, acc);
+        }
+        rewriter.replaceOp(reduceOp, result);
+        return success();
+      }
+    }
+
     Value result = extractedElements.front();
     for (Value next : llvm::drop_begin(extractedElements)) {
       switch (reduceOp.getKind()) {
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c399250151261..64977e8e07f40 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -856,6 +856,52 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: func @reduction_or_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_or_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <or>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   return %[[R]]
+func.func @reduction_and_bool(%v : vector<4xi1>) -> i1 {
+  %reduce = vector.reduction <and>, %v : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_or_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.Any %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalOr %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_or_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <or>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_and_bool_acc
+//  CHECK-SAME: (%[[V:.+]]: vector<4xi1>, %[[S:.+]]: i1)
+//       CHECK:   %[[R:.+]] = spirv.All %[[V]] : vector<4xi1>
+//       CHECK:   %[[A:.+]] = spirv.LogicalAnd %[[R]], %[[S]]
+//       CHECK:   return %[[A]]
+func.func @reduction_and_bool_acc(%v : vector<4xi1>, %s: i1) -> i1 {
+  %reduce = vector.reduction <and>, %v, %s : vector<4xi1> into i1
+  return %reduce : i1
+}
+
+// -----
+
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
 
 // CHECK-LABEL: func @reduction_bf16_addf_mulf
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 1018751cf65e0..d4108b6285044 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -1,5 +1,83 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// spirv.Any
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @any_vector2
+func.func @any_vector2(%arg0: vector<2xi1>) -> i1 {
+  // CHECK: spirv.Any %{{.*}} : vector<2xi1>
+  %0 = spirv.Any %arg0 : vector<2xi1>
+  return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @any_vector3
+func.func @any_vector3(%arg0: vector<3xi1>) -> i1 {
+  // CHECK: spirv.Any %{{.*}} : vector<3xi1>
+  %0 = spirv.Any %arg0 : vector<3xi1>
+  return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @any_vector4
+func.func @any_vector4(%arg0: vector<4xi1>) -> i1 {
+  // CHECK: spirv.Any %{{.*}} : vector<4xi1>
+  %0 = spirv.Any %arg0 : vector<4xi1>
+  return %0 : i1
+}
+
+// -----
+
+func.func @any_scalar(%arg0: i1) -> i1 {
+  // expected-error @+1 {{'spirv.Any' op operand #0 must be fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
+  %0 = "spirv.Any"(%arg0) : (i1) -> i1
+  return %0 : i1
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.All
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @all_vector2
+func.func @all_vector2(%arg0: vector<2xi1>) -> i1 {
+  // CHECK: spirv.All %{{.*}} : vector<2xi1>
+  %0 = spirv.All %arg0 : vector<2xi1>
+  return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @all_vector3
+func.func @all_vector3(%arg0: vector<3xi1>) -> i1 {
+  // CHECK: spirv.All %{{.*}} : vector<3xi1>
+  %0 = spirv.All %arg0 : vector<3xi1>
+  return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @all_vector4
+func.func @all_vector4(%arg0: vector<4xi1>) -> i1 {
+  // CHECK: spirv.All %{{.*}} : vector<4xi1>
+  %0 = spirv.All %arg0 : vector<4xi1>
+  return %0 : i1
+}
+
+// -----
+
+func.func @all_scalar(%arg0: i1) -> i1 {
+  // expected-error @+1 {{'spirv.All' op operand #0 must be fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
+  %0 = "spirv.All"(%arg0) : (i1) -> i1
+  return %0 : i1
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IEqual
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list