[Mlir-commits] [mlir] 0763f12 - [mlir][tosa] Handle rescale case where shift > 63

Rob Suderman llvmlistbot at llvm.org
Thu Dec 16 15:37:35 PST 2021


Author: Rob Suderman
Date: 2021-12-16T15:30:48-08:00
New Revision: 0763f12213dc931a4c6926324e4e5d825237405c

URL: https://github.com/llvm/llvm-project/commit/0763f12213dc931a4c6926324e4e5d825237405c
DIFF: https://github.com/llvm/llvm-project/commit/0763f12213dc931a4c6926324e4e5d825237405c.diff

LOG: [mlir][tosa] Handle rescale case where shift > 63

It is possible for the shift value to exceed the number of bits. In these
cases we can just multiply by zero. This is relatively rare occurence but
should be handled.

Reviewed By: not-jenni

Differential Revision: https://reviews.llvm.org/D115779

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c0909c03e19e..78981fd097158 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1771,6 +1771,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     SmallVector<int8_t> shiftValues;
     getValuesFromIntArrayAttribute(op.shift(), shiftValues);
 
+    // If we shift by more than the bitwidth, this just sets to 0.
+    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+      if (shiftValues[i] > 63) {
+        shiftValues[i] = 0;
+        multiplierValues[i] = 0;
+      }
+    }
+
     // Double round only occurs if shift is greater than 31, check that this
     // is ever true.
     bool doubleRound =

diff  --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 6f21e779b37df..ff82acec21e97 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -43,6 +43,13 @@ static void computeMultiplierAndShiftTosaScale16(double scale,
          "Shifted mantissa exceeds 32-bit signed output type");
 
   multiplier = static_cast<int32_t>(shiftedM);
+
+  // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
+  if (shift > 63) {
+    // Shifting the multiplier by more than 32-bits is unnecessary.
+    multiplier = multiplier >> std::min<int32_t>(32, shift - 63);
+    shift = 63;
+  }
 }
 
 /// From a scale value, generates multiplier and shift values where
@@ -71,6 +78,13 @@ static void computeMultiplierAndShiftTosaScale32(double scale,
          "Shifted mantissa exceeds 32-bit signed output type");
 
   multiplier = static_cast<int32_t>(shiftedM);
+
+  // Shifting tops out at 63 bits. Right shift to make 63 bits the max.
+  if (shift > 63) {
+    // Shifting the multiplier by more than 32-bits is unnecessary.
+    multiplier = multiplier >> std::min<int32_t>(32, shift - 63);
+    shift = 63;
+  }
 }
 
 /// Generates a quantized multiplier/shift from double.

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a381375a409b1..f8dc2e0bbd08d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -931,11 +931,11 @@ func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @rescale_per_channel
-func @rescale_per_channel(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
-  // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43]>
-  // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15]>
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
+func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
+  // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43, 0]>
+  // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15, 0]>
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [3]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<3xi8>, tensor<3xi32>, tensor<3xi8>) outs([[INIT]] : tensor<3xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK: [[C243:%.+]] = arith.constant 243
   // CHECK: [[C252:%.+]] = arith.constant 252
@@ -952,10 +952,10 @@ func @rescale_per_channel(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32, 44 : i32], shift = [14 : i32, 15 : i32, 64 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>)  -> (tensor<3xi8>)
 
   // CHECK: return [[GENERIC]]
-  return %0 : tensor<2xi8>
+  return %0 : tensor<3xi8>
 }
 
 // -----


        


More information about the Mlir-commits mailing list