[Mlir-commits] [mlir] 3205ee7 - [mlir][tosa] Support UInt8 inputs and outputs for tosa.rescale

Rob Suderman llvmlistbot at llvm.org
Thu Aug 19 18:59:47 PDT 2021


Author: Rob Suderman
Date: 2021-08-19T18:58:44-07:00
New Revision: 3205ee7e812fa6d18d68c00f04c1cffef145e279

URL: https://github.com/llvm/llvm-project/commit/3205ee7e812fa6d18d68c00f04c1cffef145e279
DIFF: https://github.com/llvm/llvm-project/commit/3205ee7e812fa6d18d68c00f04c1cffef145e279.diff

LOG: [mlir][tosa] Support UInt8 inputs and outputs for tosa.rescale

Tosa rescale can contain uint8 types. Added support for these types
using an unrealized conversion cast. Optimistically it would be better to
use bitcast however it does not support unsigned integers.

Differential Revision: https://reviews.llvm.org/D108427

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 484628109b6f2..02129a5f1bcca 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -38,6 +38,8 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 // Used to express accumulator results or compare results.
 //===----------------------------------------------------------------------===//
 
+def Tosa_UInt8 : UI<8>;
+
 def Tosa_Int8 : I<8>;
 def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
@@ -54,6 +56,7 @@ def Tosa_Bool : I<1>;
 
 // No unsigned unquantized int types.
 def Tosa_Int : AnyTypeOf<[Tosa_Bool,
+                          Tosa_UInt8,
                           Tosa_SignedInt]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f600c27d992f0..cfd1696d55d15 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1544,12 +1544,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
         [&](OpBuilder &nestedBuilder, Location nestedLoc,
             ValueRange blockArgs) {
           Value value = blockArgs[0];
+          Type valueTy = value.getType();
 
           // For now we do all of our math in 64-bit. This is not optimal but
           // should be correct for now, consider computing correct bit depth
           // later.
-          int32_t inBitwidth =
-              value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32;
+          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
 
           auto inputZp = createConstFromIntAttribute<int32_t>(
               op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
@@ -1561,9 +1561,21 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
                                                 : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
-          if (value.getType().getIntOrFloatBitWidth() < 32) {
-            value = nestedBuilder.create<SignExtendIOp>(
-                nestedLoc, nestedBuilder.getI32Type(), value);
+          if (valueTy.getIntOrFloatBitWidth() < 32) {
+            if (valueTy.isUnsignedInteger()) {
+              value = nestedBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              nestedLoc,
+                              nestedBuilder.getIntegerType(
+                                  valueTy.getIntOrFloatBitWidth()),
+                              value)
+                          .getResult(0);
+              value = nestedBuilder.create<ZeroExtendIOp>(
+                  nestedLoc, nestedBuilder.getI32Type(), value);
+            } else {
+              value = nestedBuilder.create<SignExtendIOp>(
+                  nestedLoc, nestedBuilder.getI32Type(), value);
+            }
           }
 
           value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
@@ -1579,21 +1591,38 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           IntegerType outIntType =
               blockArgs.back().getType().cast<IntegerType>();
           unsigned outBitWidth = outIntType.getWidth();
-          auto intMin = nestedBuilder.create<ConstantOp>(
-              loc, nestedBuilder.getIntegerAttr(
-                       nestedBuilder.getI32Type(),
-                       APInt::getSignedMinValue(outBitWidth).getSExtValue()));
-          auto intMax = nestedBuilder.create<ConstantOp>(
-              loc, nestedBuilder.getIntegerAttr(
-                       nestedBuilder.getI32Type(),
-                       APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
-
-          value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
-                                            CmpIPredicate::slt, nestedBuilder);
+
+          int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
+          int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
+
+          // Unsigned integers have a 
diff erence output value.
+          if (outIntType.isUnsignedInteger()) {
+            intMin = 0;
+            intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
+          }
+
+          auto intMinVal = nestedBuilder.create<ConstantOp>(
+              loc,
+              nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin));
+          auto intMaxVal = nestedBuilder.create<ConstantOp>(
+              loc,
+              nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax));
+
+          value =
+              clampHelper<mlir::CmpIOp>(nestedLoc, value, intMinVal, intMaxVal,
+                                        CmpIPredicate::slt, nestedBuilder);
 
           if (outIntType.getWidth() < 32) {
-            value =
-                nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
+            value = nestedBuilder.create<TruncateIOp>(
+                nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
+                value);
+
+            if (outIntType.isUnsignedInteger()) {
+              value = nestedBuilder
+                          .create<UnrealizedConversionCastOp>(nestedLoc,
+                                                              outIntType, value)
+                          .getResult(0);
+            }
           }
 
           nestedBuilder.create<linalg::YieldOp>(loc, value);

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 99c33d9b8eba7..68f56bd84cc9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -727,20 +727,19 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale
-func @rescale(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+// CHECK-LABEL: @rescale_i8
+func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = constant 19689
   // CHECK: [[C1:%.+]] = constant 15
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C243:%.+]] = constant 243
-  // CHECK: [[C252:%.+]] = constant 252
-
+  // CHECK: [[C17:%.+]] = constant 17
+  // CHECK: [[C22:%.+]] = constant 22
   // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]]
   // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false}
-  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]]
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = constant -128
   // CHECK-DAG: [[CMAX:%.+]] = constant 127
   // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
@@ -749,10 +748,63 @@ func @rescale(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
+  %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
 
-  // CHECK: return [[GENERIC]]
-  return %0 : tensor<2xi8>
+  // CHECK: [[C0:%.+]] = constant 19689
+  // CHECK: [[C1:%.+]] = constant 15
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
+  // CHECK: [[C17:%.+]] = constant 17
+  // CHECK: [[C22:%.+]] = constant 22
+  // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = constant 255
+  // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
+  // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
+  // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
+  // CHECK: linalg.yield [[CAST]]
+  %1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xui8>)
+
+  // CHECK: return
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_ui8
+func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
+  // CHECK: [[C0:%.+]] = constant 19689
+  // CHECK: [[C1:%.+]] = constant 15
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C17:%.+]] = constant 17
+  // CHECK: [[C22:%.+]] = constant 22
+  // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
+  // CHECK-DAG: [[IN32:%.+]] = zexti [[CAST]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = constant 127
+  // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
+  // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>)  -> (tensor<2xi8>)
+
+  return
 }
 
 // -----


        


More information about the Mlir-commits mailing list