[Mlir-commits] [mlir] [TOSA] tosa.negate operator lowering update (PR #107924)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 9 14:54:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
This PR makes tosa.negate op for integer types to use the simplified calculation branch if input_zp and output_zp values are also zero.
---
Full diff: https://github.com/llvm/llvm-project/pull/107924.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-17)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-3)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ba259d4b84fceb..93e284af051883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -139,19 +139,22 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
- !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto constant =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
- return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
- }
+ if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
+ int64_t inZp = 0, outZp = 0;
+
+ if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
+ auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
+ inZp = quantizationInfo.value().getInputZp();
+ outZp = quantizationInfo.value().getOutputZp();
+ }
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
- cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- int64_t inZp = quantizationInfo.value().getInputZp();
- int64_t outZp = quantizationInfo.value().getOutputZp();
+ if (!inZp && !outZp) {
+ auto constant = rewriter.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 0));
+ return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+ args[0]);
+ }
// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
@@ -402,17 +405,19 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (intTy.isUnsignedInteger()) {
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
- maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
- .getZExtValue();
+ maxRepresentable =
+ (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+ .getZExtValue();
}
- } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+ } else if (intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue();
+ .getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue();
+ .getSExtValue();
}
- // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+ // Ensure that the bounds are representable as n-bit signed/unsigned
+ // integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..f9d37f9427d4f4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -857,16 +857,16 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
@@ -878,6 +878,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
%2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
+ // CHECK: linalg.yield [[SUB]]
+ %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/107924
More information about the Mlir-commits
mailing list