[Mlir-commits] [mlir] [mlir][tosa] Allow zero-points to be unranked (PR #143770)
Luke Hutton
llvmlistbot at llvm.org
Wed Jun 11 12:22:50 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/143770
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.
>From 015ab397da99a0fd6b3e14c32986aea21d22dfed Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 11 Jun 2025 14:10:41 +0000
Subject: [PATCH] [mlir][tosa] Allow zero-points to be unranked
This commit allows zero-points used by a number of
tosa operations to be unranked. This allows the shape
inference pass to propagate shape information.
Change-Id: I20c1a04eb2d306f181ffd5c1574f69fd9e410102
---
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td | 2 +-
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 13 +++++++++++++
3 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 536551c8f8437..1cfe6eee576b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -152,7 +152,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
-def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
+def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..9982e1a7fe197 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1007,7 +1007,7 @@ func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1
func.func @test_conv2d_rank0_zp(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi32> {
%input_zp = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<i8>
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op operand #3 must be tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #3 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<i8>, tensor<1xi8>) -> tensor<1x27x27x16xi32>
return %0 : tensor<1x27x27x16xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 591a3f0acf65d..f5ab4a2241358 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -333,6 +333,19 @@ func.func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?
// -----
+// CHECK-LABEL: @test_unranked_zero_points_matmul
+func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x3x4xf32>) -> tensor<1x2x4xf32> {
+ %a_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %b_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %a_zp_cast = "tosa.cast"(%a_zp) : (tensor<1xi8>) -> tensor<*xf32>
+ %b_zp_cast = "tosa.cast"(%b_zp) : (tensor<1xi8>) -> tensor<*xf32>
+ // CHECK: tosa.matmul %arg0, %arg1, %2, %3 : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp_cast, %b_zp_cast : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x2x4xf32>
+ return %0 : tensor<1x2x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
More information about the Mlir-commits
mailing list