[Mlir-commits] [mlir] [mlir][spirv] Add first 3 data layout ops in TOSA Ext Inst Set (PR #187714)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 20 07:50:22 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>

This patch introduces the following reduction operators:

spirv.Tosa.Concat
spirv.Tosa.Pad
spirv.Tosa.Reshape

Also dialect and serialization round-trip tests have been added.

---

Patch is 30.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/187714.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+134) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+29) 
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+72) 
- (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir (+72) 
- (modified) mlir/test/Target/SPIRV/tosa-ops.mlir (+125) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 77be0a680c15b..327848f510368 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2125,4 +2125,138 @@ def SPIRV_TosaReduceSumOp : SPIRV_TosaReductionOp<"ReduceSum", 53, [NoMemoryEffe
 }
 
 
+def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
+  VariadicInputWithMinSize<"input1", 1>,
+  VariadicInputAllSameElementType<"output", "input1">,
+  VariadicInputAllSameRank<"output", "input1">,
+  AxisValueLessThanRankOf<"output">]> {
+  let summary = "Concatenates tensors along one dimension.";
+
+  let description = [{
+    Concatenates a list of tensors along a given axis.
+    No data conversion happens during a concat operation.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_concat
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_concat
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
+    %1 = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TensorArmAxisAttr: $axis,
+    Variadic<SPIRV_TosaAny_TensorArm>: $input1
+  );
+
+  let results = (outs
+    SPIRV_TosaAny_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    `axis` `=` $axis `,`
+    $input1
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+  }];
+}
+
+
+def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
+  AllElementTypesMatch<["input1", "pad_const", "output"]>,
+  AllRanksMatch<["input1", "output"]>,
+  ShapeConstraintFromInputRank<"input1", "padding", 2>]> {
+  let summary = "Pads a tensor with value specified.";
+
+  let description = [{
+    Pads a tensor along the borders of each dimension with a supplied value.
+    Returns a new tensor with the padding included. The pad_const value includes
+    the zero point if the tensor uses a zero point.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_pad
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_pad
+
+    #### Example:
+    ```mlir
+    %2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
+    %2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaAny_TensorArm: $input1,
+    SPIRV_Int32_1DTensorArmOfEvenLength2To12: $padding,
+    SPIRV_TosaAny_1DTensorArmOfLength1: $pad_const
+  );
+
+  let results = (outs
+    SPIRV_TosaAny_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input1 `,`
+    $padding `,`
+    $pad_const
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInput1Type() {
+      return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
+  AllElementTypesMatch<["input1", "output"]>,
+  AllElementCountsMatch<["input1", "output"]>,
+  ShapeConstraintFromInputRank<"output", "shape">]> {
+  let summary = "Reshape operator.";
+
+  let description = [{
+    Returns a tensor with the same type/values as the input, with a new shape
+    specified by the shape argument. Reshape may operate on tensors of any rank.
+    No data conversion happens during a reshape operation.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reshape
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reshape
+
+    #### Example:
+    ```mlir
+    %1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
+    %1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaAny_TensorArm: $input1,
+    SPIRV_Int32_1DTensorArmOfLength1To6: $shape
+  );
+
+  let results = (outs
+    SPIRV_TosaAny_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input1 `,`
+    $shape
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInput1Type() {
+      return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
+    }
+  }];
+}
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 8d862dd87c12a..4b8f18b4cf695 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -65,6 +65,9 @@ class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allo
     "rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
     "::mlir::spirv::TensorArmType">;
 
+def SPIRV_Int32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1,2,3,4,5,6], [SPIRV_Int32]>;
+def SPIRV_Int32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2,4,6,8,10,12], [SPIRV_Int32]>;
+
 def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
   CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
   "Attr with type = spirv::TensorArmType">;
@@ -77,6 +80,7 @@ def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6
 
 def SPIRV_Int8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
 def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
+def SPIRV_TosaAny_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaAny]>;
 
 // Struct type
 
@@ -139,4 +143,29 @@ class TableSizeConstraint<string input, Type type, int size>:
       Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
     >;
 
+class ShapeConstraintFromInputRank<string input, string other, int mul=1>:
+  PredOpTrait< "the number of elements of " # other # " must be rank(" # input # ")" # !if(!eq(mul, 1), "", " * " # mul),
+      Implies<CPred<HasRank<input>.result>,
+        [CPred<ElementCount<other>.result # " == " # mul # " * " # Rank<input>.result>]>
+    >;
+
+class VariadicInputWithMinSize<string input, int min_size>:
+    PredOpTrait<"variadic " # input # " must has at least " # min_size # " elements",
+      CPred<"static_cast<int64_t>($" # input # ".getTypes().size()) >= " # min_size>>;
+
+class VariadicInputAllSameElementType<string reference, string input>:
+    PredOpTrait<"all elements of variadic " # input # " must have same element type",
+      CPred<"::llvm::all_of($" # input # ".getTypes(), "
+                     "[&](::mlir::Type t) { return ::llvm::cast<::mlir::ShapedType>(t).getElementType() == "
+                     # ElementType<reference>.result # "; })">>;
+
+class VariadicInputAllSameRank<string reference, string input>:
+    PredOpTrait<"all elements of variadic " # input # " must have same element type",
+      CPred<"::llvm::all_of($" # input # ".getTypes(), "
+                     "[&](::mlir::Type t) { return ::llvm::cast<::mlir::ShapedType>(t).hasRank() && "
+                     # HasRank<reference>.result #
+                     " && ::llvm::cast<::mlir::ShapedType>(t).getRank() == "
+                     # Rank<reference>.result # "; })">>;
+
+
 #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 28526019551fc..7341d7026a849 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1684,3 +1684,75 @@ spirv.ARM.Graph @reducesum_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
   %0 = spirv.Tosa.ReduceSum axis = 3, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x24x22xi32>
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi32>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Concat
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @concat_must_have_at_least_one_input() -> (!spirv.arm.tensor<4x12xi8>) {
+  // expected-error @+1 {{op failed to verify that variadic input1 must has at least 1 elements}}
+  %0 = "spirv.Tosa.Concat"() <{axis = 0 : i32}> : () -> !spirv.arm.tensor<4x12xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi8>
+}
+
+spirv.ARM.Graph @concat_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12xi16>) {
+  // expected-error @+1 {{op failed to verify that all elements of variadic input1 must have same element type}}
+  %0 = spirv.Tosa.Concat axis = 1, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi16>
+}
+
+spirv.ARM.Graph @concat_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12x1xi8>) {
+  // expected-error @+1 {{op failed to verify that all elements of variadic input1 must have same element type}}
+  %0 = spirv.Tosa.Concat axis = 1, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12x1xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12x1xi8>
+}
+
+spirv.ARM.Graph @concat_axis_value_not_in_output_rank_range(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12xi8>) {
+  // expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(output)}}
+  %0 = spirv.Tosa.Concat axis = 2, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Pad
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @pad_input_pad_const_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<4xi32>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<5x8xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input1, pad_const, output} have same element type}}
+  %0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<5x8xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x8xi8>
+}
+
+spirv.ARM.Graph @pad_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<4xi32>, %arg2: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<1x5x8xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input1, output} have same rank}}
+  %0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x5x8xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x5x8xi8>
+}
+
+spirv.ARM.Graph @pad_padding_element_count_not_twice_input_rank(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<6xi32>, %arg2: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<5x8xi8>) {
+  // expected-error @+1 {{op failed to verify that the number of elements of padding must be rank(input1) * 2}}
+  %0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<6xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<5x8xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x8xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Reshape
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reshape_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<2xi32>) -> (!spirv.arm.tensor<6x4xi16>) {
+  // expected-error @+1 {{op failed to verify that all of {input1, output} have same element type}}
+  %0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<6x4xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x4xi16>
+}
+
+spirv.ARM.Graph @reshape_input_output_element_counts_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<2xi32>) -> (!spirv.arm.tensor<5x4xi8>) {
+  // expected-error @+1 {{op failed to verify that all of {input1, output} have same element count}}
+  %0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<5x4xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x4xi8>
+}
+
+spirv.ARM.Graph @reshape_shape_element_count_not_output_rank(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<6x4xi8>) {
+  // expected-error @+1 {{op failed to verify that the number of elements of shape must be rank(output)}}
+  %0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<6x4xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x4xi8>
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index d8f495cad3b7f..1a10605c83f69 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -885,3 +885,75 @@ spirv.ARM.Graph @reducesum_fp(%arg0: !spirv.arm.tensor<32x32x33xf32>) -> (!spirv
   // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<32x1x33xf32>
   spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<32x1x33xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Concat - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @concat_int(%arg0: !spirv.arm.tensor<12x13x3x14xi8>, %arg1: !spirv.arm.tensor<12x13x3x14xi8>, %arg2: !spirv.arm.tensor<12x13x3x14xi8>, %arg3: !spirv.arm.tensor<12x13x3x14xi8>) -> (!spirv.arm.tensor<12x13x12x14xi8>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
+  %1 = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<12x13x12x14xi8>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<12x13x12x14xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Concat - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @concat_fp(%arg0: !spirv.arm.tensor<40x31x19xf32>, %arg1: !spirv.arm.tensor<40x15x19xf32>, %arg2: !spirv.arm.tensor<40x16x19xf32>) -> (!spirv.arm.tensor<40x62x19xf32>) {
+  // CHECK: {{%.*}} = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
+  %1 = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<40x62x19xf32>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<40x62x19xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Pad - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @pad_int(%arg0: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<21x19xi8>) {
+  %0 = spirv.Constant dense<[10, 7, 6, 6]> : !spirv.arm.tensor<4xi32>
+  %1 = spirv.Constant dense<-76> : !spirv.arm.tensor<1xi8>
+  // CHECK: {{%.*}} = spirv.Tosa.Pad %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
+  %2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<21x19xi8>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<21x19xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Pad - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @pad_fp(%arg0: !spirv.arm.tensor<2x9x2x3xf32>) -> (!spirv.arm.tensor<4x9x4x4xf32>) {
+  %0 = spirv.Constant dense<[1, 1, 0, 0, 1, 1, 0, 1]> : !spirv.arm.tensor<8xi32>
+  %1 = spirv.Constant dense<1.21630913E+38> : !spirv.arm.tensor<1xf32>
+  // CHECK: {{%.*}} = spirv.Tosa.Pad %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
+  %2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x9x4x4xf32>
+  spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<4x9x4x4xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Reshape - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reshape_int(%arg0: !spirv.arm.tensor<25x6x29x35xi16>) -> (!spirv.arm.tensor<125x6x7x29xi16>) {
+  %0 = spirv.Constant dense<[125, 6, 7, 29]> : !spirv.arm.tensor<4xi32>
+  // CHECK: {{%.*}} = spirv.Tosa.Reshape %arg0, {{%.*}} : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
+  %1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<125x6x7x29xi16>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<125x6x7x29xi16>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Reshape - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @reshape_fp(%arg0: !spirv.arm.tensor<1x2x7x2xf32>) -> (!spirv.arm.tensor<2x1x14xf32>) {
+  %0 = spirv.Constant dense<[2, 1, 14]> : !spirv.arm.tensor<3xi32>
+  // CHECK: {{%.*}} = spirv.Tosa.Reshape %arg0, {{%.*}} : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
+  %1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x1x14xf32>
+  spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x1x14xf32>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 0176411920245..8119cbb6a220d 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -1555,3 +1555,128 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<32x1x33xf32>
   }
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Concat - PRO-INT
+//===---------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/187714


More information about the Mlir-commits mailing list