[Mlir-commits] [mlir] [mlir][tosa] Switch zero point of negate to input variable type (PR #129758)

Luke Hutton llvmlistbot at llvm.org
Mon Mar 10 07:42:17 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/129758

>From 56aec5788295f83816e3ce5643f7cf7bb7b724bb Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 27 Nov 2024 10:50:11 +0000
Subject: [PATCH] [mlir][tosa] Switch zero point of negate to input variable
 type

This commit changes the zero point attribute to an input to align
with the 1.0 spec.

Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184
Signed-off-by: Tai Ly <tai.ly at arm.com>
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  |  10 +-
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   6 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  21 +++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  29 +++--
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp |  42 +++++++-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  31 +++++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 101 +++++++++++++++---
 .../TosaToLinalg/tosa-to-linalg.mlir          |  50 ++++-----
 .../Dialect/Mesh/sharding-propagation.mlir    |  13 ++-
 mlir/test/Dialect/Mesh/spmdization.mlir       |  15 ++-
 mlir/test/Dialect/Tosa/availability.mlir      |   4 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  45 +++++++-
 mlir/test/Dialect/Tosa/invalid.mlir           |  64 +++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir       |   4 +-
 mlir/test/Dialect/Tosa/ops.mlir               |   4 +-
 mlir/test/Dialect/Tosa/quant-test.mlir        |   4 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  12 ++-
 17 files changed, 364 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3fd4c3d1d3e1..efc329ee48849 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -114,8 +114,12 @@ profileComplianceMap = {
     {"tosa.logical_not",
      {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
     {"tosa.negate",
-     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
-      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T, i8T},
+        {i16T, i16T, i16T, i16T},
+        {i32T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.reciprocal",
      {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
     {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -310,7 +314,7 @@ extensionComplianceMap = {
     {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
-    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
     {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..da4daa03aa652 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
                                   input, kernel, stride, pad, acc_type);
   }]>;
 
-// This builder is called on single-parameter unary operators that have a scale
+// This builder is called on single-parameter negate operators that have a scale
 // relationship between their input and output, expressed by the
 // UnaryOpQuantizationAttr.
-def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
+def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
   (ins "Type":$outputType, "Value":$input),
   [{
-    buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
+    buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
   }]>;
 
 // These builders are called on the TOSA pad operator that needs to create its
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ecddc9fe9a13d..52c80e975f290 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1356,7 +1356,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
 //===----------------------------------------------------------------------===//
 // Operator: negate
 //===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
+    TosaElementwiseOperator,
+    Pure]> {
   let summary = "Elementwise negate op";
 
   let description = [{
@@ -1365,8 +1367,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<I32Attr>:$input1_zp,
-      OptionalAttr<I32Attr>:$output_zp
+      Tosa_ScalarIntOrFloatTensor:$input1_zp,
+      Tosa_ScalarIntOrFloatTensor:$output_zp
   );
 
   let results = (outs
@@ -1378,9 +1380,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
     Extension<[Tosa_EXT_BF16]>,
   ];
 
-  let builders = [Tosa_UnaryOpQuantInfoBuilder];
+  let builders = [Tosa_NegateOpQuantInfoBuilder];
+
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInput1ZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInput1ZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
 
   let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..7772c186b526a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::NegateOp
   if (isa<tosa::NegateOp>(op)) {
-    if (isa<FloatType>(elementTy))
-      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+    auto negate = cast<tosa::NegateOp>(op);
 
-    if (isa<IntegerType>(elementTy)) {
-      auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
-      auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+    FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+    if (failed(maybeInZp)) {
+      (void)rewriter.notifyMatchFailure(
+          op, "input1 zero point cannot be statically determined");
+      return nullptr;
+    }
+
+    FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+    if (failed(maybeOutZp)) {
+      (void)rewriter.notifyMatchFailure(
+          op, "output zero point cannot be statically determined");
+      return nullptr;
+    }
 
-      const int64_t inZp =
-          inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
-      const int64_t outZp =
-          outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
+    int64_t inZp = *maybeInZp;
+    int64_t outZp = *maybeOutZp;
 
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+
+    if (isa<IntegerType>(elementTy)) {
       if (!inZp && !outZp) {
         auto constant = rewriter.create<arith::ConstantOp>(
             loc, IntegerAttr::get(elementTy, 0));
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index 6dcb7c845b21f..be29298a35aeb 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -62,6 +62,45 @@ struct MatMulOpSharding
   }
 };
 
+struct NegateOpSharding
+    : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = dyn_cast<RankedTensorType>(val.getType());
+    if (!type)
+      return {};
+    SmallVector<utils::IteratorType> types(type.getRank(),
+                                           utils::IteratorType::parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = dyn_cast<RankedTensorType>(val.getType());
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    SmallVector<AffineMap> maps = {
+        AffineMap::getMultiDimIdentityMap(rank, ctx),
+        AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
+        AffineMap::getMultiDimIdentityMap(rank, ctx)};
+    return maps;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+};
+
 template <typename OpType>
 static void registerElemwiseOne(MLIRContext *ctx) {
   OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -84,9 +123,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
         BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
         LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
         MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
-        LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+        LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
         GreaterOp, GreaterEqualOp>(ctx);
 
     MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+    NegateOp::attachInterface<NegateOpSharding>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3e99c1f717d09..09d2c5d35263c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
-  auto input = getInput1();
   // Element-wise negate(negate(x)) = x
-  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
-    return op.getInput1();
+  // iff all zero points are constant 0
+  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
+  if (!definingOp) {
+    // defining op of input1 is not a negate, cannot fold
+    return {};
   }
 
-  return {};
+  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+      failed(maybeIZp) || *maybeIZp != 0) {
+    // input1 zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+      failed(maybeOZp) || *maybeOZp != 0) {
+    // output zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
+      failed(maybeIZp) || *maybeIZp != 0) {
+    // definingOp's input1 zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
+      failed(maybeOZp) || *maybeOZp != 0) {
+    // definingOp's output zero point is not constant 0, cannot fold
+    return {};
+  }
+
+  return definingOp.getInput1();
 }
 
 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7a991b3876f69..219775c31bd56 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
-/// This builder is called on single-parameter unary operators that have scale
-/// relationship between their input and output, expressed by the
-/// UnaryOpQuantizationAttr.
-static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
-                                      OperationState &result, Type outputType,
-                                      Value input) {
-  result.addOperands(input);
+/// This builder is called on single-parameter negate operator
+/// to construct input and output zero points based on their
+/// types.
+static void buildNegateOpWithQuantInfo(OpBuilder &builder,
+                                       OperationState &result, Type outputType,
+                                       Value input) {
+  const Location loc{result.location};
+  int64_t input1Zp{0};
+  int64_t outputZp{0};
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
   if (quantAttr) {
-    // note: negateOp has attributes input1_zp and output_zp
-    result.addAttribute("input1_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getInputZp())));
-    result.addAttribute("output_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getOutputZp())));
+    input1Zp = quantAttr.getInputZp();
+    outputZp = quantAttr.getOutputZp();
+  }
+  const std::optional<Value> input1ZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), input1Zp);
+  if (!input1ZpOp) {
+    (void)emitError(
+        loc, "Failed to create input1 zero point for quantized NEGATE op");
+  }
+
+  const std::optional<Value> outputZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), outputZp);
+  if (!outputZpOp) {
+    (void)emitError(
+        loc, "Failed to create output zero point for quantized NEGATE op");
   }
+
+  if (input1ZpOp && outputZpOp) {
+    result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
+  } else {
+    // failed to create one or more zero points above: just add input as
+    // operands. This will trigger error in building the op because of
+    // missing zero points
+    result.addOperands({input});
+  }
+
   result.types.push_back(outputType);
 }
 
@@ -1728,6 +1748,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
 ZERO_POINT_HELPER(MatMulOp, A)
 ZERO_POINT_HELPER(MatMulOp, B)
+ZERO_POINT_HELPER(NegateOp, Input1)
+ZERO_POINT_HELPER(NegateOp, Output)
+
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2230,7 +2253,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
 NARY_SHAPE_INFER(tosa::LogicalXorOp)
 NARY_SHAPE_INFER(tosa::MaximumOp)
 NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2243,6 +2265,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
 NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
+LogicalResult tosa::NegateOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    NegateOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
+LogicalResult tosa::NegateOp::verify() {
+  // Verify same element type
+  const Type input1Type = getInput1().getType();
+  const Type outputType = getOutput().getType();
+  if (verifySameElementTypes(*this, input1Type, outputType).failed())
+    return failure();
+
+  // Verify same shape
+  const SmallVector<Type, 2> types = {input1Type, outputType};
+  if (failed(verifyCompatibleShapes(types)))
+    return emitOpError() << "requires the same shape for input1 and output";
+
+  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
+  const Type input1ZpEType =
+      getStorageElementTypeOrSelf(getInput1Zp().getType());
+  if (input1EType != input1ZpEType) {
+    return emitOpError("expect both input1 and its zero point are the same "
+                       "element type, got ")
+           << input1EType << " and " << input1ZpEType;
+  }
+  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
+  const Type outputZpEType =
+      getStorageElementTypeOrSelf(getOutputZp().getType());
+  if (outputEType != outputZpEType) {
+    return emitOpError("expect both output and its zero point are the same "
+                       "element type, got ")
+           << outputEType << " and " << outputZpEType;
+  }
+
+  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
+
+  return success();
+}
+
 static LogicalResult poolingInferReturnTypes(
     ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
     ArrayRef<int64_t> pad,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..1c7be0ed6f107 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -477,7 +477,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
-  %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32>
+  %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %5 = tosa.negate %0, %in_zp, %out_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: pow
@@ -662,10 +664,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: arith.subi [[ZERO]], %[[ARG1]]
-  %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+  %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: and
@@ -852,40 +856,22 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 // CHECK-LABEL: @test_negate_quantized
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
-  // CHECK: linalg.yield [[SUB]]
-  %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[C32639:%.+]] = arith.constant 32639
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
-  // CHECK: [[MIN:%.+]] = arith.constant -128
-  // CHECK: [[MAX:%.+]] = arith.constant 127
-  // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
-  // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
-  // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
-  // CHECK: linalg.yield [[TRUNC]]
-  %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[C32640:%.+]] = arith.constant 32640
-  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
-  // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
   // CHECK: [[MIN:%.+]] = arith.constant -128
   // CHECK: [[MAX:%.+]] = arith.constant 127
   // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+  %in_zp0 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %out_zp0 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
   // CHECK: [[C_128:%.+]] = arith.constant -128
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -895,14 +881,18 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+  %in_zp3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %out_zp3 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = tosa.negate %arg0, %in_zp3, %out_zp3 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
   // CHECK: linalg.yield [[SUB]]
-  %4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+  %in_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %out_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %4 = tosa.negate %arg0, %in_zp4, %out_zp4 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   return
 }
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 14c67e670e921..aa5fa00488f08 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -77,7 +77,7 @@ func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf
 
 // CHECK-LABEL: func.func @arrow_structure
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
   // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
   // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
@@ -85,12 +85,15 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
   %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
   // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT:   %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]]  : tensor<8x16xf32>
-  %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]]
+  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]]  : tensor<8x16xf32>
+  %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
+  // CHECK-NEXT:  %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
   // CHECK-NEXT:  %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
   // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]]  : tensor<8x16xf32>
-  %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
   %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
   %3 = mesh.shard %2 to %s3  : tensor<8x16xf32>
   // CHECK-NEXT: return %[[V6]], %[[V8]]
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 59f7162e21013..5c9fd29444f04 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -176,7 +176,7 @@ func.func @multiple_chained_ops(
   %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
   // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
   %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : 
+  // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
   %6 = mesh.shard %5 to %s6  : tensor<2xi8>
@@ -207,7 +207,11 @@ mesh.mesh @mesh_1d_4(shape = 4)
 // CHECK-LABEL: func @ew_chain_with_halo
 func.func @ew_chain_with_halo(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
-  %arg0: tensor<8x16xf32>)
+  %arg0: tensor<8x16xf32>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
+  %arg1: tensor<1xf32>,
+  // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
+  %arg2: tensor<1xf32>)
   // CHECK-SAME: -> tensor<5x16xf32>
    -> tensor<8x16xf32> {
   %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
@@ -224,8 +228,11 @@ func.func @ew_chain_with_halo(
   %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2  : tensor<8x16xf32>
   %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
   %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4  annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
-  %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
+  %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding
+  %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
+  %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
+  %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
   %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
   %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5  : tensor<8x16xf32>
   %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index b786264d84106..820ea6559b848 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -380,7 +380,9 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
 func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [bf16] ]
-  %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.negate %arg0, %input_zp, %output_zp : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 4242f68609634..e2575c764fdfe 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -856,13 +856,54 @@ func.func @fold_exp_log(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 // CHECK-LABEL: @fold_negate_negate
 func.func @fold_negate_negate(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg{{.*}} : tensor<?x1xf32>
-  %0 = tosa.negate %arg0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
-  %1 = tosa.negate %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
+  %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  %1 = tosa.negate %0, %in_zp, %out_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
   return %1 : tensor<?x1xf32>
 }
 
 // -----
 
+// CHECK-LABEL: @no_fold_negate_negate_non_const_zp
+func.func @no_fold_negate_negate_non_const_zp(%arg0: tensor<?x1xf32>, %in_zp: tensor<1xf32>) -> tensor<?x1xf32> {
+  // cannot fold if any zp is not constant
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  %zero = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.negate %arg0, %in_zp, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  %1 = tosa.negate %0, %zero, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  %2 = tosa.negate %1, %zero, %in_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  %3 = tosa.negate %2, %zero, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  %4 = tosa.negate %3, %in_zp, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+  return %4 : tensor<?x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_negate_negate_non_zero_zp
+func.func @no_fold_negate_negate_non_zero_zp(%arg0: tensor<?x1xi8>) -> tensor<?x1xi8> {
+  // cannot fold if any zp is not constant 0
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  // CHECK: tosa.negate
+  %zero = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %one = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.negate %arg0, %zero, %one : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+  %1 = tosa.negate %0, %zero, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+  %2 = tosa.negate %1, %one, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+  %3 = tosa.negate %2, %zero, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+  %4 = tosa.negate %3, %zero, %one : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+  return %4 : tensor<?x1xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_abs_abs
 func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: %[[ABS:.*]] = tosa.abs %arg{{.*}} : (tensor<?x1xf32>) -> tensor<?x1xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f536444f6379e..c4fe9c1a6cabc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1652,3 +1652,67 @@ func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1
 %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<1x14x28xf32>
   return %0 : tensor<1x14x28xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_negate_same_element_type
+func.func @test_negate_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf32> {
+  // expected-error at +1 {{'tosa.negate' op expect input and output to have same element type, got 'f16' and 'f32'}}
+  %0 = tosa.negate %arg0, %arg1, %arg2
+      : (tensor<1x16x16x8xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf32>
+  return %0 : tensor<1x16x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_same_shape
+func.func @test_negate_same_shape(%arg0: tensor<1x16x16x16xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> {
+  // expected-error at +1 {{'tosa.negate' op requires the same shape for input1 and output}}
+  %0 = tosa.negate %arg0, %arg1, %arg2
+      : (tensor<1x16x16x16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf16>
+  return %0 : tensor<1x16x16x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_input_zp_same_element_type
+func.func @test_negate_input_zp_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> {
+  // expected-error at +1 {{'tosa.negate' op expect both input1 and its zero point are the same element type, got 'f16' and 'i8'}}
+  %0 = tosa.negate %arg0, %arg1, %arg2
+      : (tensor<1x16x16x8xf16>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xf16>
+  return %0 : tensor<1x16x16x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_output_zp_same_element_type
+func.func @test_negate_output_zp_same_element_type(%arg0: tensor<1x16x16x8xi8>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xi8> {
+  // expected-error at +1 {{'tosa.negate' op expect both output and its zero point are the same element type, got 'i8' and 'f16'}}
+  %0 = tosa.negate %arg0, %arg1, %arg2
+      : (tensor<1x16x16x8xi8>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xi8>
+  return %0 : tensor<1x16x16x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_input_zp_non_zero
+func.func @test_negate_input_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
+  %input_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  // expected-error at +1 {{'tosa.negate' op input1 zero point must be zero for non-int8 integer types}}
+  %0 = tosa.negate %arg0, %input_zp, %output_zp
+      : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
+  return %0 : tensor<1x16x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_output_zp_non_zero
+func.func @test_negate_output_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
+  %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  // expected-error at +1 {{'tosa.negate' op output zero point must be zero for non-int8 integer types}}
+  %0 = tosa.negate %arg0, %input_zp, %output_zp
+      : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
+  return %0 : tensor<1x16x16x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 6d8237635d0ec..2f7250dabe162 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -249,9 +249,9 @@ func.func @test_logical_not_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xi1>) -> te
 
 // -----
 
-func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.negate' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.negate %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  %0 = tosa.negate %arg0, %arg1, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 480d8c327ab83..916886025cb0e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -487,8 +487,8 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
 
 // -----
 // CHECK-LABEL: negate
-func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_negate(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index 447a6ef7f9e05..f0ad4eb4fdb0b 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -2,9 +2,9 @@
 
 // -----
 // CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
   //  CHECK: tosa.negate
-  %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
   return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index deede4b0afadc..3d0ded8c58ac5 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -45,8 +45,10 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
   // CHECK: tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
   %5 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.negate %arg0 : (tensor<4xf32>) -> tensor<4xf32>
-  %6 = tosa.negate %arg0 : (tensor<4xf32>) -> tensor<*xf32>
+  %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %6 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
 
   // CHECK: tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<4xf32>
   %7 = tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<*xf32>
@@ -87,8 +89,10 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
   // CHECK: tosa.clz %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %3 = tosa.clz %arg0 : (tensor<4xi32>) -> tensor<*xi32>
 
-  // CHECK: tosa.negate %arg0 : (tensor<4xi32>) -> tensor<4xi32>
-  %4 = tosa.negate %arg0 : (tensor<4xi32>) -> tensor<*xi32>
+  %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %4 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
 
   // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
   %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>



More information about the Mlir-commits mailing list