[Mlir-commits] [mlir] [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (PR #158782)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 21 18:17:07 PDT 2025
https://github.com/ShivaChen updated https://github.com/llvm/llvm-project/pull/158782
>From 60372437bdc11b6c595eb1cefe7972cf54011a41 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Fri, 29 Aug 2025 10:08:07 +0100
Subject: [PATCH 1/3] [mlir][tosa] support NegateOp with dynamic extension in
TosaToLinalg
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 80 +++++++++++--------
.../TosaToLinalg/tosa-to-linalg.mlir | 31 +++++++
2 files changed, 76 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..91e0f235349f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::NegateOp>(op)) {
auto negate = cast<tosa::NegateOp>(op);
+ int64_t inZp = 0, outZp = 0;
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
- if (failed(maybeInZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input1 zero point cannot be statically determined");
- return nullptr;
- }
-
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
- if (failed(maybeOutZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return nullptr;
- }
-
- int64_t inZp = *maybeInZp;
- int64_t outZp = *maybeOutZp;
+ if (!failed(maybeInZp))
+ inZp = *maybeInZp;
+ if (!failed(maybeOutZp))
+ outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
- if (!inZp && !outZp) {
+ if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) {
auto constant = arith::ConstantOp::create(
rewriter, loc, IntegerAttr::get(elementTy, 0));
return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
args[0]);
}
+ Value zpAddValue;
+ Type intermediateType;
// Compute the maximum value that can occur in the intermediate buffer.
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- const int64_t zpAdd = inZp + outZp;
- const int64_t maxValue =
- APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to
- // represent it. We assume 48-bit numbers may be supported further in
- // the pipeline.
int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+ // Compute the maximum value that can occur in the intermediate buffer.
+ const int64_t zpAdd = inZp + outZp;
+ const int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in
+ // the pipeline.
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
+
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ } else {
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ auto arg1 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ auto arg2 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ zpAddValue =
+ rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ }
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
@@ -1013,9 +1018,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
else
return operands.take_front(3);
}
- // Input1_zp and output_zp cannot broadcast
- if (isa<tosa::NegateOp>(operation))
+ if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+ FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+ FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+ if (failed(maybeOutZp) && failed(maybeInZp))
+ return operands;
+ // Input1_zp and output_zp cannot broadcast when they are constants.
return operands.take_front(1);
+ }
return operands;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..780344764e014 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,37 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// -----
+func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+ // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+ %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+ return %cast : tensor<*xf16>
+}
+
+// -----
+
+func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+ // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+ // CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+ // CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+ // CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+ // CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+ // CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64
+ // CHECK: [[C32767:%.*]] = arith.constant 32767 : i64
+ // CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+ // CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+ %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+ return %cast : tensor<*xi16>
+}
+
+// -----
+
// CHECK-LABEL: @test_identity
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>
>From 873292f42550f78c8108effb1b6d52dbea658247 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 22 Sep 2025 01:56:05 +0100
Subject: [PATCH 2/3] Define hasInZp and hasOutZp
---
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 91e0f235349f0..d8016eb09efb3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -189,9 +189,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int64_t inZp = 0, outZp = 0;
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
- if (!failed(maybeInZp))
+ bool hasInZp = !failed(maybeInZp);
+ bool hasOutZp = !failed(maybeOutZp);
+ if (hasInZp)
inZp = *maybeInZp;
- if (!failed(maybeOutZp))
+ if (hasOutZp)
outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
@@ -211,7 +213,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
int intermediateBitWidth = 64;
- if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+ if (hasInZp && hasOutZp) {
// Compute the maximum value that can occur in the intermediate buffer.
const int64_t zpAdd = inZp + outZp;
const int64_t maxValue =
>From b5dff7579c34ef46f680c6df187ad900a2c6fec3 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 22 Sep 2025 02:03:13 +0100
Subject: [PATCH 3/3] Add CHECK-LABEL in test cases
---
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 780344764e014..2163dbb0d4561 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,7 +899,8 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// -----
-func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+// CHECK-LABEL: @test_negate_no_const_1
+func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
// CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
@@ -910,7 +911,8 @@ func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %a
// -----
-func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+// CHECK-LABEL: @test_negate_no_const_2
+func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
// CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
More information about the Mlir-commits
mailing list