[Mlir-commits] [mlir] 2191f5a - [MLIR][TOSA] Add missing SameOperandsAndResultShape Trait to tosa.cast (#153826)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 26 02:03:29 PDT 2025


Author: Jonas Rickert
Date: 2025-08-26T10:03:26+01:00
New Revision: 2191f5a3c1c79fa3c61e85f01a9893360add2a6d

URL: https://github.com/llvm/llvm-project/commit/2191f5a3c1c79fa3c61e85f01a9893360add2a6d
DIFF: https://github.com/llvm/llvm-project/commit/2191f5a3c1c79fa3c61e85f01a9893360add2a6d.diff

LOG: [MLIR][TOSA] Add missing SameOperandsAndResultShape Trait to tosa.cast (#153826)

According to the TOSA spec, tosa.cast is only changing the elementtype,
and not the shape of the input tensor

Signed-off-by: Rickert, Jonas <jonas.rickert at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
 //===----------------------------------------------------------------------===//
 // Operator: cast
 //===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>]> {
 

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 930bb9fe96811..b6d68c94f436b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
   // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
   // CHECK:           return %[[VAL_0]] : tensor<3x600x1200xf32>
   // CHECK:         }
-  %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
-  %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
-  return %2 : tensor<3x600x1200xf32>
+  %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+  return %1 : tensor<3x600x1200xf32>
 }
 
 // -----
@@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
 // CHECK-LABEL: @do_not_fold_reciprocal_int
 func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
   // CHECK:           tosa.reciprocal
-  %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
-  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
-  return %2 : tensor<3x600x1200xi32>
+  %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+  return %1 : tensor<3x600x1200xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3bccb32c5b9f4..7c76c178436c7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -6,6 +6,15 @@
 
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
 
+
+func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
+  // expected-error at +1{{'tosa.cast' op requires the same shape for all operands and results}}
+  %1 = "tosa.cast"(%arg0) : (tensor<i1>) -> tensor<5xi32>
+  return %1 : tensor<5xi32>
+}
+
+// -----
+
 func.func @test_const() -> tensor<1xf32> {
   // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
   %0 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>


        


More information about the Mlir-commits mailing list