[Mlir-commits] [mlir] bd67b8f - [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 22 08:42:10 PDT 2025


Author: ShivaChen
Date: 2025-09-22T16:42:07+01:00
New Revision: bd67b8ff68937371ccc48016f737fdcb381b248e

URL: https://github.com/llvm/llvm-project/commit/bd67b8ff68937371ccc48016f737fdcb381b248e
DIFF: https://github.com/llvm/llvm-project/commit/bd67b8ff68937371ccc48016f737fdcb381b248e.diff

LOG: [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..e3602111cb1dd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,63 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::NegateOp>(op)) {
     auto negate = cast<tosa::NegateOp>(op);
 
+    int64_t inZp = 0, outZp = 0;
     FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
-    if (failed(maybeInZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "input1 zero point cannot be statically determined");
-      return nullptr;
-    }
-
     FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
-    if (failed(maybeOutZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "output zero point cannot be statically determined");
-      return nullptr;
-    }
-
-    int64_t inZp = *maybeInZp;
-    int64_t outZp = *maybeOutZp;
+    bool hasInZp = !failed(maybeInZp);
+    bool hasOutZp = !failed(maybeOutZp);
+    if (hasInZp)
+      inZp = *maybeInZp;
+    if (hasOutZp)
+      outZp = *maybeOutZp;
 
     if (isa<FloatType>(elementTy))
       return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
 
     if (isa<IntegerType>(elementTy)) {
-      if (!inZp && !outZp) {
+      if (hasInZp && hasOutZp && !inZp && !outZp) {
         auto constant = arith::ConstantOp::create(
             rewriter, loc, IntegerAttr::get(elementTy, 0));
         return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
                                      args[0]);
       }
 
+      Value zpAddValue;
+      Type intermediateType;
       // Compute the maximum value that can occur in the intermediate buffer.
       const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-      const int64_t zpAdd = inZp + outZp;
-      const int64_t maxValue =
-          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-          std::abs(zpAdd) + 1;
-
-      // Convert that maximum value into the maximum bitwidth needed to
-      // represent it. We assume 48-bit numbers may be supported further in
-      // the pipeline.
       int intermediateBitWidth = 64;
-      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-        intermediateBitWidth = 16;
-      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-        intermediateBitWidth = 32;
-      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-        intermediateBitWidth = 48;
-      }
 
-      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-      Value zpAddValue = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      if (hasInZp && hasOutZp) {
+        // Compute the maximum value that can occur in the intermediate buffer.
+        const int64_t zpAdd = inZp + outZp;
+        const int64_t maxValue =
+            APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+            std::abs(zpAdd) + 1;
+
+        // Convert that maximum value into the maximum bitwidth needed to
+        // represent it. We assume 48-bit numbers may be supported further in
+        // the pipeline.
+        if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+          intermediateBitWidth = 16;
+        } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+          intermediateBitWidth = 32;
+        } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+          intermediateBitWidth = 48;
+        }
+
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        zpAddValue = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      } else {
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        auto arg1 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+        auto arg2 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+        zpAddValue =
+            rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+      }
 
       // The negation can be applied by doing:
       //  outputValue = inZp + outZp - inputValue
@@ -1013,9 +1020,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
     else
       return operands.take_front(3);
   }
-  // Input1_zp and output_zp cannot broadcast
-  if (isa<tosa::NegateOp>(operation))
+  if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+    FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+    FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+    if (failed(maybeOutZp) && failed(maybeInZp))
+      return operands;
+    // Input1_zp and output_zp cannot broadcast when they are constants.
     return operands.take_front(1);
+  }
   return operands;
 }
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..2163dbb0d4561 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,39 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_negate_no_const_1
+func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+  // CHECK:   [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+  %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+  return %cast : tensor<*xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_negate_no_const_2
+func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+  // CHECK:   [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+  // CHECK:   [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+  // CHECK:   [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+  // CHECK:   [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+  // CHECK:   [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+  // CHECK:   [[C_32768:%.*]] = arith.constant -32768 : i64
+  // CHECK:   [[C32767:%.*]] = arith.constant 32767 : i64
+  // CHECK:   [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+  // CHECK:   [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+  // CHECK:   [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+  %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+  return %cast : tensor<*xi16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_identity
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>


        


More information about the Mlir-commits mailing list