[Mlir-commits] [mlir] 81c4ceb - [mlir][SPIR-V] Add spirv.Any and spirv.All ops (#192286)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 11:14:30 PDT 2026
Author: Arseniy Obolenskiy
Date: 2026-04-16T20:14:24+02:00
New Revision: 81c4ceb90239098e60d706f0a68f68d4dacec7af
URL: https://github.com/llvm/llvm-project/commit/81c4ceb90239098e60d706f0a68f68d4dacec7af
DIFF: https://github.com/llvm/llvm-project/commit/81c4ceb90239098e60d706f0a68f68d4dacec7af.diff
LOG: [mlir][SPIR-V] Add spirv.Any and spirv.All ops (#192286)
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
mlir/test/Target/SPIRV/logical-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 297dde3a67b2a..27abd13b8ddb1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4512,6 +4512,8 @@ def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 1
def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>;
def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>;
+def SPIRV_OC_OpAny : I32EnumAttrCase<"OpAny", 154>;
+def SPIRV_OC_OpAll : I32EnumAttrCase<"OpAll", 155>;
def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>;
@@ -4714,6 +4716,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpOuterProduct, SPIRV_OC_OpDot,
SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
+ SPIRV_OC_OpAny, SPIRV_OC_OpAll,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite,
SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 179219042c882..4de76ca884534 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -48,6 +48,70 @@ class SPIRV_LogicalUnaryOp<string mnemonic, Type operandType,
// -----
+def SPIRV_AnyOp : SPIRV_Op<"Any", [Pure]> {
+ let summary = "Result is true if any component of Vector is true.";
+
+ let description = [{
+ Result Type must be a Boolean type scalar.
+
+ Vector must be a vector of Boolean type.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.Any %vector : vector<4xi1>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_Bool>:$vector
+ );
+
+ let results = (outs
+ SPIRV_Bool:$result
+ );
+
+ let assemblyFormat = "$vector attr-dict `:` type($vector)";
+
+ let hasVerifier = 0;
+}
+
+// -----
+
+def SPIRV_AllOp : SPIRV_Op<"All", [Pure]> {
+ let summary = "Result is true if all components of Vector are true.";
+
+ let description = [{
+ Result Type must be a Boolean type scalar.
+
+ Vector must be a vector of Boolean type.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.All %vector : vector<4xi1>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_Bool>:$vector
+ );
+
+ let results = (outs
+ SPIRV_Bool:$result
+ );
+
+ let assemblyFormat = "$vector attr-dict `:` type($vector)";
+
+ let hasVerifier = 0;
+}
+
+// -----
+
def SPIRV_FOrdEqualOp : SPIRV_LogicalBinaryOp<"FOrdEqual", SPIRV_Float, [Commutative]> {
let summary = "Floating-point comparison for being ordered and equal.";
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 1018751cf65e0..d8d80ddfae097 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -1,5 +1,99 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// spirv.Any
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @any_vector2
+func.func @any_vector2(%arg0: vector<2xi1>) -> i1 {
+ // CHECK: spirv.Any %{{.*}} : vector<2xi1>
+ %0 = spirv.Any %arg0 : vector<2xi1>
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @any_vector3
+func.func @any_vector3(%arg0: vector<3xi1>) -> i1 {
+ // CHECK: spirv.Any %{{.*}} : vector<3xi1>
+ %0 = spirv.Any %arg0 : vector<3xi1>
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @any_vector4
+func.func @any_vector4(%arg0: vector<4xi1>) -> i1 {
+ // CHECK: spirv.Any %{{.*}} : vector<4xi1>
+ %0 = spirv.Any %arg0 : vector<4xi1>
+ return %0 : i1
+}
+
+// -----
+
+func.func @any_scalar(%arg0: i1) -> i1 {
+ // expected-error @+1 {{invalid kind of type specified: expected builtin.vector, but found 'i1'}}
+ %0 = spirv.Any %arg0 : i1
+ return %0 : i1
+}
+
+// -----
+
+func.func @any_wrong_result_type(%arg0: vector<2xi1>) -> vector<2xi1> {
+ // expected-error @+1 {{'spirv.Any' op result #0 must be bool, but got 'vector<2xi1>'}}
+ %0 = "spirv.Any"(%arg0) : (vector<2xi1>) -> vector<2xi1>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.All
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @all_vector2
+func.func @all_vector2(%arg0: vector<2xi1>) -> i1 {
+ // CHECK: spirv.All %{{.*}} : vector<2xi1>
+ %0 = spirv.All %arg0 : vector<2xi1>
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @all_vector3
+func.func @all_vector3(%arg0: vector<3xi1>) -> i1 {
+ // CHECK: spirv.All %{{.*}} : vector<3xi1>
+ %0 = spirv.All %arg0 : vector<3xi1>
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @all_vector4
+func.func @all_vector4(%arg0: vector<4xi1>) -> i1 {
+ // CHECK: spirv.All %{{.*}} : vector<4xi1>
+ %0 = spirv.All %arg0 : vector<4xi1>
+ return %0 : i1
+}
+
+// -----
+
+func.func @all_scalar(%arg0: i1) -> i1 {
+ // expected-error @+1 {{invalid kind of type specified: expected builtin.vector, but found 'i1'}}
+ %0 = spirv.All %arg0 : i1
+ return %0 : i1
+}
+
+// -----
+
+func.func @all_wrong_result_type(%arg0: vector<2xi1>) -> vector<2xi1> {
+ // expected-error @+1 {{'spirv.All' op result #0 must be bool, but got 'vector<2xi1>'}}
+ %0 = "spirv.All"(%arg0) : (vector<2xi1>) -> vector<2xi1>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IEqual
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index d570815448b7c..83459b980a1d7 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -6,6 +6,16 @@
// RUN: %if spirv-tools %{ spirv-val %t %}
spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], []> {
+ spirv.func @any_vector(%arg0: vector<4xi1>) "None" {
+ // CHECK: {{.*}} = spirv.Any {{.*}} : vector<4xi1>
+ %0 = spirv.Any %arg0 : vector<4xi1>
+ spirv.Return
+ }
+ spirv.func @all_vector(%arg0: vector<4xi1>) "None" {
+ // CHECK: {{.*}} = spirv.All {{.*}} : vector<4xi1>
+ %0 = spirv.All %arg0 : vector<4xi1>
+ spirv.Return
+ }
spirv.func @iequal_scalar(%arg0: i32, %arg1: i32) "None" {
// CHECK: {{.*}} = spirv.IEqual {{.*}}, {{.*}} : i32
%0 = spirv.IEqual %arg0, %arg1 : i32
More information about the Mlir-commits
mailing list