[Mlir-commits] [mlir] [mlir][tosa] Change the shift of mul to be required (PR #125297)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 31 13:28:40 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)



---

Patch is 29.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125297.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+12) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+13-8) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+12-9) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+2-4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+4-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+9-5) 
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+10-5) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+35-7) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+18-16) 
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+6-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c59c582a1f5221..13e4376de8aa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -812,7 +812,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
+    // Apply right shift on i32_t input data only
+    Tosa_ScalarInt8Tensor:$shift
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..d02bf1589f44b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
     IsRankedTensorTypePred,
     CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
 
+def AllDimensionsAreSizeOne : And<[
+    IsRankedTensorTypePred,
+    CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+
 class TosaTensorOf<
     list<Type> allowedTypes, string summary = "tosa-conformant tensor">
     : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
 
+class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
+    : TosaRankedTensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
+      "tosa-conformant scalar tensor">;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
@@ -139,6 +148,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 // Rank-0 (scalar) tensor
 def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
+// Scalar tensors: Rank-1 (with only one element)
+def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
+
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
 // they should be shape propagate used Tosa's shape inference pass and verified
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..a0dfee80360688 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
     auto shift_val = cast<tosa::MulOp>(op).getShift();
+    ElementsAttr shift_elem;
+    if (!shift_val.getImpl() ||
+        !matchPattern(shift_val, m_Constant(&shift_elem))) {
+      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+    }
+
+    int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
 
     if (isa<FloatType>(elementTy)) {
+      if (shift != 0) {
+        (void)rewriter.notifyMatchFailure(op,
+                                          "Cannot have shift value for float");
+        return nullptr;
+      }
       return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
     if (isa<IntegerType>(elementTy)) {
-      int32_t shift = 0;
-      ElementsAttr shift_elem;
-      if (shift_val.getImpl() &&
-          matchPattern(shift_val, m_Constant(&shift_elem))) {
-        // Explicit shift is set.
-        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
-      }
-
       Value a = args[0];
       Value b = args[1];
+
       if (shift > 0) {
         auto shiftConst =
             rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a10439db40803..43470a81cd57ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -963,16 +963,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes,
     OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  LogicalResult status = success();
+  // mul op's output shape only depend on input1 and input2, not on shift
+  ValueShapeRange twoInputs = operands.drop_back();
   llvm::SmallVector<int64_t> outShape;
-  if (operands.size() == 2) {
-    status = resolveBroadcastShape(operands, outShape);
-  } else {
-    // mul op's output shape only depend on input1 and input2, not on shift
-    ValueShapeRange two_inputs = operands.drop_back();
-    status = resolveBroadcastShape(two_inputs, outShape);
-  }
-  if (status.failed()) {
+  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
   } else {
     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -1007,6 +1001,15 @@ LogicalResult tosa::MulOp::verify() {
         return emitOpError(
             "requires the same element type for all operands and results");
     }
+
+    // verify shift has value 0 for non-integer types
+    ElementsAttr shift_elem;
+    if (matchPattern(getShift(), m_Constant(&shift_elem))) {
+      int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      if (shift != 0) {
+        return emitOpError() << "require shift to be 0 for float type";
+      }
+    }
   }
 
   // Verify the op has same ranks for all main operands (excludes extra operands
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..4c312ffd124e24 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
 
     for (Value operand : op->getOperands()) {
       // If this is a problem in future, think about alternatives to recursion.
-      if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
-          operand == op->getOperand(2)) {
+      if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
         // do not recurse into MulOp's shift operand
         continue;
       }
@@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
   for (Value v : op->getOperands()) {
     if (valuesMap.contains(v)) {
       operands.push_back(valuesMap.at(v));
-    } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
-               v == op->getOperand(2)) {
+    } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
       // special case for MulOp's shift operand
       operands.push_back(v);
     } else {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..3704b4c29fceaf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..a9895dd45d62bd 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -331,8 +331,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -343,7 +344,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -379,11 +381,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
@@ -966,7 +969,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
    // CHECK: tosa.mul
    %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+   %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 32677f06e22523..89c17fa1ab5c83 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_rhs_f32
 func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_lhs_f32
 func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
 func.func @fold_mul_splat_f32() -> tensor<10xf32> {
   %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
   %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ac4d466aef94b2..5c1dbcac1bcb83 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -730,26 +730,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
 
 // CHECK-LABEL: test_mul_type_mismatch
 func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
-  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: test_mul_invalid_shift
-func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
-  %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  // expected-error at +1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
-  return %0 : tensor<13x21x3xi32>
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: test_mul_missing_shift
 func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
-  // this is ok because mul's shift operand is optional for now
+  // expected-error at +1 {{'tosa.mul' op expected 3 operands, but found 2}}
   %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
@@ -1061,3 +1062,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
   %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
   return %0 : tensor<1x13x21x3xf32>
 }
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_2d
+func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
+  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_1d
+func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
+  // expected-error at +1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_broadcast
+func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4596c8f9d5362..2774a82d6fb8b5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -338,7 +338,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 73eabab657f380..028105855ce25b 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>)
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125297


More information about the Mlir-commits mailing list