[Mlir-commits] [mlir] 2778d9d - [TOSA] tosa.negate operator lowering update (#107924)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 02:47:03 PDT 2024


Author: Dmitriy Smirnov
Date: 2024-09-10T10:46:58+01:00
New Revision: 2778d9d63bcb6eb7d58bbab1131b4e711c1f60c4

URL: https://github.com/llvm/llvm-project/commit/2778d9d63bcb6eb7d58bbab1131b4e711c1f60c4
DIFF: https://github.com/llvm/llvm-project/commit/2778d9d63bcb6eb7d58bbab1131b4e711c1f60c4.diff

LOG: [TOSA] tosa.negate operator lowering update (#107924)

This PR makes tosa.negate op for integer types to use the simplified
calculation branch if input_zp and output_zp values are also zero.

Signed-off-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ba259d4b84fceb..93e284af051883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
-      !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-    auto constant =
-        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
-    return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
-  }
+  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
+    int64_t inZp = 0, outZp = 0;
+
+    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
+      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
+      inZp = quantizationInfo.value().getInputZp();
+      outZp = quantizationInfo.value().getOutputZp();
+    }
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
-      cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-    auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    int64_t inZp = quantizationInfo.value().getInputZp();
-    int64_t outZp = quantizationInfo.value().getOutputZp();
+    if (!inZp && !outZp) {
+      auto constant = rewriter.create<arith::ConstantOp>(
+          loc, IntegerAttr::get(elementTy, 0));
+      return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+                                            args[0]);
+    }
 
     // Compute the maximum value that can occur in the intermediate buffer.
     int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     if (intTy.isUnsignedInteger()) {
       minRepresentable = 0;
       if (intTy.getIntOrFloatBitWidth() <= 63) {
-        maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
-                          .getZExtValue();
+        maxRepresentable =
+            (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+                .getZExtValue();
       }
-    } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+    } else if (intTy.getIntOrFloatBitWidth() <= 64) {
       // Ensure that min & max fit into signed n-bit constants.
       minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue();
+                             .getSExtValue();
       maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                            .getSExtValue();
+                             .getSExtValue();
     }
-    // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+    // Ensure that the bounds are representable as n-bit signed/unsigned
+    // integers.
     min = std::max(min, minRepresentable);
     max = std::max(max, minRepresentable);
     min = std::min(min, maxRepresentable);

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..f9d37f9427d4f4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -857,16 +857,16 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
+  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
   // CHECK: [[MIN:%.+]] = arith.constant -128
   // CHECK: [[MAX:%.+]] = arith.constant 127
   // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
@@ -878,6 +878,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
   %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
 
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
+  // CHECK: linalg.yield [[SUB]]
+  %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+
   return
 }
 


        


More information about the Mlir-commits mailing list