[Mlir-commits] [mlir] 8f190b1 - [mlir][tosa] Add tosa.negate lowerings for quantized cases

Rob Suderman llvmlistbot at llvm.org
Tue Apr 27 17:22:41 PDT 2021


Author: Rob Suderman
Date: 2021-04-27T17:16:39-07:00
New Revision: 8f190b13bab16d44819aa9aaf83a327ac2ead68d

URL: https://github.com/llvm/llvm-project/commit/8f190b13bab16d44819aa9aaf83a327ac2ead68d
DIFF: https://github.com/llvm/llvm-project/commit/8f190b13bab16d44819aa9aaf83a327ac2ead68d.diff

LOG: [mlir][tosa] Add tosa.negate lowerings for quantized cases

Quantized negation can be performed using higher bits operations.
Minimal bits are picked to perform the operation.

Differential Revision: https://reviews.llvm.org/D101225

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 042626e2a477..51de267170ad 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -159,14 +159,65 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   // tosa::NegateOp
-  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
+  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+
+  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+      !cast<tosa::NegateOp>(op).quantization_info()) {
     auto constant =
-        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
-    return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
+        rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+    return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
   }
 
-  if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
-    return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
+  if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+      cast<tosa::NegateOp>(op).quantization_info()) {
+    auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
+    int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+    int64_t inZp =
+        quantizationInfo.getValue().input_zp().getValue().getSExtValue();
+    int64_t outZp =
+        quantizationInfo.getValue().output_zp().getValue().getSExtValue();
+
+    // Compute the maximum value that can occur in the intermediate buffer.
+    int64_t zpAdd = inZp + outZp;
+    int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+                       std::abs(zpAdd) + 1;
+
+    // Convert that maximum value into the maximum bitwidth needed to represent
+    // it. We assume 48-bit numbers may be supported further in the pipeline.
+    int intermediateBitWidth = 64;
+    if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+      intermediateBitWidth = 16;
+    } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+      intermediateBitWidth = 32;
+    } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+      intermediateBitWidth = 48;
+    }
+
+    Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+    Value zpAddValue = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+    // The negation can be applied by doing:
+    //  outputValue = inZp + outZp - inputValue
+    auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
+    auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
+
+    // Clamp to the negation range.
+    auto min = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(
+                 intermediateType,
+                 APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
+    auto max = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(
+                 intermediateType,
+                 APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
+    auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
+                                           CmpIPredicate::slt, rewriter);
+
+    // Truncate to the final value.
+    return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
+  }
 
   // tosa::BitwiseAndOp
   if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 489bdd3cd94f..3f9940e4a203 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -258,7 +258,8 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: muli
+  // CHECK: [[ZERO:%.+]] = constant 0
+  // CHECK: subi [[ZERO]], %arg1
   %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -363,6 +364,35 @@ func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_negate_quantized
+func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: [[ZERO:%.+]] = constant 0
+  // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16
+  // CHECK: [[SUB:%.+]] = subi [[ZERO]], [[EXT]]
+  // CHECK: [[MIN:%.+]] = constant -128
+  // CHECK: [[MAX:%.+]] = constant 127
+  // CHECK: [[PRED1:%.+]] = cmpi slt, [[SUB]], [[MIN]]
+  // CHECK: [[LBOUND:%.+]] = select [[PRED1]], [[MIN]], [[SUB]]
+  // CHECK: [[PRED2:%.+]] = cmpi slt, [[MAX]], [[SUB]]
+  // CHECK: [[UBOUND:%.+]] = select [[PRED2]], [[MAX]], [[LBOUND]]
+  // CHECK: [[TRUNC:%.+]] = trunci [[UBOUND]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %0 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 0 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i16
+  %1 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32639 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK: [[EXT:%.+]] = sexti %arg1 : i8 to i32
+  %2 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 32640 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8>
+
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: @test_reshape_downrank
 func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {


        


More information about the Mlir-commits mailing list