[Mlir-commits] [mlir] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp (PR #192227)

Davide Grohmann llvmlistbot at llvm.org
Thu Apr 16 23:57:05 PDT 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/192227

>From a818accdc74795e80ba8290f29f81905f6084c42 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 14 Apr 2026 13:43:19 +0200
Subject: [PATCH] [mlir][spirv] Improve type constraints for SPIR-V Tosa CastOp

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I8e888c338512b9bc71dcdc78264a87bca2fe019d
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td     | 48 +++++++++++++-
 .../SPIRV/IR/tosa-ops-verification.mlir       | 66 +++++++++++++++++++
 2 files changed, 113 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index 2c086c9e48ebb..d72cb65a5d86c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -2661,12 +2661,58 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 
 
 def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
-  AllShapesMatch<["input", "output"]>]> {
+  AllShapesMatch<["input", "output"]>,
+  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
+  TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
+  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
   let summary = "Cast operation.";
 
   let description = [{
     Casts a tensor from one data type to another.
 
+    Valid casting combinations are defined in the following table:
+
+    | From    | To      |
+    |---------|---------|
+    | float16 | float32 |
+    | float16 | int16   |
+    | float16 | int32   |
+    | float16 | int8    |
+    | float32 | float16 |
+    | float32 | int16   |
+    | float32 | int32   |
+    | float32 | int8    |
+    | int16   | float16 |
+    | int16   | float32 |
+    | int32   | float16 |
+    | int32   | float32 |
+    | int8    | float16 |
+    | int8    | float32 |
+    | Boolean | int16   |
+    | Boolean | int32   |
+    | Boolean | int8    |
+    | int16   | Boolean |
+    | int16   | int32   |
+    | int16   | int8    |
+    | int32   | Boolean |
+    | int32   | int16   |
+    | int32   | int8    |
+    | int8    | Boolean |
+    | int8    | int16   |
+    | int8    | int32   |
+    | bf16    | float32 |
+    | bf16    | int16   |
+    | bf16    | int32   |
+    | bf16    | int8    |
+    | float32 | bf16    |
+    | int16   | bf16    |
+    | int32   | bf16    |
+    | int8    | bf16    |
+
     References:
       * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
       * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index 057aa353b4ee1..f5a8a3ca4c05d 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -1951,6 +1951,72 @@ spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<
   spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x5xi32>
 }
 
+spirv.ARM.Graph @cast_f16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xbf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
+spirv.ARM.Graph @cast_f16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_f32_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xf32>) {
+  // expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
+}
+
+spirv.ARM.Graph @cast_i8_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi8>) {
+  // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,32-bit signless integer,bool,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x4xi8>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi8>
+}
+
+spirv.ARM.Graph @cast_i16_to_i16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi16>) {
+  // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [16-bit float,32-bit float,32-bit signless integer,8-bit signless integer,bool,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi16> -> !spirv.arm.tensor<2x3x4xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi16>
+}
+
+spirv.ARM.Graph @cast_i32_to_i32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi32>) {
+  // expected-error @+1 {{op failed to verify that if input has type 32-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,8-bit signless integer,bool,bfloat16 type]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi32> -> !spirv.arm.tensor<2x3x4xi32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi32>
+}
+
+spirv.ARM.Graph @cast_bool_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xf32>) {
+  // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xf32>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
+}
+
+spirv.ARM.Graph @cast_bool_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_bool_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi1>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type bool then output must have a type in [16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi1> -> !spirv.arm.tensor<2x3x4xbf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
+spirv.ARM.Graph @cast_bf16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
+}
+
+spirv.ARM.Graph @cast_bf16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
+  // expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
+  %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xbf16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.TOSA.Rescale
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list