[Mlir-commits] [mlir] bd67b8f - [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 08:42:10 PDT 2025
Author: ShivaChen
Date: 2025-09-22T16:42:07+01:00
New Revision: bd67b8ff68937371ccc48016f737fdcb381b248e
URL: https://github.com/llvm/llvm-project/commit/bd67b8ff68937371ccc48016f737fdcb381b248e
DIFF: https://github.com/llvm/llvm-project/commit/bd67b8ff68937371ccc48016f737fdcb381b248e.diff
LOG: [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..e3602111cb1dd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,63 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::NegateOp>(op)) {
auto negate = cast<tosa::NegateOp>(op);
+ int64_t inZp = 0, outZp = 0;
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
- if (failed(maybeInZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input1 zero point cannot be statically determined");
- return nullptr;
- }
-
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
- if (failed(maybeOutZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return nullptr;
- }
-
- int64_t inZp = *maybeInZp;
- int64_t outZp = *maybeOutZp;
+ bool hasInZp = !failed(maybeInZp);
+ bool hasOutZp = !failed(maybeOutZp);
+ if (hasInZp)
+ inZp = *maybeInZp;
+ if (hasOutZp)
+ outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
- if (!inZp && !outZp) {
+ if (hasInZp && hasOutZp && !inZp && !outZp) {
auto constant = arith::ConstantOp::create(
rewriter, loc, IntegerAttr::get(elementTy, 0));
return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
args[0]);
}
+ Value zpAddValue;
+ Type intermediateType;
// Compute the maximum value that can occur in the intermediate buffer.
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- const int64_t zpAdd = inZp + outZp;
- const int64_t maxValue =
- APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to
- // represent it. We assume 48-bit numbers may be supported further in
- // the pipeline.
int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ if (hasInZp && hasOutZp) {
+ // Compute the maximum value that can occur in the intermediate buffer.
+ const int64_t zpAdd = inZp + outZp;
+ const int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in
+ // the pipeline.
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
+
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ } else {
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ auto arg1 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ auto arg2 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ zpAddValue =
+ rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ }
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
@@ -1013,9 +1020,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
else
return operands.take_front(3);
}
- // Input1_zp and output_zp cannot broadcast
- if (isa<tosa::NegateOp>(operation))
+ if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+ FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+ FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+ if (failed(maybeOutZp) && failed(maybeInZp))
+ return operands;
+ // Input1_zp and output_zp cannot broadcast when they are constants.
return operands.take_front(1);
+ }
return operands;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..2163dbb0d4561 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,39 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_negate_no_const_1
+func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+ // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+ %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+ return %cast : tensor<*xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_negate_no_const_2
+func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+ // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+ // CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+ // CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+ // CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+ // CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+ // CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64
+ // CHECK: [[C32767:%.*]] = arith.constant 32767 : i64
+ // CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+ // CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+ %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+ return %cast : tensor<*xi16>
+}
+
+// -----
+
// CHECK-LABEL: @test_identity
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>
More information about the Mlir-commits
mailing list