[Mlir-commits] [mlir] [mlir][tosa] Make TOSA MUL's Shift an Input (PR #121953)
Jack Frankland
llvmlistbot at llvm.org
Mon Jan 27 06:48:25 PST 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/121953
>From ad05fa2fe8859d2476d0a2e8b7e91e3f831bd5fa Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Tue, 6 Feb 2024 16:49:05 -0800
Subject: [PATCH] [mlir][tosa] Make TOSA MUL's Shift an Input
The TOSA-v1.0 specification makes the shift attribute of the MUL
(Hammard product) operator an input. Move the `shift` parameter of the
MUL operator in the MILR TOSA dialect from an attribute to an input and
update any lit tests appropriately.
Expand the verifier of the `tosa::MulOp` operation to check the various
constraints defined in the TOSA-v1.0 specification. Specifically, ensure
that all input operands (excluding the optional shift) are of the same
rank. This means that broadcasting tests which previously checked rank-0
tensors would be broadcast are no longer valid and are removed.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 4 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 51 ++++++-----
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 91 ++++++++++++++-----
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 85 ++++++++++-------
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 15 ++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 73 ++++++++++++++-
.../Transforms/TosaDecomposeDepthwise.cpp | 9 +-
.../Tosa/Transforms/TosaMakeBroadcastable.cpp | 2 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 12 ++-
mlir/test/Dialect/Tosa/canonicalize.mlir | 30 ++++--
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 30 ++++--
mlir/test/Dialect/Tosa/invalid.mlir | 15 ++-
mlir/test/Dialect/Tosa/ops.mlir | 6 +-
.../Tosa/tosa-decompose-depthwise.mlir | 4 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 17 ++--
15 files changed, 321 insertions(+), 123 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 4975530a9588ca..29afd6c27302cc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -239,9 +239,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape,
TosaElementwiseOperator,
- SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
@@ -250,6 +248,8 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
SameOperandsAndResultShape,
+ ResultsBroadcastableShape,
+ SameOperandsAndResultRank,
SameOperandsAndResultElementType])> {}
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index e4f5d09064cd75..4d62a15110764e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
@@ -28,6 +29,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
class PatternRewriter;
@@ -53,34 +55,37 @@ class MulOperandsAndResultElementType
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
- auto resElemType = getElementTypeOrSelf(op->getResult(0));
-
- // In cases of floating point type, op requires the same element
- // type for all operands and result.
- if (llvm::isa<FloatType>(resElemType))
- return impl::verifySameOperandsAndResultElementType(op);
-
+ // Check we have the three operands; lhs, rhs and shift
+ // and a single result.
+ if (failed(impl::verifyNOperands(op, 3)) ||
+ failed(impl::verifyNResults(op, 1)))
+ return failure();
+
+ Type resElemType = getElementTypeOrSelf(op->getResult(0));
+ Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
+ Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));
+
+ // Verify operands type match (ignoring the shift parameter which will
+ // always be i8).
+ if (lhsElemType != rhsElemType)
+ return op->emitOpError("requires the same element type for all operands");
+
+ // Though the spec requires the element type of result to be i32, a more
+ // relaxed way is provided at dialect level for easier cooperating with
+ // other dialects.
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
- IntegerType lhsIntType =
- cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
- IntegerType rhsIntType =
- cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
- if (lhsIntType != rhsIntType)
- return op->emitOpError(
- "requires the same element type for all operands");
-
- // Though the spec requires the element type of result to be i32, a more
- // relaxed way is provided at dialect level for easier cooperating with
- // other dialects.
+ auto lhsIntType = cast<IntegerType>(lhsElemType);
if (lhsIntType.getWidth() > resIntType.getWidth())
return op->emitOpError("invalid data type size for operands or result");
-
- return success();
+ } else {
+ // In cases of floating point type or quant types, op requires the same
+ // element type for all operands and result (excluding shift).
+ if (resElemType != lhsElemType)
+ return op->emitOpError(
+ "requires the same element type for all operands and results");
}
- // In cases of all other types, op requires the same element
- // type for all operands and result.
- return impl::verifySameOperandsAndResultElementType(op);
+ return llvm::success();
}
};
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2186510e7db1e1..758f49079e74e5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -482,7 +482,9 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
//===----------------------------------------------------------------------===//
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise addition operator";
let description = [{
@@ -515,8 +517,10 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
- [SameOperandsAndResultElementType]> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise Arithmetic Right Shift";
let description = [{
@@ -540,7 +544,9 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
//===----------------------------------------------------------------------===//
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Bitwise AND operator";
let description = [{
@@ -563,7 +569,9 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Bitwise OR operator";
let description = [{
@@ -586,7 +594,9 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Bitwise XOR operator";
let description = [{
@@ -607,7 +617,10 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: int_div
//===----------------------------------------------------------------------===//
-def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
+def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultRank,
+ SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";
let description = [{
@@ -632,7 +645,9 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementT
//===----------------------------------------------------------------------===//
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x AND y element-wise.";
let description = [{
@@ -653,8 +668,10 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
- [SameOperandsAndResultElementType]> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise Logical Left Shift";
let description = [{
@@ -675,8 +692,10 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
- [SameOperandsAndResultElementType]> {
+def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise Logical Right Shift";
let description = [{
@@ -699,7 +718,9 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
//===----------------------------------------------------------------------===//
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x OR y element-wise.";
let description = [{
@@ -722,7 +743,9 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x XOR y element-wise.";
let description = [{
@@ -745,7 +768,9 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
//===----------------------------------------------------------------------===//
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise Maximum";
let description = [{
@@ -769,7 +794,9 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
//===----------------------------------------------------------------------===//
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Commutative,
- SameOperandsAndResultElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise Minimum";
let description = [{
@@ -810,7 +837,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- I8Attr:$shift
+ TosaTensorRankOf<[Tosa_Int8], [1]>:$shift
);
let results = (outs
@@ -824,7 +851,10 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
+def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Computes the power of one value to another.";
let description = [{
@@ -845,7 +875,10 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
+def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise subtraction operator";
let description = [{
@@ -1196,7 +1229,9 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
+def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
+ ResultsBroadcastableShape,
+ SameOperandsAndResultRank]> {
let summary = "Elementwise select operator";
let description = [{
@@ -1232,7 +1267,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
InferTensorType,
Commutative,
- SameOperandsElementType]> {
+ ResultsBroadcastableShape,
+ SameOperandsElementType,
+ SameOperandsAndResultRank]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@@ -1260,7 +1297,10 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
+def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
+ ResultsBroadcastableShape,
+ SameOperandsElementType,
+ SameOperandsAndResultRank]> {
let summary = "Returns the truth value of (x > y) element-wise.";
let description = [{
@@ -1282,8 +1322,11 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
- [SameOperandsElementType]> {
+def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
+ ResultsBroadcastableShape,
+ SameOperandsElementType,
+ SameOperandsAndResultRank
+ ]> {
let summary = "Returns the truth value of (x >= y) element-wise.";
let description = [{
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f97e0ff1e30ea7..b0eb2d6cbc30b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -90,43 +90,54 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
- Value a = args[0];
- Value b = args[1];
- auto shift =
- cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
- if (shift > 0) {
- auto shiftConst =
- rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
- if (!a.getType().isInteger(32))
- a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
- if (!b.getType().isInteger(32))
- b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
- auto result = rewriter.create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getBoolAttr(false));
-
- if (elementTy.isInteger(32))
- return result;
-
- return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ if (isa<tosa::MulOp>(op)) {
+ auto shift_val = cast<tosa::MulOp>(op).getShift();
+
+ if (isa<FloatType>(elementTy)) {
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
- int aWidth = a.getType().getIntOrFloatBitWidth();
- int bWidth = b.getType().getIntOrFloatBitWidth();
- int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+ if (isa<IntegerType>(elementTy)) {
+ int32_t shift = 0;
+ ElementsAttr shift_elem;
+ if (shift_val.getImpl() &&
+ matchPattern(shift_val, m_Constant(&shift_elem))) {
+ // Explicit shift is set.
+ shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ }
+
+ Value a = args[0];
+ Value b = args[1];
+ if (shift > 0) {
+ auto shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+ if (!a.getType().isInteger(32))
+ a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
- if (aWidth < cWidth)
- a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
- if (bWidth < cWidth)
- b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+ if (!b.getType().isInteger(32))
+ b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
- return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ auto result = rewriter.create<tosa::ApplyScaleOp>(
+ loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter.getBoolAttr(false));
+
+ if (elementTy.isInteger(32))
+ return result;
+
+ return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ }
+
+ int aWidth = a.getType().getIntOrFloatBitWidth();
+ int bWidth = b.getType().getIntOrFloatBitWidth();
+ int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+ if (aWidth < cWidth)
+ a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+ if (bWidth < cWidth)
+ b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+
+ return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ }
}
// tosa::NegateOp
@@ -940,7 +951,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
auto loc = operation->getLoc();
auto rank =
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
+ // For the mul op we need to avoid expanding the rank of the optional shift
+ // input.
+ auto operandsToExpand =
+ isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
+
+ auto expandedOperands =
+ expandInputRanks(rewriter, loc, operandsToExpand, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b8e0005dc1bc03..ddfcde6de14f14 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -665,7 +665,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+ // Result right shift on i32_t data type only. For simplification, synthesize
+ // a zero shift for other data type.
+ int32_t shift = 0;
+ if (resultETy.isInteger(32)) {
+ ElementsAttr shift_elem;
+ if (getShift().getImpl()) {
+ if (!matchPattern(getShift(), m_Constant(&shift_elem)))
+ // cannot be folded when the shift value is unknown.
+ return {};
+ shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ }
+ }
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
@@ -680,7 +691,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return lhs;
}
- return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
+ return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fdccce60fe1d86..ae4e09a1e324c6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -945,9 +945,76 @@ LogicalResult tosa::SliceOp::verify() {
}
LogicalResult tosa::MulOp::verify() {
- Type elementTy = getInput1().getType().getElementType();
- if (isa<FloatType>(elementTy) && getShift() != 0)
- return emitOpError() << "require shift to be 0 for float type";
+ auto resElemType = getElementTypeOrSelf(getOutput());
+
+ // Verify if the element type among operands and result match tosa
+ // specification.
+ if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
+ IntegerType lhsIntType =
+ cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+ IntegerType rhsIntType =
+ cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+ if (lhsIntType != rhsIntType)
+ return emitOpError("requires the same element type for all operands");
+
+ // Though the spec requires the element type of result to be i32, a more
+ // relaxed way is provided at dialect level for easier cooperating with
+ // other dialects.
+ if (lhsIntType.getWidth() > resIntType.getWidth())
+ return emitOpError("invalid data type size for operands or result");
+
+ } else {
+ // For other supported type, the spec requires requires the same element
+ // type for all operands (excludes `shift` operand) and results.
+ for (int i = 0; i < 2; ++i) {
+ if (getElementTypeOrSelf(getOperand(i)) != resElemType)
+ return emitOpError(
+ "requires the same element type for all operands and results");
+ }
+ }
+
+ // Verify the op has same ranks for all main operands (excludes extra operands
+ // such as shift of mul op, so this is the only difference with the built-in
+ // `SameOperandsAndResultRank` trait) and results types, if known.
+
+ // delegate function that returns true if type is a shaped type with known
+ // rank
+ auto hasRank = [](const Type type) {
+ if (auto shaped_type = dyn_cast<ShapedType>(type))
+ return shaped_type.hasRank();
+
+ return false;
+ };
+
+ auto rankedOperandTypes =
+ llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
+
+ auto rankedResultTypes =
+ llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
+
+ // If all operands and results are unranked, then no further verification.
+ if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ return success();
+
+ // delegate function that returns rank of shaped type with known rank
+ auto getRank = [](const Type type) {
+ return cast<ShapedType>(type).getRank();
+ };
+
+ auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+ : getRank(*rankedResultTypes.begin());
+
+ for (size_t i = 0; i < 2; ++i) {
+ if (rank != getRank(rankedOperandTypes[i])) {
+ return emitOpError("operands don't have matching ranks");
+ }
+ }
+
+ for (const auto type : rankedResultTypes) {
+ if (rank != getRank(type)) {
+ return emitOpError("result type has different rank than operands");
+ }
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 45f4419875b485..181aff3a9ce04f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@@ -131,9 +132,15 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return failure();
}
+ auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
+ auto shiftType = RankedTensorType::get({1}, shiftElementType);
+ auto shiftZeroAttr = DenseElementsAttr::get(
+ shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
+ Value constZero =
+ rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
Value mulValue = rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
- weight, /*shift=*/0)
+ weight, constZero)
.getResult();
// Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2a990eed3f681e..79afc75fd6c8ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
Value input1 = tosaBinaryOp.getInput1();
Value input2 = tosaBinaryOp.getInput2();
- int32_t shift = tosaBinaryOp.getShift();
+ Value shift = tosaBinaryOp.getShift();
Value output = tosaBinaryOp.getResult();
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f860dca85c9e9c..3704b4c29fceaf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.mulf
- %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: arith.extsi
// CHECK: arith.extsi
// CHECK: arith.muli
- %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
return
}
@@ -646,12 +648,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK: arith.muli
- %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %shift1 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: arith.constant 2
// CHECK: apply_scale
- %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %shift2 = "tosa.const"() <{value = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6f47f041b9199a..7d3e49d7392dc5 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -332,7 +332,8 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -343,7 +344,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -353,8 +355,22 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
- %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int_and_shift
+func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
+ // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<1xi8>}>
+ // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>)
+ %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %shift = "tosa.const"() <{value = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
@@ -365,11 +381,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
// CHECK-NOT: tosa.mul
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
// CHECK-NOT: tosa.mul
// CHECK: return %[[ZERO]], %[[ZERO]]
- %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -927,7 +944,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
// CHECK: tosa.mul
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %shift = "tosa.const"() <{value = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 8198903b78ac05..4c872e02fd03e4 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -237,8 +237,9 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
// CHECK-LABEL: @fold_mul_zero_rhs_f32
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -248,8 +249,9 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_zero_lhs_f32
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -259,8 +261,9 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_zero_rhs_i32
func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
- %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %mul = tosa.mul %arg0, %zero, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
// CHECK: return %[[ZERO]]
return %mul : tensor<i32>
}
@@ -270,8 +273,9 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_zero_lhs_i32
func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
- %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %mul = tosa.mul %zero, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
// CHECK: return %[[ZERO]]
return %mul : tensor<i32>
}
@@ -281,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_one_rhs_f32
func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %arg0, %one {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -291,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_one_lhs_f32
func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %one, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -301,7 +307,8 @@ func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_one_rhs_i32
func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
- %mul = tosa.mul %arg0, %one {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %shift = "tosa.const"() {value = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %one, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
// CHECK: return %arg0
return %mul : tensor<i32>
}
@@ -311,7 +318,8 @@ func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_one_lhs_i32
func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
- %mul = tosa.mul %one, %arg0 {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %shift = "tosa.const"() {value = dense<6> : tensor<1xi8>} : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
// CHECK: return %arg0
return %mul : tensor<i32>
}
@@ -322,7 +330,8 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
func.func @fold_mul_splat_i8() -> tensor<10xi32> {
%one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
%two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
- %mul = tosa.mul %one, %two {shift = 3 : i8} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+ %shift = "tosa.const"() {value = dense<3> : tensor<1xi8>} : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32>
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
// CHECK: return %[[THREE]]
return %mul : tensor<10xi32>
@@ -334,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
func.func @fold_mul_splat_f32() -> tensor<10xf32> {
%one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
- %mul = tosa.mul %one, %two {shift = 0 : i8} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
// CHECK: return %[[THREE]]
return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 4808867b28bb97..e9fb93d59e38a6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -724,10 +724,21 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
// -----
+// CHECK-LABEL: test_mul_invalid_shift_type
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: test_mul_invalid_shift
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+ // expected-error at +1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<f32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 19b93d7611854d..b2773a4f7f02f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -330,14 +330,16 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: mul
func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
- %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i8 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi16>, tensor<13x1x3xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index bbcc206e1490c7..5f36dd3b3d137c 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -34,7 +34,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]]
// CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]]
// CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array<i64: 1, 1, 1, 2, 3>}
- // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] {shift = 0 : i8}
+ // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]]
// CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
// CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
// CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
@@ -51,7 +51,7 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
// CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
// CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
// CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
- // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
+ // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]]
// CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
// CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
// CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 6beb1ad6296135..dedfad3c6d2073 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,8 +114,9 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>)
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ // CHECK: tosa.mul %arg0, %arg1, %{{.*}} : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %arg0, %arg1, %shift : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
@@ -148,8 +149,9 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ // CHECK: tosa.mul %arg0, %arg1, %{{.*}} : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = tosa.mul %arg0, %arg1, %shift : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
@@ -206,8 +208,9 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () {
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
%10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
- // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
- %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
+ // CHECK: tosa.mul %arg0, %arg1, %{{.*}} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %11 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<*xi32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
%12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
@@ -1369,7 +1372,7 @@ func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape
// -----
-// CHECK-LABEL: test_non_tosa_consumer_shape2
+// CHECK-LABEL: test_non_tosa_consumer_shape
func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor<?xindex> {
// CHECK: tosa.log %arg0 : (tensor<4x4xf32>) -> tensor<4x4xf32>
%0 = tosa.log %arg0 : (tensor<4x4xf32>) -> tensor<*xf32>
More information about the Mlir-commits
mailing list