[Mlir-commits] [mlir] fdd7f9c - [mlir][tosa] Allow scalar int8 tensors to be unranked (#150731)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 28 02:03:38 PDT 2025


Author: Yi-Chi Lee
Date: 2025-07-28T10:03:35+01:00
New Revision: fdd7f9c61bbc476bfc6839dec3428e1dea06eacb

URL: https://github.com/llvm/llvm-project/commit/fdd7f9c61bbc476bfc6839dec3428e1dea06eacb
DIFF: https://github.com/llvm/llvm-project/commit/fdd7f9c61bbc476bfc6839dec3428e1dea06eacb.diff

LOG: [mlir][tosa] Allow scalar int8 tensors to be unranked (#150731)

This PR fixes #150519

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 349e8edb0ed96..754640dca6561 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -151,7 +151,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
 def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
-def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
+def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>;
 def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
 
 // We include unranked tensors as a supported type for all possible tosa

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 716362e094652..b90d6f59ba820 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
 // CHECK-LABEL: test_mul_non_scalar_shift_2d
 func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
-  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
@@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
 // CHECK-LABEL: test_mul_non_scalar_shift_1d
 func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
-  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9d43f8998da93..5a1610635af2c 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1:
 
 // -----
 
+// CHECK-LABEL: @test_unranked_scalar_i8_tensor
+func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> {
+  // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8>
+  %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8>
+  // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_table_static
 func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
   // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>


        


More information about the Mlir-commits mailing list