[Mlir-commits] [mlir] [mlir][tosa] Make TOSA MUL's Shift an Input (PR #121953)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 7 07:26:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Jack Frankland (FranklandJack)
<details>
<summary>Changes</summary>
The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the `shift` parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately.
Expand the verifier of the `tosa::MulOp` operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed.
---
Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121953.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+55-34)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+13-2)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+78-3)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (+1-1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-4)
- (modified) mlir/test/Dialect/Tosa/broadcast.mlir (-9)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+19-6)
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+13-10)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+13-2)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..dceb36116797c5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -800,7 +800,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- I8Attr:$shift
+ Optional<TosaTensorRankOf<[Tosa_Int8], [0]>>:$shift
);
let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..4ec19db2116546 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
- Value a = args[0];
- Value b = args[1];
- auto shift =
- cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
- if (shift > 0) {
- auto shiftConst =
- rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
- if (!a.getType().isInteger(32))
- a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
- if (!b.getType().isInteger(32))
- b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
- auto result = rewriter.create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getBoolAttr(false));
-
- if (elementTy.isInteger(32))
- return result;
-
- return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ if (isa<tosa::MulOp>(op)) {
+ auto shift_val = cast<tosa::MulOp>(op).getShift();
+ if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+ (void)rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for non i32 output");
+ return nullptr;
+ };
+
+ if (isa<FloatType>(elementTy)) {
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
- int aWidth = a.getType().getIntOrFloatBitWidth();
- int bWidth = b.getType().getIntOrFloatBitWidth();
- int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+ if (isa<IntegerType>(elementTy)) {
+ int32_t shift = 0;
+ ElementsAttr shift_elem;
+ if (shift_val.getImpl() && matchPattern(shift_val, m_Constant(&shift_elem))) {
+ // Explicit shift is set.
+ shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ }
+
+ Value a = args[0];
+ Value b = args[1];
+ if (shift > 0) {
+ auto shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+ if (!a.getType().isInteger(32))
+ a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+
+ if (!b.getType().isInteger(32))
+ b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+
+ auto result = rewriter.create<tosa::ApplyScaleOp>(
+ loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter.getBoolAttr(false));
- if (aWidth < cWidth)
- a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
- if (bWidth < cWidth)
- b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+ if (elementTy.isInteger(32))
+ return result;
- return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ }
+
+ int aWidth = a.getType().getIntOrFloatBitWidth();
+ int bWidth = b.getType().getIntOrFloatBitWidth();
+ int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+ if (aWidth < cWidth)
+ a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+ if (bWidth < cWidth)
+ b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+
+ return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+ }
}
// tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
auto loc = operation->getLoc();
auto rank =
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
+ // For the mul op we need to avoid expanding the rank of the optional shift
+ // input.
+ auto operandsToExpand =
+ isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
+
+ auto expandedOperands =
+ expandInputRanks(rewriter, loc, operandsToExpand, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..892421155733b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -614,7 +614,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+ // Result right shift on i32_t data type only. For simplification, synthesize a zero
+ // shift for other date type.
+ int32_t shift = 0;
+ if (resultETy.isInteger(32)) {
+ ElementsAttr shift_elem;
+ if (getShift().getImpl()) {
+ if (!matchPattern(getShift(), m_Constant(&shift_elem)))
+ // cannot be folded when the shift value is unknown.
+ return {};
+ shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ }
+ }
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
@@ -629,7 +640,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return lhs;
}
- return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
+ return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..07f5d5e49f6d44 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -866,9 +866,84 @@ LogicalResult tosa::SliceOp::verify() {
}
LogicalResult tosa::MulOp::verify() {
- Type elementTy = getInput1().getType().getElementType();
- if (isa<FloatType>(elementTy) && getShift() != 0)
- return emitOpError() << "require shift to be 0 for float type";
+ auto resElemType = getElementTypeOrSelf(getOutput());
+
+ // Verify if the element type amoung operands and result match tosa
+ // specification.
+ if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
+ IntegerType lhsIntType =
+ cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+ IntegerType rhsIntType =
+ cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+ if (lhsIntType != rhsIntType)
+ return emitOpError(
+ "requires the same element type for all operands");
+
+ // Though the spec requires the element type of result to be i32, a more
+ // relaxed way is provided at dialect level for easier cooperating with
+ // other dialects.
+ if (lhsIntType.getWidth() > resIntType.getWidth())
+ return emitOpError("invalid data type size for operands or result");
+
+ } else {
+ // For other supported type, the spec requires requires the same element
+ // type for all operands (excludes `shift` operand) and results.
+ for (int i = 0; i < 2; ++i) {
+ if (getElementTypeOrSelf(getOperand(i)) != resElemType)
+ return emitOpError(
+ "requires the same element type for all operands and results");
+ }
+ }
+
+ // Check if the shift value apply to non-i32 output type as that is not
+ // allowed in the spec.
+ if (!(llvm::isa<IntegerType>(resElemType) && resElemType.isInteger(32)))
+ if (getShift().getImpl())
+ return emitOpError(
+ "right shift output only on i32 data type");
+
+ // Verify the op has same ranks for all main operands (excludes extra operands
+ // such as shift of mul op, so this is the only difference with the built-in
+ // `SameOperandsAndResultRank` trait) and results types, if known.
+
+ // delegate function that returns true if type is a shaped type with known
+ // rank
+ auto hasRank = [](const Type type) {
+ if (auto shaped_type = dyn_cast<ShapedType>(type))
+ return shaped_type.hasRank();
+
+ return false;
+ };
+
+ auto rankedOperandTypes =
+ llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
+
+ auto rankedResultTypes =
+ llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
+
+ // If all operands and results are unranked, then no further verification.
+ if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ return success();
+
+ // delegate function that returns rank of shaped type with known rank
+ auto getRank = [](const Type type) {
+ return cast<ShapedType>(type).getRank();
+ };
+
+ auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+ : getRank(*rankedResultTypes.begin());
+
+ for (size_t i = 0; i < 2; ++i) {
+ if (rank != getRank(rankedOperandTypes[i])) {
+ return emitOpError("operands don't have matching ranks");
+ }
+ }
+
+ for (const auto type : rankedResultTypes) {
+ if (rank != getRank(type)) {
+ return emitOpError("result type has different rank than operands");
+ }
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..537c0cce04491e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -137,7 +137,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
Value mulValue = rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
- weight, /*shift=*/0)
+ weight, Value{} /* zero_shift */)
.getResult();
// Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2a990eed3f681e..79afc75fd6c8ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
Value input1 = tosaBinaryOp.getInput1();
Value input2 = tosaBinaryOp.getInput2();
- int32_t shift = tosaBinaryOp.getShift();
+ Value shift = tosaBinaryOp.getShift();
Value output = tosaBinaryOp.getResult();
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..7137d24afde3c7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -451,7 +451,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.mulf
- %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.negf
@@ -597,7 +597,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: arith.extsi
// CHECK: arith.extsi
// CHECK: arith.muli
- %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+ %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
return
}
@@ -625,12 +625,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK: arith.muli
- %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %shift1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
+ %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: arith.constant 2
// CHECK: apply_scale
- %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %shift2 = "tosa.const"() <{value = dense<2> : tensor<i8>}> : () -> tensor<i8>
+ %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 7613aa3b8dd03d..f33a1465b6009d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -169,15 +169,6 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
return %0 : tensor<3x3x4x5xf32>
}
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
- // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
- // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
- return %0 : tensor<17x16x15x14xi32>
-}
-
// -----
// CHECK-LABEL: broadcast_arithmetic_right_shift
func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..0e783a9a23ae48 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -280,7 +280,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -291,7 +291,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -302,7 +302,20 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
- %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ %1 = tosa.mul %arg0, %ones : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int_and_shift
+func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
+ // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<i8>}>
+ // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>)
+ %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %shift = "tosa.const"() <{value = dense<31> : tensor<i8>}> : () -> tensor<i8>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
@@ -313,11 +326,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
// CHECK-NOT: tosa.mul
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+ %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
// CHECK-NOT: tosa.mul
// CHECK: return %[[ZERO]], %[[ZERO]]
- %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -872,7 +885,7 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
// CHECK: tosa.mul
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..b6947e05b26bf1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -247,7 +247,7 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -258,7 +258,7 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -269,7 +269,7 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
- %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %mul = tosa.mul %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: return %[[ZERO]]
return %mul : tensor<i32>
}
@@ -280,7 +280,7 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
- %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (te...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/121953
More information about the Mlir-commits
mailing list