[Mlir-commits] [mlir] [mlir][spirv] Add reduction ops in TOSA Ext Inst Set (PR #187278)

Davide Grohmann llvmlistbot at llvm.org
Wed Mar 18 07:00:37 PDT 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/187278

This patch introduces the following reduction operators:

spirv.Tosa.ReduceAll
spirv.Tosa.ReduceAny
spirv.Tosa.ReduceMax
spirv.Tosa.ReduceMin
spirv.Tosa.ReduceProduct
spirv.Tosa.ReduceSum

Also dialect and serialization round-trip tests have been added.

>From 0d6e561f08a1405e030117e18f9740c236c83ad0 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 28 Jan 2026 13:39:57 +0100
Subject: [PATCH] [mlir][spirv] Add reduction ops in TOSA Ext Inst Set

This patch introduces the following reduction operators:

spirv.Tosa.ReduceAll
spirv.Tosa.ReduceAny
spirv.Tosa.ReduceMax
spirv.Tosa.ReduceMin
spirv.Tosa.ReduceProduct
spirv.Tosa.ReduceSum

Also dialect and serialization round-trip tests have been added.

Change-Id: I457be7a7072c9674230aa71bb538079122c0677f
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 220 ++++++++++++++++++
 .../SPIRV/IR/tosa-ops-verification.mlir       | 120 ++++++++++
 mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir      |  99 ++++++++
 mlir/test/Target/SPIRV/tosa-ops.mlir          | 171 ++++++++++++++
 4 files changed, 610 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index bc26e0bb9cb0a..9eab69986e4e3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -213,6 +213,19 @@ class SPIRV_TosaComparisonOp<string mnemonic, int opcode, list<Trait> traits = [
   }];
 }
 
+class SPIRV_TosaReductionOp<string mnemonic, int opcode, list<Trait> traits = []> :
+  SPIRV_TosaOpWithResult<mnemonic, opcode, !listconcat(traits, [Pure,
+  AllElementTypesMatch<["input", "output"]>,
+  AllRanksMatch<["input", "output"]>,
+  AxisValueLessThanRankOf<"input">])> {
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
 def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
   OutputRankIsInputRankMinusOne<"input", "output">,
   AxisValueLessThanRankOf<"input">]> {
@@ -1902,4 +1915,211 @@ def SPIRV_TosaGreaterEqualOp : SPIRV_TosaComparisonOp<"GreaterEqual", 47> {
 }
 
 
+def SPIRV_TosaReduceAllOp : SPIRV_TosaReductionOp<"ReduceAll", 48> {
+  let summary = "Reduce All operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis with a Logical AND operation.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_all
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_all
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x1x12xi1>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_Bool_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_Bool_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
+def SPIRV_TosaReduceAnyOp : SPIRV_TosaReductionOp<"ReduceAny", 49> {
+  let summary = "Reduce Any operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis with a Logical OR operation.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_any
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_any
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x1x8xi1>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_Bool_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_Bool_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
+def SPIRV_TosaReduceMaxOp : SPIRV_TosaReductionOp<"ReduceMax", 50> {
+  let summary = "Reduce Max operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis with a Maximum operation.
+    NaN Propagation Mode is ignored for inputs using integer element types.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_max
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_max
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x1x3xi8>
+    %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x1xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    `nan_mode` `=` $nan_mode `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
+def SPIRV_TosaReduceMinOp : SPIRV_TosaReductionOp<"ReduceMin", 51> {
+  let summary = "Reduce Min operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis with a Minimum operation.
+    NaN Propagation Mode is ignored for inputs using integer element types.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_min
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_min
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1x1xi8>
+    %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x1x9xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    `nan_mode` `=` $nan_mode `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
+def SPIRV_TosaReduceProductOp : SPIRV_TosaReductionOp<"ReduceProduct", 52> {
+  let summary = "Reduce Product operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis by computing the Product of the axis.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_product
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_product
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x1xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_TosaFloat_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaFloat_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
+def SPIRV_TosaReduceSumOp : SPIRV_TosaReductionOp<"ReduceSum", 53> {
+  let summary = "Reduce Sum operator.";
+
+  let description = [{
+    Reduces a tensor along the given axis by computing the Sum of the axis.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reduce_sum
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_sum
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x1x22xi32>
+    %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x1x33xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    SPIRV_TosaNumerical_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaNumerical_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+}
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 7b9c88fe60d80..28526019551fc 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1564,3 +1564,123 @@ spirv.ARM.Graph @greaterequal_output_shape_not_broadcast_shape(%arg0: !spirv.arm
   %0 = spirv.Tosa.GreaterEqual %arg0, %arg1 : !spirv.arm.tensor<10x17x7x16xf32>, !spirv.arm.tensor<1x17x7x16xf32> -> !spirv.arm.tensor<1x17x7x16xi1>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x17x7x16xi1>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAll
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceall_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<18x22x23x12xi1>) -> (!spirv.arm.tensor<18x22x12xi1>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x12xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x22x12xi1>
+}
+
+spirv.ARM.Graph @reduceall_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<18x22x23x12xi1>) -> (!spirv.arm.tensor<18x22x23x12xi1>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceAll axis = 4, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x23x12xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<18x22x23x12xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAny
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceany_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<25x13x30x8xi1>) -> (!spirv.arm.tensor<25x13x8xi1>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x8xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<25x13x8xi1>
+}
+
+spirv.ARM.Graph @reduceany_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<25x13x30x8xi1>) -> (!spirv.arm.tensor<25x13x30x8xi1>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceAny axis = 4, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x30x8xi1>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<25x13x30x8xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMax
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemax_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<16x20x10xf16>) -> (!spirv.arm.tensor<16x20x10xf32>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+  %0 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x10xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<16x20x10xf32>
+}
+
+spirv.ARM.Graph @reducemax_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<8x30x12x3xi8>) -> (!spirv.arm.tensor<8x30x3xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x3xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x3xi8>
+}
+
+spirv.ARM.Graph @reducemax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<8x30x12x3xi8>) -> (!spirv.arm.tensor<8x30x12x3xi8>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceMax axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x12x3xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<8x30x12x3xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMin
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemin_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<27x10x25x9xf16>) -> (!spirv.arm.tensor<27x10x25x9xf32>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+  %0 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x25x9xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xf32>
+}
+
+spirv.ARM.Graph @reducemin_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<2x5x5x1xi8>) -> (!spirv.arm.tensor<2x5x1xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x5x1xi8>
+}
+
+spirv.ARM.Graph @reducemin_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<27x10x25x9xf16>) -> (!spirv.arm.tensor<27x10x25x9xf16>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceMin axis = 4, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x25x9xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<27x10x25x9xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceProduct
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceproduct_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<2x16x25xf16>) -> (!spirv.arm.tensor<2x16x25xf32>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+  %0 = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x25xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x16x25xf32>
+}
+
+spirv.ARM.Graph @reduceproduct_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<2x16x25xf16>) -> (!spirv.arm.tensor<2x25xf16>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceProduct axis = 1, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x25xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x25xf16>
+}
+
+spirv.ARM.Graph @reduceproduct_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<2x16x25xf16>) -> (!spirv.arm.tensor<2x16x25xf16>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceProduct axis = 3, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x25xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x16x25xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceSum
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducesum_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<32x32x33xf32>) -> (!spirv.arm.tensor<32x32x33xf16>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}}
+  %0 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x32x33xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<32x32x33xf16>
+}
+
+spirv.ARM.Graph @reducesum_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<20x24x22xi32>) -> (!spirv.arm.tensor<20x22xi32>) {
+  // expected-error @+1 {{op failed to verify that all of {input, output} have same rank}}
+  %0 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x22xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x22xi32>
+}
+
+spirv.ARM.Graph @reducesum_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.tensor<20x24x22xi32>) -> (!spirv.arm.tensor<20x24x22xi32>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(input)}}
+  %0 = spirv.Tosa.ReduceSum axis = 3, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x24x22xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi32>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index a4a26cb394603..d8f495cad3b7f 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -786,3 +786,102 @@ spirv.ARM.Graph @greaterequal_fp(%arg0: !spirv.arm.tensor<3x17x6x3xf32>, %arg1:
   // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<3x17x6x3xi1>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAll - PRO-INT or PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceall_any(%arg0: !spirv.arm.tensor<18x22x23x12xi1>) -> (!spirv.arm.tensor<18x22x1x12xi1>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x1x12xi1>
+  %1 = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x1x12xi1>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x22x1x12xi1>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<18x22x1x12xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAny - PRO-INT or PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceany_any(%arg0: !spirv.arm.tensor<25x13x30x8xi1>) -> (!spirv.arm.tensor<25x13x1x8xi1>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x1x8xi1>
+  %1 = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x1x8xi1>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<25x13x1x8xi1>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<25x13x1x8xi1>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemax_int(%arg0: !spirv.arm.tensor<8x30x12x3xi8>) -> (!spirv.arm.tensor<8x30x1x3xi8>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x1x3xi8>
+  %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x1x3xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x30x1x3xi8>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x30x1x3xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemax_fp(%arg0: !spirv.arm.tensor<16x20x10xf16>) -> (!spirv.arm.tensor<16x20x1xf16>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x1xf16>
+  %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x1xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x20x1xf16>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<16x20x1xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMin - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemin_int(%arg0: !spirv.arm.tensor<2x5x5x1xi8>) -> (!spirv.arm.tensor<2x5x1x1xi8>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1x1xi8>
+  %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1x1xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x5x1x1xi8>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x5x1x1xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMin - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducemin_fp(%arg0: !spirv.arm.tensor<27x10x25x9xf16>) -> (!spirv.arm.tensor<27x10x1x9xf16>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x1x9xf16>
+  %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x1x9xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x10x1x9xf16>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<27x10x1x9xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceProduct - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reduceproduct_fp(%arg0: !spirv.arm.tensor<2x16x25xf16>) -> (!spirv.arm.tensor<2x16x1xf16>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x1xf16>
+  %1 = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x1xf16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x16x1xf16>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x16x1xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceSum - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducesum_int(%arg0: !spirv.arm.tensor<20x24x22xi32>) -> (!spirv.arm.tensor<20x1x22xi32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x1x22xi32>
+  %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x1x22xi32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<20x1x22xi32>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<20x1x22xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceSum - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reducesum_fp(%arg0: !spirv.arm.tensor<32x32x33xf32>) -> (!spirv.arm.tensor<32x1x33xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x1x33xf32>
+  %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x1x33xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<32x1x33xf32>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<32x1x33xf32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 856fde46a4866..0176411920245 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1384,3 +1384,174 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<3x17x6x3xi1>
   }
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAll - PRO-INT or PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reduceall_any_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x22x23x12xi1>, UniformConstant>
+  spirv.GlobalVariable @reduceall_any_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x22x1x12xi1>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reduceall_any, @reduceall_any_arg_0, @reduceall_any_res_0
+  spirv.ARM.Graph @reduceall_any(%arg0: !spirv.arm.tensor<18x22x23x12xi1>) -> (!spirv.arm.tensor<18x22x1x12xi1>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x1x12xi1>
+    %1 = spirv.Tosa.ReduceAll axis = 2, %arg0 : !spirv.arm.tensor<18x22x23x12xi1> -> !spirv.arm.tensor<18x22x1x12xi1>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x22x1x12xi1>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<18x22x1x12xi1>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceAny - PRO-INT or PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reduceany_any_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<25x13x30x8xi1>, UniformConstant>
+  spirv.GlobalVariable @reduceany_any_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<25x13x1x8xi1>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reduceany_any, @reduceany_any_arg_0, @reduceany_any_res_0
+  spirv.ARM.Graph @reduceany_any(%arg0: !spirv.arm.tensor<25x13x30x8xi1>) -> (!spirv.arm.tensor<25x13x1x8xi1>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x1x8xi1>
+    %1 = spirv.Tosa.ReduceAny axis = 2, %arg0 : !spirv.arm.tensor<25x13x30x8xi1> -> !spirv.arm.tensor<25x13x1x8xi1>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<25x13x1x8xi1>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<25x13x1x8xi1>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMax - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducemax_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<8x30x12x3xi8>, UniformConstant>
+  spirv.GlobalVariable @reducemax_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<8x30x1x3xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducemax_int, @reducemax_int_arg_0, @reducemax_int_res_0
+  spirv.ARM.Graph @reducemax_int(%arg0: !spirv.arm.tensor<8x30x12x3xi8>) -> (!spirv.arm.tensor<8x30x1x3xi8>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x1x3xi8>
+    %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<8x30x12x3xi8> -> !spirv.arm.tensor<8x30x1x3xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<8x30x1x3xi8>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<8x30x1x3xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMax - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducemax_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<16x20x10xf16>, UniformConstant>
+  spirv.GlobalVariable @reducemax_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<16x20x1xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducemax_fp, @reducemax_fp_arg_0, @reducemax_fp_res_0
+  spirv.ARM.Graph @reducemax_fp(%arg0: !spirv.arm.tensor<16x20x10xf16>) -> (!spirv.arm.tensor<16x20x1xf16>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x1xf16>
+    %2 = spirv.Tosa.ReduceMax axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<16x20x10xf16> -> !spirv.arm.tensor<16x20x1xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<16x20x1xf16>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<16x20x1xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMin - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducemin_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x5x5x1xi8>, UniformConstant>
+  spirv.GlobalVariable @reducemin_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x5x1x1xi8>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducemin_int, @reducemin_int_arg_0, @reducemin_int_res_0
+  spirv.ARM.Graph @reducemin_int(%arg0: !spirv.arm.tensor<2x5x5x1xi8>) -> (!spirv.arm.tensor<2x5x1x1xi8>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1x1xi8>
+    %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<2x5x5x1xi8> -> !spirv.arm.tensor<2x5x1x1xi8>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x5x1x1xi8>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x5x1x1xi8>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceMin - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducemin_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<27x10x25x9xf16>, UniformConstant>
+  spirv.GlobalVariable @reducemin_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<27x10x1x9xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducemin_fp, @reducemin_fp_arg_0, @reducemin_fp_res_0
+  spirv.ARM.Graph @reducemin_fp(%arg0: !spirv.arm.tensor<27x10x25x9xf16>) -> (!spirv.arm.tensor<27x10x1x9xf16>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x1x9xf16>
+    %2 = spirv.Tosa.ReduceMin axis = 2, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x10x25x9xf16> -> !spirv.arm.tensor<27x10x1x9xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x10x1x9xf16>
+    spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<27x10x1x9xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceProduct - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reduceproduct_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<2x16x25xf16>, UniformConstant>
+  spirv.GlobalVariable @reduceproduct_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<2x16x1xf16>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reduceproduct_fp, @reduceproduct_fp_arg_0, @reduceproduct_fp_res_0
+  spirv.ARM.Graph @reduceproduct_fp(%arg0: !spirv.arm.tensor<2x16x25xf16>) -> (!spirv.arm.tensor<2x16x1xf16>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x1xf16>
+    %1 = spirv.Tosa.ReduceProduct axis = 2, %arg0 : !spirv.arm.tensor<2x16x25xf16> -> !spirv.arm.tensor<2x16x1xf16>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x16x1xf16>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x16x1xf16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceSum - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducesum_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<20x24x22xi32>, UniformConstant>
+  spirv.GlobalVariable @reducesum_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<20x1x22xi32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducesum_int, @reducesum_int_arg_0, @reducesum_int_res_0
+  spirv.ARM.Graph @reducesum_int(%arg0: !spirv.arm.tensor<20x24x22xi32>) -> (!spirv.arm.tensor<20x1x22xi32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x1x22xi32>
+    %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x1x22xi32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<20x1x22xi32>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<20x1x22xi32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.ReduceSum - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  spirv.GlobalVariable @reducesum_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<32x32x33xf32>, UniformConstant>
+  spirv.GlobalVariable @reducesum_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<32x1x33xf32>, UniformConstant>
+  spirv.ARM.GraphEntryPoint @reducesum_fp, @reducesum_fp_arg_0, @reducesum_fp_res_0
+  spirv.ARM.Graph @reducesum_fp(%arg0: !spirv.arm.tensor<32x32x33xf32>) -> (!spirv.arm.tensor<32x1x33xf32>) {
+    // CHECK: {{%.*}} = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x1x33xf32>
+    %1 = spirv.Tosa.ReduceSum axis = 1, %arg0 : !spirv.arm.tensor<32x32x33xf32> -> !spirv.arm.tensor<32x1x33xf32>
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<32x1x33xf32>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<32x1x33xf32>
+  }
+}



More information about the Mlir-commits mailing list