[Mlir-commits] [mlir] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp (PR #192227)
Davide Grohmann
llvmlistbot at llvm.org
Thu Apr 16 23:57:05 PDT 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/192227
>From a818accdc74795e80ba8290f29f81905f6084c42 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 14 Apr 2026 13:43:19 +0200
Subject: [PATCH] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I8e888c338512b9bc71dcdc78264a87bca2fe019d
---
.../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td | 48 +++++++++++++-
.../SPIRV/IR/tosa-ops-verification.mlir | 66 +++++++++++++++++++
2 files changed, 113 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 2c086c9e48ebb..d72cb65a5d86c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2661,12 +2661,58 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
- AllShapesMatch<["input", "output"]>]> {
+ AllShapesMatch<["input", "output"]>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
+ TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
+ TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
let summary = "Cast operation.";
let description = [{
Casts a tensor from one data type to another.
+ Valid casting combinations are defined in the following table:
+
+ | From | To |
+ |---------|---------|
+ | float16 | float32 |
+ | float16 | int16 |
+ | float16 | int32 |
+ | float16 | int8 |
+ | float32 | float16 |
+ | float32 | int16 |
+ | float32 | int32 |
+ | float32 | int8 |
+ | int16 | float16 |
+ | int16 | float32 |
+ | int32 | float16 |
+ | int32 | float32 |
+ | int8 | float16 |
+ | int8 | float32 |
+ | Boolean | int16 |
+ | Boolean | int32 |
+ | Boolean | int8 |
+ | int16 | Boolean |
+ | int16 | int32 |
+ | int16 | int8 |
+ | int32 | Boolean |
+ | int32 | int16 |
+ | int32 | int8 |
+ | int8 | Boolean |
+ | int8 | int16 |
+ | int8 | int32 |
+ | bf16 | float32 |
+ | bf16 | int16 |
+ | bf16 | int32 |
+ | bf16 | int8 |
+ | float32 | bf16 |
+ | int16 | bf16 |
+ | int32 | bf16 |
+ | int8 | bf16 |
+
References:
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
* https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 057aa353b4ee1..f5a8a3ca4c05d 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1951,6 +1951,72 @@ spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x5xi32>
}
+spirv.ARM.Graph @cast_f16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xbf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
+spirv.ARM.Graph @cast_f16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_f32_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xf32>) {
+ // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
+}
+
+spirv.ARM.Graph @cast_i8_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi8>) {
+ // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,32-bit signless integer,bool,bfloat16 type]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x4xi8>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi8>
+}
+
+spirv.ARM.Graph @cast_i16_to_i16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+ // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [16-bit float,32-bit float,32-bit signless integer,8-bit signless integer,bool,bfloat16 type]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi16> -> !spirv.arm.tensor<2x3x4xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @cast_i32_to_i32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi32>) {
+ // expected-error @+1 {{op failed to verify that if input has type 32-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,8-bit signless integer,bool,bfloat16 type]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi32> -> !spirv.arm.tensor<2x3x4xi32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi32>
+}
+
+spirv.ARM.Graph @cast_bool_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xf32>) {
+ // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
+}
+
+spirv.ARM.Graph @cast_bool_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_bool_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xbf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
+spirv.ARM.Graph @cast_bf16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_bf16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+ // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xbf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
//===----------------------------------------------------------------------===//
// spirv.TOSA.Rescale
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list