[Mlir-commits] [mlir] af78e5d - [mlir][tosa]Fix Rescale shift attr data type (#71084)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 10 08:57:43 PST 2024


Author: Tai Ly
Date: 2024-01-10T16:57:39Z
New Revision: af78e5daf0791135485dbd7972ffedb927727a6b

URL: https://github.com/llvm/llvm-project/commit/af78e5daf0791135485dbd7972ffedb927727a6b
DIFF: https://github.com/llvm/llvm-project/commit/af78e5daf0791135485dbd7972ffedb927727a6b.diff

LOG: [mlir][tosa]Fix Rescale shift attr data type (#71084)

Change Rescale shift attribute to be DenseI8ArrayAttr to match spec
(instead of DenseI32ArrayAttr)

This replaces https://reviews.llvm.org/D157439

Signed-off-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
    mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9dde59f634d739..3257ecd9d91f11 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1848,7 +1848,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     I32Attr:$input_zp,
     I32Attr:$output_zp,
     DenseI32ArrayAttr:$multiplier,
-    DenseI32ArrayAttr:$shift,
+    DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3931e454da2e22..8a29752ff8d7f8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -93,7 +93,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   // CHECK:   linalg.yield [[ADDF]] : f32
   // CHECK: } -> tensor<f32>
   %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  
+
   // CHECK: return [[RESULT]] : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -223,7 +223,7 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t
   // CHECK:   linalg.yield %[[VAL_4]] : f32
   // CHECK: } -> tensor<3xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
-  
+
   // CHECK: return %[[RESULT]] : tensor<3xf32>
   return %0 : tensor<3xf32>
 }
@@ -352,7 +352,7 @@ func.func @test_add_2d_
diff erent_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
   // CHECK:   linalg.yield %[[VAL_4]] : f32
   // CHECK: } -> tensor<2x3x4xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-  
+
   // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
@@ -1057,7 +1057,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
 
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
@@ -1079,7 +1079,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
   // CHECK: linalg.yield [[CAST]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
 
   // CHECK: return
   return
@@ -1096,13 +1096,13 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
 
   return
 }
@@ -1120,7 +1120,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i32: 38>} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
   return
 }
 
@@ -1151,7 +1151,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
 
   return
 }
@@ -1183,7 +1183,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i32: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1194,18 +1194,18 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
 // CHECK-LABEL: @rescaleDoubleRound
 func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
-  // CHECK: tosa.apply_scale 
+  // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = true}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i32: 33>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
 // CHECK-LABEL: @rescaleUnnecessaryDoubleRound
 func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
-  // CHECK: tosa.apply_scale 
+  // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = false}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a3d0b5e5447aa3..3d68464ebf0b30 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -64,7 +64,7 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
   %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
   %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
-  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i32: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -604,7 +604,7 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 // -----
 // CHECK-LABEL: rescale
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
-    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i32: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f5cd8bd48de1cc..1f0cfaf92c5c74 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -94,7 +94,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
 
   // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
-  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i32: 14, 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>) -> tensor<*xi16>
+  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i8: 14, 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>) -> tensor<*xi16>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>

diff  --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 9642301e8111c1..e5a3e2b6fccaa3 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -169,8 +169,9 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
       rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
       rewriter.getDenseI32ArrayAttr({multiplier}),
-      rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(true),
-      rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
+      rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
+      rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
+      rewriter.getBoolAttr(false));
 
   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
   return success();


        


More information about the Mlir-commits mailing list