[Mlir-commits] [mlir] [mlir][tosa] Switch zero point of negate to input variable type (PR #129758)
Luke Hutton
llvmlistbot at llvm.org
Mon Mar 10 07:42:17 PDT 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/129758
>From 56aec5788295f83816e3ce5643f7cf7bb7b724bb Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 27 Nov 2024 10:50:11 +0000
Subject: [PATCH] [mlir][tosa] Switch zero point of negate to input variable
type
This commit changes the zero point attribute to an input to align
with the 1.0 spec.
Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184
Signed-off-by: Tai Ly <tai.ly at arm.com>
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 10 +-
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 6 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 21 +++-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 29 +++--
.../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 42 +++++++-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 31 +++++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 101 +++++++++++++++---
.../TosaToLinalg/tosa-to-linalg.mlir | 50 ++++-----
.../Dialect/Mesh/sharding-propagation.mlir | 13 ++-
mlir/test/Dialect/Mesh/spmdization.mlir | 15 ++-
mlir/test/Dialect/Tosa/availability.mlir | 4 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 45 +++++++-
mlir/test/Dialect/Tosa/invalid.mlir | 64 +++++++++++
mlir/test/Dialect/Tosa/level_check.mlir | 4 +-
mlir/test/Dialect/Tosa/ops.mlir | 4 +-
mlir/test/Dialect/Tosa/quant-test.mlir | 4 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 12 ++-
17 files changed, 364 insertions(+), 91 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3fd4c3d1d3e1..efc329ee48849 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -114,8 +114,12 @@ profileComplianceMap = {
{"tosa.logical_not",
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
{"tosa.negate",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{i8T, i8T, i8T, i8T},
+ {i16T, i16T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}},
+ {{Profile::pro_fp},
+ {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.reciprocal",
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -310,7 +314,7 @@ extensionComplianceMap = {
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..da4daa03aa652 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
input, kernel, stride, pad, acc_type);
}]>;
-// This builder is called on single-parameter unary operators that have a scale
+// This builder is called on single-parameter negate operators that have a scale
// relationship between their input and output, expressed by the
// UnaryOpQuantizationAttr.
-def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
+def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
(ins "Type":$outputType, "Value":$input),
[{
- buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
+ buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
}]>;
// These builders are called on the TOSA pad operator that needs to create its
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ecddc9fe9a13d..52c80e975f290 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1356,7 +1356,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
//===----------------------------------------------------------------------===//
// Operator: negate
//===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
+ TosaElementwiseOperator,
+ Pure]> {
let summary = "Elementwise negate op";
let description = [{
@@ -1365,8 +1367,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
let arguments = (ins
Tosa_Tensor:$input1,
- OptionalAttr<I32Attr>:$input1_zp,
- OptionalAttr<I32Attr>:$output_zp
+ Tosa_ScalarIntOrFloatTensor:$input1_zp,
+ Tosa_ScalarIntOrFloatTensor:$output_zp
);
let results = (outs
@@ -1378,9 +1380,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
Extension<[Tosa_EXT_BF16]>,
];
- let builders = [Tosa_UnaryOpQuantInfoBuilder];
+ let builders = [Tosa_NegateOpQuantInfoBuilder];
+
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getInput1ZeroPoint();
+ FailureOr<int64_t> getOutputZeroPoint();
+ LogicalResult verifyInput1ZeroPoint(int64_t zp);
+ LogicalResult verifyOutputZeroPoint(int64_t zp);
+ }];
let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..7772c186b526a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::NegateOp
if (isa<tosa::NegateOp>(op)) {
- if (isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+ auto negate = cast<tosa::NegateOp>(op);
- if (isa<IntegerType>(elementTy)) {
- auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
- auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+ FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+ if (failed(maybeInZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "input1 zero point cannot be statically determined");
+ return nullptr;
+ }
+
+ FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+ if (failed(maybeOutZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return nullptr;
+ }
- const int64_t inZp =
- inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
- const int64_t outZp =
- outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
+ int64_t inZp = *maybeInZp;
+ int64_t outZp = *maybeOutZp;
+ if (isa<FloatType>(elementTy))
+ return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+
+ if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index 6dcb7c845b21f..be29298a35aeb 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -62,6 +62,45 @@ struct MatMulOpSharding
}
};
+struct NegateOpSharding
+ : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ Value val = op->getOperand(0);
+ auto type = dyn_cast<RankedTensorType>(val.getType());
+ if (!type)
+ return {};
+ SmallVector<utils::IteratorType> types(type.getRank(),
+ utils::IteratorType::parallel);
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getOperand(0);
+ auto type = dyn_cast<RankedTensorType>(val.getType());
+ if (!type)
+ return {};
+ int64_t rank = type.getRank();
+ SmallVector<AffineMap> maps = {
+ AffineMap::getMultiDimIdentityMap(rank, ctx),
+ AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
+ AffineMap::getMultiDimIdentityMap(rank, ctx)};
+ return maps;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+};
+
template <typename OpType>
static void registerElemwiseOne(MLIRContext *ctx) {
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -84,9 +123,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
- LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+ LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
GreaterOp, GreaterEqualOp>(ctx);
MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+ NegateOp::attachInterface<NegateOpSharding>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3e99c1f717d09..09d2c5d35263c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
- auto input = getInput1();
// Element-wise negate(negate(x)) = x
- if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
- return op.getInput1();
+ // iff all zero points are constant 0
+ auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
+ if (!definingOp) {
+ // defining op of input1 is not a negate, cannot fold
+ return {};
}
- return {};
+ if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+ failed(maybeIZp) || *maybeIZp != 0) {
+ // input1 zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ failed(maybeOZp) || *maybeOZp != 0) {
+ // output zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
+ failed(maybeIZp) || *maybeIZp != 0) {
+ // definingOp's input1 zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
+ failed(maybeOZp) || *maybeOZp != 0) {
+ // definingOp's output zero point is not constant 0, cannot fold
+ return {};
+ }
+
+ return definingOp.getInput1();
}
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7a991b3876f69..219775c31bd56 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}
-/// This builder is called on single-parameter unary operators that have scale
-/// relationship between their input and output, expressed by the
-/// UnaryOpQuantizationAttr.
-static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
- OperationState &result, Type outputType,
- Value input) {
- result.addOperands(input);
+/// This builder is called on single-parameter negate operator
+/// to construct input and output zero points based on their
+/// types.
+static void buildNegateOpWithQuantInfo(OpBuilder &builder,
+ OperationState &result, Type outputType,
+ Value input) {
+ const Location loc{result.location};
+ int64_t input1Zp{0};
+ int64_t outputZp{0};
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
- // note: negateOp has attributes input1_zp and output_zp
- result.addAttribute("input1_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getInputZp())));
- result.addAttribute("output_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getOutputZp())));
+ input1Zp = quantAttr.getInputZp();
+ outputZp = quantAttr.getOutputZp();
+ }
+ const std::optional<Value> input1ZpOp =
+ createZeroPointTensor(builder, loc, input.getType(), input1Zp);
+ if (!input1ZpOp) {
+ (void)emitError(
+ loc, "Failed to create input1 zero point for quantized NEGATE op");
+ }
+
+ const std::optional<Value> outputZpOp =
+ createZeroPointTensor(builder, loc, input.getType(), outputZp);
+ if (!outputZpOp) {
+ (void)emitError(
+ loc, "Failed to create output zero point for quantized NEGATE op");
}
+
+ if (input1ZpOp && outputZpOp) {
+ result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
+ } else {
+ // failed to create one or more zero points above: just add input as
+ // operands. This will trigger error in building the op because of
+ // missing zero points
+ result.addOperands({input});
+ }
+
result.types.push_back(outputType);
}
@@ -1728,6 +1748,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
ZERO_POINT_HELPER(MatMulOp, A)
ZERO_POINT_HELPER(MatMulOp, B)
+ZERO_POINT_HELPER(NegateOp, Input1)
+ZERO_POINT_HELPER(NegateOp, Output)
+
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2230,7 +2253,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2243,6 +2265,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
+LogicalResult tosa::NegateOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ NegateOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult tosa::NegateOp::verify() {
+ // Verify same element type
+ const Type input1Type = getInput1().getType();
+ const Type outputType = getOutput().getType();
+ if (verifySameElementTypes(*this, input1Type, outputType).failed())
+ return failure();
+
+ // Verify same shape
+ const SmallVector<Type, 2> types = {input1Type, outputType};
+ if (failed(verifyCompatibleShapes(types)))
+ return emitOpError() << "requires the same shape for input1 and output";
+
+ const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
+ const Type input1ZpEType =
+ getStorageElementTypeOrSelf(getInput1Zp().getType());
+ if (input1EType != input1ZpEType) {
+ return emitOpError("expect both input1 and its zero point are the same "
+ "element type, got ")
+ << input1EType << " and " << input1ZpEType;
+ }
+ const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
+ const Type outputZpEType =
+ getStorageElementTypeOrSelf(getOutputZp().getType());
+ if (outputEType != outputZpEType) {
+ return emitOpError("expect both output and its zero point are the same "
+ "element type, got ")
+ << outputEType << " and " << outputZpEType;
+ }
+
+ FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+ if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+ return failure();
+
+ return success();
+}
+
static LogicalResult poolingInferReturnTypes(
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
ArrayRef<int64_t> pad,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..1c7be0ed6f107 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -477,7 +477,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.negf
- %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32>
+ %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %5 = tosa.negate %0, %in_zp, %out_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: pow
@@ -662,10 +664,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
- %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: and
@@ -852,40 +856,22 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK-LABEL: @test_negate_quantized
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
- // CHECK: linalg.yield [[SUB]]
- %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[C32639:%.+]] = arith.constant 32639
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
- // CHECK: [[MIN:%.+]] = arith.constant -128
- // CHECK: [[MAX:%.+]] = arith.constant 127
- // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
- // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
- // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
- // CHECK: linalg.yield [[TRUNC]]
- %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[C32640:%.+]] = arith.constant 32640
- // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
- // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+ %in_zp0 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %out_zp0 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
// CHECK: [[C_128:%.+]] = arith.constant -128
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -895,14 +881,18 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+ %in_zp3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %out_zp3 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.negate %arg0, %in_zp3, %out_zp3 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
// CHECK: linalg.yield [[SUB]]
- %4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %in_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %out_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %4 = tosa.negate %arg0, %in_zp4, %out_zp4 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
return
}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 14c67e670e921..aa5fa00488f08 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -77,7 +77,7 @@ func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf
// CHECK-LABEL: func.func @arrow_structure
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
// CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
@@ -85,12 +85,15 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
%0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
- %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]]
+ // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
+ %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
// CHECK-NEXT: %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<8x16xf32>
- %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
%s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
%3 = mesh.shard %2 to %s3 : tensor<8x16xf32>
// CHECK-NEXT: return %[[V6]], %[[V8]]
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 59f7162e21013..5c9fd29444f04 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -176,7 +176,7 @@ func.func @multiple_chained_ops(
%4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
// CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
%5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
+ // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
%6 = mesh.shard %5 to %s6 : tensor<2xi8>
@@ -207,7 +207,11 @@ mesh.mesh @mesh_1d_4(shape = 4)
// CHECK-LABEL: func @ew_chain_with_halo
func.func @ew_chain_with_halo(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
- %arg0: tensor<8x16xf32>)
+ %arg0: tensor<8x16xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg1: tensor<1xf32>,
+ // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg2: tensor<1xf32>)
// CHECK-SAME: -> tensor<5x16xf32>
-> tensor<8x16xf32> {
%ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
@@ -224,8 +228,11 @@ func.func @ew_chain_with_halo(
%sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32>
%ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
%sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
- %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
+ %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding
+ %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
%ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
%sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32>
%ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index b786264d84106..820ea6559b848 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -380,7 +380,9 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [bf16] ]
- %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.negate %arg0, %input_zp, %output_zp : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 4242f68609634..e2575c764fdfe 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -856,13 +856,54 @@ func.func @fold_exp_log(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK-LABEL: @fold_negate_negate
func.func @fold_negate_negate(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
- %0 = tosa.negate %arg0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
- %1 = tosa.negate %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
+ %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ %1 = tosa.negate %0, %in_zp, %out_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
return %1 : tensor<?x1xf32>
}
// -----
+// CHECK-LABEL: @no_fold_negate_negate_non_const_zp
+func.func @no_fold_negate_negate_non_const_zp(%arg0: tensor<?x1xf32>, %in_zp: tensor<1xf32>) -> tensor<?x1xf32> {
+ // cannot fold if any zp is not constant
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ %zero = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.negate %arg0, %in_zp, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ %1 = tosa.negate %0, %zero, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ %2 = tosa.negate %1, %zero, %in_zp : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ %3 = tosa.negate %2, %zero, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ %4 = tosa.negate %3, %in_zp, %zero : (tensor<?x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1xf32>
+ return %4 : tensor<?x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_negate_negate_non_zero_zp
+func.func @no_fold_negate_negate_non_zero_zp(%arg0: tensor<?x1xi8>) -> tensor<?x1xi8> {
+ // cannot fold if any zp is not constant 0
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ // CHECK: tosa.negate
+ %zero = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %one = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.negate %arg0, %zero, %one : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+ %1 = tosa.negate %0, %zero, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+ %2 = tosa.negate %1, %one, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+ %3 = tosa.negate %2, %zero, %zero : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+ %4 = tosa.negate %3, %zero, %one : (tensor<?x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x1xi8>
+ return %4 : tensor<?x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: @fold_abs_abs
func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: %[[ABS:.*]] = tosa.abs %arg{{.*}} : (tensor<?x1xf32>) -> tensor<?x1xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f536444f6379e..c4fe9c1a6cabc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1652,3 +1652,67 @@ func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_negate_same_element_type
+func.func @test_negate_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf32> {
+ // expected-error at +1 {{'tosa.negate' op expect input and output to have same element type, got 'f16' and 'f32'}}
+ %0 = tosa.negate %arg0, %arg1, %arg2
+ : (tensor<1x16x16x8xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf32>
+ return %0 : tensor<1x16x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_same_shape
+func.func @test_negate_same_shape(%arg0: tensor<1x16x16x16xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> {
+ // expected-error at +1 {{'tosa.negate' op requires the same shape for input1 and output}}
+ %0 = tosa.negate %arg0, %arg1, %arg2
+ : (tensor<1x16x16x16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf16>
+ return %0 : tensor<1x16x16x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_input_zp_same_element_type
+func.func @test_negate_input_zp_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> {
+ // expected-error at +1 {{'tosa.negate' op expect both input1 and its zero point are the same element type, got 'f16' and 'i8'}}
+ %0 = tosa.negate %arg0, %arg1, %arg2
+ : (tensor<1x16x16x8xf16>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xf16>
+ return %0 : tensor<1x16x16x8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_output_zp_same_element_type
+func.func @test_negate_output_zp_same_element_type(%arg0: tensor<1x16x16x8xi8>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xi8> {
+ // expected-error at +1 {{'tosa.negate' op expect both output and its zero point are the same element type, got 'i8' and 'f16'}}
+ %0 = tosa.negate %arg0, %arg1, %arg2
+ : (tensor<1x16x16x8xi8>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xi8>
+ return %0 : tensor<1x16x16x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_input_zp_non_zero
+func.func @test_negate_input_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
+ %input_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.negate' op input1 zero point must be zero for non-int8 integer types}}
+ %0 = tosa.negate %arg0, %input_zp, %output_zp
+ : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
+ return %0 : tensor<1x16x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_negate_output_zp_non_zero
+func.func @test_negate_output_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
+ %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %output_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.negate' op output zero point must be zero for non-int8 integer types}}
+ %0 = tosa.negate %arg0, %input_zp, %output_zp
+ : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
+ return %0 : tensor<1x16x16x8xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 6d8237635d0ec..2f7250dabe162 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -249,9 +249,9 @@ func.func @test_logical_not_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xi1>) -> te
// -----
-func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
// expected-error at +1 {{'tosa.negate' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.negate %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+ %0 = tosa.negate %arg0, %arg1, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
return %0 : tensor<1x1x1x1x13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 480d8c327ab83..916886025cb0e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -487,8 +487,8 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
// -----
// CHECK-LABEL: negate
-func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_negate(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index 447a6ef7f9e05..f0ad4eb4fdb0b 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -2,9 +2,9 @@
// -----
// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
- %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index deede4b0afadc..3d0ded8c58ac5 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -45,8 +45,10 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
// CHECK: tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
%5 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
- // CHECK: tosa.negate %arg0 : (tensor<4xf32>) -> tensor<4xf32>
- %6 = tosa.negate %arg0 : (tensor<4xf32>) -> tensor<*xf32>
+ %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>
+ %6 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<4xf32>
%7 = tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<*xf32>
@@ -87,8 +89,10 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
// CHECK: tosa.clz %arg0 : (tensor<4xi32>) -> tensor<4xi32>
%3 = tosa.clz %arg0 : (tensor<4xi32>) -> tensor<*xi32>
- // CHECK: tosa.negate %arg0 : (tensor<4xi32>) -> tensor<4xi32>
- %4 = tosa.negate %arg0 : (tensor<4xi32>) -> tensor<*xi32>
+ %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
+ %4 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
// CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
%5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
More information about the Mlir-commits
mailing list