[Mlir-commits] [mlir] 20ae283 - [mlir][tosa] Change the shift of mul to be required (#125297)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 11:02:48 PST 2025
Author: Tai Ly
Date: 2025-02-11T11:02:44-08:00
New Revision: 20ae283d087224f6b82b7308054bd34a6764d926
URL: https://github.com/llvm/llvm-project/commit/20ae283d087224f6b82b7308054bd34a6764d926
DIFF: https://github.com/llvm/llvm-project/commit/20ae283d087224f6b82b7308054bd34a6764d926.diff
LOG: [mlir][tosa] Change the shift of mul to be required (#125297)
Change the shift operand for the mul operator to be a required operand.
Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)
Signed-off-by: Tai Ly <tai.ly at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/constant-op-fold.mlir
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7a65417db0eabf5..b8755da8db32e9e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -105,8 +105,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
- Optional<Tosa_ZeroPointTensor>:$input_zp,
- Optional<Tosa_ZeroPointTensor>:$weight_zp,
+ Optional<Tosa_ScalarTensor>:$input_zp,
+ Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -136,8 +136,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
- Optional<Tosa_ZeroPointTensor>:$input_zp,
- Optional<Tosa_ZeroPointTensor>:$weight_zp,
+ Optional<Tosa_ScalarTensor>:$input_zp,
+ Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -168,8 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
- Optional<Tosa_ZeroPointTensor>:$input_zp,
- Optional<Tosa_ZeroPointTensor>:$weight_zp,
+ Optional<Tosa_ScalarTensor>:$input_zp,
+ Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -356,8 +356,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
- Optional<Tosa_ZeroPointTensor>:$input_zp,
- Optional<Tosa_ZeroPointTensor>:$weight_zp,
+ Optional<Tosa_ScalarTensor>:$input_zp,
+ Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
@@ -817,7 +817,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
+ // Apply right shift on i32_t input data only
+ Tosa_ScalarInt8Tensor:$shift
);
let results = (outs
@@ -1590,7 +1591,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let arguments = (ins
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
- Optional<Tosa_ScalarTensor>:$pad_const,
+ Optional<Tosa_Rank0Tensor>:$pad_const,
OptionalAttr<I32Attr>:$input_zp
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 120adf82249e074..6457bb8749ee065 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+def AllDimensionsAreSizeOne : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+
class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
+ : TosaRankedTensorOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
+ "tosa-conformant scalar tensor">;
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
@@ -136,8 +145,10 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
-// Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
+
+def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
@@ -296,9 +307,4 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;
-// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
-// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
-// following def can be removed.
-def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
-
#endif // TOSA_TYPES_BASE
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 28bc8732b7978ad..d849c782bf08bac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
+ ElementsAttr shift_elem;
+ if (!shift_val.getImpl() ||
+ !matchPattern(shift_val, m_Constant(&shift_elem))) {
+ (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+ }
+
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
+ if (shift != 0) {
+ (void)rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for float");
+ return nullptr;
+ }
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
if (isa<IntegerType>(elementTy)) {
- int32_t shift = 0;
- ElementsAttr shift_elem;
- if (shift_val.getImpl() &&
- matchPattern(shift_val, m_Constant(&shift_elem))) {
- // Explicit shift is set.
- shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
- }
-
Value a = args[0];
Value b = args[1];
+
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index cff24d825d3f530..4928be38476a90e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1130,16 +1130,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- LogicalResult status = success();
+ // mul op's output shape only depend on input1 and input2, not on shift
+ ValueShapeRange twoInputs = operands.drop_back();
llvm::SmallVector<int64_t> outShape;
- if (operands.size() == 2) {
- status = resolveBroadcastShape(operands, outShape);
- } else {
- // mul op's output shape only depend on input1 and input2, not on shift
- ValueShapeRange two_inputs = operands.drop_back();
- status = resolveBroadcastShape(two_inputs, outShape);
- }
- if (status.failed()) {
+ if (resolveBroadcastShape(twoInputs, outShape).failed()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
} else {
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -1174,6 +1168,15 @@ LogicalResult tosa::MulOp::verify() {
return emitOpError(
"requires the same element type for all operands and results");
}
+
+ // verify shift has value 0 for non-integer types
+ ElementsAttr shift_elem;
+ if (matchPattern(getShift(), m_Constant(&shift_elem))) {
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ if (shift != 0) {
+ return emitOpError() << "require shift to be 0 for float type";
+ }
+ }
}
// Verify the op has same ranks for all main operands (excludes extra operands
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 281f0529a5c0816..64e5c31793f8463 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
- if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
- operand == op->getOperand(2)) {
+ if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
// do not recurse into MulOp's shift operand
continue;
}
@@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
- } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
- v == op->getOperand(2)) {
+ } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
// special case for MulOp's shift operand
operands.push_back(v);
} else {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 56521fb67ef0cb0..17add2d41afe7d3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.mulf
- %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: arith.extsi
// CHECK: arith.extsi
// CHECK: arith.muli
- %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 58a70fb03a09235..24d572244a9b050 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -322,8 +322,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -334,7 +335,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -370,11 +372,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
// CHECK-NOT: tosa.mul
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
// CHECK-NOT: tosa.mul
// CHECK: return %[[ZERO]], %[[ZERO]]
- %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -974,7 +977,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
// CHECK: tosa.mul
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 40469987d89d087..e6fb741df95985a 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_one_rhs_f32
func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_one_lhs_f32
func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
func.func @fold_mul_splat_f32() -> tensor<10xf32> {
%one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
- %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
// CHECK: return %[[THREE]]
return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e77078161d06385..913191be86f8513 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -768,26 +768,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
// CHECK-LABEL: test_mul_type_mismatch
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
- %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: test_mul_invalid_shift
-func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
- %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
- // expected-error at +1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
- return %0 : tensor<13x21x3xi32>
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
- // this is ok because mul's shift operand is optional for now
+ // expected-error at +1 {{'tosa.mul' op expected 3 operands, but found 2}}
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1099,3 +1100,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
%0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
return %0 : tensor<1x13x21x3xf32>
}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_2d
+func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
+ // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_1d
+func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
+ // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_broadcast
+func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 0d3dfead4a7a316..348849cfaa572c6 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -355,7 +355,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index a0a7e5dec6ed050..7dc9b048085fa98 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>)
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
// CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
// CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
return
}
@@ -148,23 +149,24 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
- // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
- %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+ %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
// CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
// CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
- %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+ %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
return
}
@@ -211,10 +213,10 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () {
%11 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<*xi32>
// CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
- %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
+ %13 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
// CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
- %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
+ %14 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
index e70f3644da6465e..3a293009a54559d 100644
--- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
@@ -193,6 +193,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
// CHECK-LABEL: @test_resnet18_common_case
// COM: note that %74 is now represented by %arg2
+// CHECK-DAG: %[[CONST0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
@@ -205,15 +206,16 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
// CHECK-DAG: %[[VAL_12:.*]] = tosa.sub %arg2, %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
// CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_13]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32>
// CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_17:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_16]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32>
// CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
// CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
// CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
%60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
@@ -228,9 +230,9 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%79 = tosa.reshape %arg0, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
%81 = tosa.reshape %78, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
- %82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %82 = tosa.mul %80, %81, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32>
%83 = tosa.reshape %60, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
- %84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %84 = tosa.mul %82, %83, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32>
%85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
%87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
More information about the Mlir-commits
mailing list