[Mlir-commits] [mlir] [mlir][tosa] Make TOSA MUL's Shift an Input (PR #121953)

Jack Frankland llvmlistbot at llvm.org
Tue Jan 7 07:26:17 PST 2025


https://github.com/FranklandJack created https://github.com/llvm/llvm-project/pull/121953

The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the `shift` parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately.

Expand the verifier of the `tosa::MulOp` operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed.

>From 612a4a0cd4828768a091d21744e56920377b9069 Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Tue, 6 Feb 2024 16:49:05 -0800
Subject: [PATCH] [mlir][tosa] Make TOSA MUL's Shift an Input

The TOSA-v1.0 specification makes the shift attribute of the MUL
(Hammard product) operator an input. Move the `shift` parameter of the
MUL operator in the MILR TOSA dialect from an attribute to an input and
update any lit tests appropriately.

Expand the verifier of the `tosa::MulOp` operation to check the various
constraints defined in the TOSA-v1.0 specification. Specifically, ensure
that all input operands (excluding the optional shift) are of the same
rank. This means that broadcasting tests which previously checked rank-0
tensors would be broadcast are no longer valid and are removed.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  2 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 89 ++++++++++++-------
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 15 +++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 81 ++++++++++++++++-
 .../Transforms/TosaDecomposeDepthwise.cpp     |  2 +-
 .../Tosa/Transforms/TosaMakeBroadcastable.cpp |  2 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          | 10 ++-
 mlir/test/Dialect/Tosa/broadcast.mlir         |  9 --
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 25 ++++--
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 23 ++---
 mlir/test/Dialect/Tosa/invalid.mlir           | 15 +++-
 mlir/test/Dialect/Tosa/ops.mlir               |  4 +-
 .../Tosa/tosa-decompose-depthwise.mlir        |  4 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 14 +--
 14 files changed, 212 insertions(+), 83 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..dceb36116797c5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -800,7 +800,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    I8Attr:$shift
+    Optional<TosaTensorRankOf<[Tosa_Int8], [0]>>:$shift
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..4ec19db2116546 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
-  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
-    Value a = args[0];
-    Value b = args[1];
-    auto shift =
-        cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
-    if (shift > 0) {
-      auto shiftConst =
-          rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
-      if (!a.getType().isInteger(32))
-        a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
-      if (!b.getType().isInteger(32))
-        b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
-      auto result = rewriter.create<tosa::ApplyScaleOp>(
-          loc, rewriter.getI32Type(), a, b, shiftConst,
-          rewriter.getBoolAttr(false));
-
-      if (elementTy.isInteger(32))
-        return result;
-
-      return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+  if (isa<tosa::MulOp>(op)) {
+    auto shift_val = cast<tosa::MulOp>(op).getShift();
+    if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+        (void)rewriter.notifyMatchFailure(op,
+                                          "Cannot have shift value for non i32 output");
+        return nullptr;
+    };
+
+    if (isa<FloatType>(elementTy)) {
+      return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
-    int aWidth = a.getType().getIntOrFloatBitWidth();
-    int bWidth = b.getType().getIntOrFloatBitWidth();
-    int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+    if (isa<IntegerType>(elementTy)) {
+      int32_t shift = 0;
+      ElementsAttr shift_elem;
+      if (shift_val.getImpl() && matchPattern(shift_val, m_Constant(&shift_elem))) {
+        // Explicit shift is set.
+        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      }
+
+      Value a = args[0];
+      Value b = args[1];
+      if (shift > 0) {
+        auto shiftConst =
+            rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+        if (!a.getType().isInteger(32))
+          a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+
+        if (!b.getType().isInteger(32))
+          b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+
+        auto result = rewriter.create<tosa::ApplyScaleOp>(
+            loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter.getBoolAttr(false));
 
-    if (aWidth < cWidth)
-      a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
-    if (bWidth < cWidth)
-      b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+        if (elementTy.isInteger(32))
+          return result;
 
-    return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+        return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+      }
+
+      int aWidth = a.getType().getIntOrFloatBitWidth();
+      int bWidth = b.getType().getIntOrFloatBitWidth();
+      int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+      if (aWidth < cWidth)
+        a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+      if (bWidth < cWidth)
+        b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+
+      return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+    }
   }
 
   // tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   auto loc = operation->getLoc();
   auto rank =
       cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
+  // For the mul op we need to avoid expanding the rank of the optional shift
+  // input.
+  auto operandsToExpand =
+      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
+
+  auto expandedOperands =
+      expandInputRanks(rewriter, loc, operandsToExpand, rank);
   auto [targetShape, masterOperands] =
       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
   auto broadcastOperands = broadcastDynamicDimensions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..892421155733b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -614,7 +614,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+  // Result right shift on i32_t data type only. For simplification, synthesize a zero
+  // shift for other date type.
+  int32_t shift = 0;
+  if (resultETy.isInteger(32)) {
+    ElementsAttr shift_elem;
+    if (getShift().getImpl()) {
+      if (!matchPattern(getShift(), m_Constant(&shift_elem)))
+        // cannot be folded when the shift value is unknown.
+        return {};
+      shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+    }
+  }
 
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
@@ -629,7 +640,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
       return lhs;
   }
 
-  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
+  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
 }
 
 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..07f5d5e49f6d44 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -866,9 +866,84 @@ LogicalResult tosa::SliceOp::verify() {
 }
 
 LogicalResult tosa::MulOp::verify() {
-  Type elementTy = getInput1().getType().getElementType();
-  if (isa<FloatType>(elementTy) && getShift() != 0)
-    return emitOpError() << "require shift to be 0 for float type";
+  auto resElemType = getElementTypeOrSelf(getOutput());
+
+  // Verify if the element type amoung operands and result match tosa
+  // specification.
+  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
+    IntegerType lhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+    IntegerType rhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+    if (lhsIntType != rhsIntType)
+      return emitOpError(
+          "requires the same element type for all operands");
+
+    // Though the spec requires the element type of result to be i32, a more
+    // relaxed way is provided at dialect level for easier cooperating with
+    // other dialects.
+    if (lhsIntType.getWidth() > resIntType.getWidth())
+      return emitOpError("invalid data type size for operands or result");
+
+  } else {
+    // For other supported type, the spec requires requires the same element
+    // type for all operands (excludes `shift` operand) and results.
+    for (int i = 0; i < 2; ++i) {
+      if (getElementTypeOrSelf(getOperand(i)) != resElemType)
+        return emitOpError(
+            "requires the same element type for all operands and results");
+    }
+  }
+
+  // Check if the shift value apply to non-i32 output type as that is not
+  // allowed in the spec.
+  if (!(llvm::isa<IntegerType>(resElemType) && resElemType.isInteger(32)))
+    if (getShift().getImpl())
+        return emitOpError(
+            "right shift output only on i32 data type");
+
+  // Verify the op has same ranks for all main operands (excludes extra operands
+  // such as shift of mul op, so this is the only difference with the built-in
+  // `SameOperandsAndResultRank` trait) and results types, if known.
+
+  // delegate function that returns true if type is a shaped type with known
+  // rank
+  auto hasRank = [](const Type type) {
+    if (auto shaped_type = dyn_cast<ShapedType>(type))
+      return shaped_type.hasRank();
+
+    return false;
+  };
+
+  auto rankedOperandTypes =
+    llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
+
+  auto rankedResultTypes =
+    llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
+
+  // If all operands and results are unranked, then no further verification.
+  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+    return success();
+
+  // delegate function that returns rank of shaped type with known rank
+  auto getRank = [](const Type type) {
+    return cast<ShapedType>(type).getRank();
+  };
+
+  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+    : getRank(*rankedResultTypes.begin());
+
+  for (size_t i = 0; i < 2; ++i) {
+    if (rank != getRank(rankedOperandTypes[i])) {
+      return emitOpError("operands don't have matching ranks");
+    }
+  }
+
+  for (const auto type : rankedResultTypes) {
+    if (rank != getRank(type)) {
+      return emitOpError("result type has different rank than operands");
+    }
+  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..537c0cce04491e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -137,7 +137,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
 
     Value mulValue = rewriter
                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
-                                              weight, /*shift=*/0)
+                                              weight, Value{} /* zero_shift */)
                          .getResult();
 
     // Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2a990eed3f681e..79afc75fd6c8ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
 
     Value input1 = tosaBinaryOp.getInput1();
     Value input2 = tosaBinaryOp.getInput2();
-    int32_t shift = tosaBinaryOp.getShift();
+    Value shift = tosaBinaryOp.getShift();
     Value output = tosaBinaryOp.getResult();
     auto outputType = dyn_cast<RankedTensorType>(output.getType());
     if (!outputType)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..7137d24afde3c7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -451,7 +451,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -597,7 +597,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }
@@ -625,12 +625,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
 
   // CHECK: linalg.generic
   // CHECK: arith.muli
-  %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
+  %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 2
   // CHECK: apply_scale
-  %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift2 = "tosa.const"() <{value = dense<2> : tensor<i8>}> : () -> tensor<i8>
+  %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 7613aa3b8dd03d..f33a1465b6009d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -169,15 +169,6 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
   return %0 : tensor<3x3x4x5xf32>
 }
 
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
 // -----
 // CHECK-LABEL: broadcast_arithmetic_right_shift
 func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..0e783a9a23ae48 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -280,7 +280,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -291,7 +291,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -302,7 +302,20 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int_and_shift
+func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
+  // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<i8>}>
+  // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>)
+  %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %shift = "tosa.const"() <{value = dense<31> : tensor<i8>}> : () -> tensor<i8>
+  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>) -> tensor<2x3xi32>
   return %1 : tensor<2x3xi32>
 }
 
@@ -313,11 +326,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
@@ -872,7 +885,7 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
    // CHECK: tosa.mul
    %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..b6947e05b26bf1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -247,7 +247,7 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -258,7 +258,7 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -269,7 +269,7 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -280,7 +280,7 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -290,7 +290,7 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_rhs_f32
 func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %arg0, %one {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -300,7 +300,7 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_lhs_f32
 func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %one, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -310,7 +310,8 @@ func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_rhs_i32
 func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %mul = tosa.mul %arg0, %one {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %shift = "tosa.const"() {value = dense<6> : tensor<i8>} : () -> tensor<i8>
+  %mul = tosa.mul %arg0, %one, %shift : (tensor<i32>, tensor<i32>, tensor<i8>) -> tensor<i32>
   // CHECK: return %arg0
   return %mul : tensor<i32>
 }
@@ -320,7 +321,8 @@ func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_lhs_i32
 func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %mul = tosa.mul %one, %arg0 {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %shift = "tosa.const"() {value = dense<6> : tensor<i8>} : () -> tensor<i8>
+  %mul = tosa.mul %one, %arg0, %shift : (tensor<i32>, tensor<i32>, tensor<i8>) -> tensor<i32>
   // CHECK: return %arg0
   return %mul : tensor<i32>
 }
@@ -331,7 +333,8 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_splat_i8() -> tensor<10xi32> {
   %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
   %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
-  %mul = tosa.mul %one, %two {shift = 3 : i8} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+  %shift = "tosa.const"() {value = dense<3> : tensor<i8>} : () -> tensor<i8>
+  %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<i8>) -> tensor<10xi32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xi32>
@@ -343,7 +346,7 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
 func.func @fold_mul_splat_f32() -> tensor<10xf32> {
   %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
   %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %mul = tosa.mul %one, %two {shift = 0 : i8} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..aadfd4239c1278 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -612,10 +612,21 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
 
 // -----
 
+// CHECK-LABEL: test_mul_invalid_shift_type
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() {value = dense<0> : tensor<i8>} : () -> tensor<i8>
+  // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands and results}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<i8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_mul_invalid_shift
 func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  // expected-error at +1 {{'tosa.mul' op operand #2 must be 0D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<f32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 88fa94ae90db69..926074a9ca3942 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -316,14 +316,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
-  %0 = "tosa.mul"(%arg0, %arg1)  { shift = 1 : i8 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 2224bf3f57b255..fc7e3b42347c4c 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -34,7 +34,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
   // CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]]
   // CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]]
   // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array<i64: 1, 1, 1, 2, 3>}
-  // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] {shift = 0 : i8}
+  // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]]
   // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
   // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
   // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
@@ -51,7 +51,7 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
   // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
   // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
-  // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
+  // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]]
   // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
   // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
   // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..84415523e7c619 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,8 +114,8 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) ->
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  %arg1_reshaped = tosa.reshape %arg1 {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+  %3 = tosa.mul %arg0, %arg1_reshaped : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
@@ -148,8 +148,8 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+  // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
@@ -206,8 +206,10 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
   %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: %[[ARG1_RESHAPED:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+  %arg1_reshaped = tosa.reshape %arg1 {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+  // CHECK: tosa.mul %arg0, %[[ARG1_RESHAPED]] : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
+  %11 = tosa.mul %arg0, %arg1_reshaped : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
   %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>



More information about the Mlir-commits mailing list