[Mlir-commits] [mlir] [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (PR #158782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 21 18:17:07 PDT 2025


https://github.com/ShivaChen updated https://github.com/llvm/llvm-project/pull/158782

>From 60372437bdc11b6c595eb1cefe7972cf54011a41 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Fri, 29 Aug 2025 10:08:07 +0100
Subject: [PATCH 1/3] [mlir][tosa] support NegateOp with dynamic extension in
 TosaToLinalg

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 80 +++++++++++--------
 .../TosaToLinalg/tosa-to-linalg.mlir          | 31 +++++++
 2 files changed, 76 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..91e0f235349f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::NegateOp>(op)) {
     auto negate = cast<tosa::NegateOp>(op);
 
+    int64_t inZp = 0, outZp = 0;
     FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
-    if (failed(maybeInZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "input1 zero point cannot be statically determined");
-      return nullptr;
-    }
-
     FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
-    if (failed(maybeOutZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "output zero point cannot be statically determined");
-      return nullptr;
-    }
-
-    int64_t inZp = *maybeInZp;
-    int64_t outZp = *maybeOutZp;
+    if (!failed(maybeInZp))
+      inZp = *maybeInZp;
+    if (!failed(maybeOutZp))
+      outZp = *maybeOutZp;
 
     if (isa<FloatType>(elementTy))
       return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
 
     if (isa<IntegerType>(elementTy)) {
-      if (!inZp && !outZp) {
+      if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) {
         auto constant = arith::ConstantOp::create(
             rewriter, loc, IntegerAttr::get(elementTy, 0));
         return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
                                      args[0]);
       }
 
+      Value zpAddValue;
+      Type intermediateType;
       // Compute the maximum value that can occur in the intermediate buffer.
       const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-      const int64_t zpAdd = inZp + outZp;
-      const int64_t maxValue =
-          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-          std::abs(zpAdd) + 1;
-
-      // Convert that maximum value into the maximum bitwidth needed to
-      // represent it. We assume 48-bit numbers may be supported further in
-      // the pipeline.
       int intermediateBitWidth = 64;
-      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-        intermediateBitWidth = 16;
-      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-        intermediateBitWidth = 32;
-      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-        intermediateBitWidth = 48;
-      }
 
-      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-      Value zpAddValue = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+        // Compute the maximum value that can occur in the intermediate buffer.
+        const int64_t zpAdd = inZp + outZp;
+        const int64_t maxValue =
+            APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+            std::abs(zpAdd) + 1;
+
+        // Convert that maximum value into the maximum bitwidth needed to
+        // represent it. We assume 48-bit numbers may be supported further in
+        // the pipeline.
+        if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+          intermediateBitWidth = 16;
+        } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+          intermediateBitWidth = 32;
+        } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+          intermediateBitWidth = 48;
+        }
+
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        zpAddValue = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      } else {
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        auto arg1 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+        auto arg2 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+        zpAddValue =
+            rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+      }
 
       // The negation can be applied by doing:
       //  outputValue = inZp + outZp - inputValue
@@ -1013,9 +1018,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
     else
       return operands.take_front(3);
   }
-  // Input1_zp and output_zp cannot broadcast
-  if (isa<tosa::NegateOp>(operation))
+  if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+    FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+    FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+    if (failed(maybeOutZp) && failed(maybeInZp))
+      return operands;
+    // Input1_zp and output_zp cannot broadcast when they are constants.
     return operands.take_front(1);
+  }
   return operands;
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..780344764e014 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,37 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+  // CHECK:   [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+  %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+  return %cast : tensor<*xf16>
+}
+
+// -----
+
+func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+  // CHECK:   [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+  // CHECK:   [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+  // CHECK:   [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+  // CHECK:   [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+  // CHECK:   [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+  // CHECK:   [[C_32768:%.*]] = arith.constant -32768 : i64
+  // CHECK:   [[C32767:%.*]] = arith.constant 32767 : i64
+  // CHECK:   [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+  // CHECK:   [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+  // CHECK:   [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+  %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+  return %cast : tensor<*xi16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_identity
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>

>From 873292f42550f78c8108effb1b6d52dbea658247 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 22 Sep 2025 01:56:05 +0100
Subject: [PATCH 2/3] Define hasInZp and hasOutZp

---
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 91e0f235349f0..d8016eb09efb3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -189,9 +189,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     int64_t inZp = 0, outZp = 0;
     FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
     FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
-    if (!failed(maybeInZp))
+    bool hasInZp = !failed(maybeInZp);
+    bool hasOutZp = !failed(maybeOutZp);
+    if (hasInZp)
       inZp = *maybeInZp;
-    if (!failed(maybeOutZp))
+    if (hasOutZp)
       outZp = *maybeOutZp;
 
     if (isa<FloatType>(elementTy))
@@ -211,7 +213,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
       const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
       int intermediateBitWidth = 64;
 
-      if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+      if (hasInZp && hasOutZp) {
         // Compute the maximum value that can occur in the intermediate buffer.
         const int64_t zpAdd = inZp + outZp;
         const int64_t maxValue =

>From b5dff7579c34ef46f680c6df187ad900a2c6fec3 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 22 Sep 2025 02:03:13 +0100
Subject: [PATCH 3/3] Add CHECK-LABEL in test cases

---
 mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 780344764e014..2163dbb0d4561 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,7 +899,8 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
-func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+// CHECK-LABEL: @test_negate_no_const_1
+func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
   // CHECK: %[[GENERIC:.+]] = linalg.generic 
   // CHECK:   ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
   // CHECK:   [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
@@ -910,7 +911,8 @@ func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %a
 
 // -----
 
-func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+// CHECK-LABEL: @test_negate_no_const_2
+func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
   // CHECK: %[[GENERIC:.+]] = linalg.generic 
   // CHECK:   ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
   // CHECK:   [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64



More information about the Mlir-commits mailing list