[Mlir-commits] [mlir] [mlir][tosa] Allow shift operand of tosa::MulOp as non-constant (PR #155197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 27 19:59:08 PDT 2025


https://github.com/ShivaChen updated https://github.com/llvm/llvm-project/pull/155197

>From f3d6e68aa2686e2f9dd71bfbbc9da851c4909cbc Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Thu, 24 Jul 2025 05:18:16 +0100
Subject: [PATCH 1/3] [mlir][tosa] Allow shift operand of tosa::MulOp as
 non-constant

The shift operand of tosa::MulOp could be non-constant when
the dynamic extension enabled. Given that checkConstantOperandMul
could check the shift operand according to the extension, we
might able to relax the checking in TosaToLinalg.

Commutative of MulOp might need to be removed to avoid shift
operand been reordered with other operands when the shift operand
is non-constant.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  1 -
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 56 +++++++++++++------
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |  8 ---
 .../TosaToLinalg/tosa-to-linalg.mlir          | 11 ++++
 4 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 416df6e87b11f..7918812914735 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 def Tosa_MulOp : Tosa_Op<"mul", [
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>,
-    Commutative,
     Pure]> {
   let summary = "Multiplication operator.";
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::MulOp>(op)) {
     auto shiftVal = cast<tosa::MulOp>(op).getShift();
     DenseElementsAttr shiftElem;
-    if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
-      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
-      return nullptr;
-    }
-
-    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    bool shiftIsConstant = true;
+    int32_t shift = 0;
+    if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+      shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    else
+      shiftIsConstant = false;
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
       Value a = args[0];
       Value b = args[1];
 
-      if (shift > 0) {
-        auto shiftConst =
-            arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+      if (shift > 0 || !shiftIsConstant) {
+        Value shiftConst;
+        if (shiftIsConstant)
+          shiftConst =
+              rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
         if (!a.getType().isInteger(32))
           a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
 
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
+        auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
         auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
             rewriter.getStringAttr("SINGLE_ROUND"));
 
-        if (elementTy.isInteger(32))
-          return result;
-
-        return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+        return result;
       }
 
       int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   return operand;
 }
 
+static bool hasDynamicDimensions(ValueRange operands) {
+  for (auto operand : operands) {
+    auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+    if (!rankedTensorType)
+      continue;
+    int64_t rank = rankedTensorType.getRank();
+    for (auto dim : llvm::seq<int64_t>(0, rank)) {
+      if (rankedTensorType.isDynamicDim(dim))
+        return true;
+    }
+  }
+  return false;
+}
+
 static SmallVector<Value>
 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                            IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   if (operands.size() == 1)
     return operands;
 
+  if (!hasDynamicDimensions(operands))
+    return operands;
+
   // Broadcast dynamic dimensions operand by operand
   return llvm::map_to_vector(operands, [&](Value operand) {
     return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
 static ValueRange getBroadcastableOperands(Operation *operation,
                                            ValueRange operands) {
   // Shift cannot broadcast
-  if (isa<tosa::MulOp>(operation))
-    return operands.take_front(2);
+  if (isa<tosa::MulOp>(operation)) {
+    DenseElementsAttr shiftElems;
+    // Shift cannot broadcast when it is constant
+    if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+      return operands.take_front(2);
+    else
+      return operands.take_front(3);
+  }
   // Input1_zp and output_zp cannot broadcast
   if (isa<tosa::NegateOp>(operation))
     return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
   %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
-  // expected-error at +1 {{failed to legalize operation 'tosa.mul'}}
-  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
-  return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+  // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}

>From 04a803a3e98e1d76b4ef12f3c9d80312325c5e64 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Thu, 28 Aug 2025 03:07:10 +0100
Subject: [PATCH 2/3] Remove hasDynamicDimensions

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 25 ++++++++-----------
 1 file changed, 10 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a02d6c97aa5d8..73046e0da361a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -910,20 +910,6 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   return operand;
 }
 
-static bool hasDynamicDimensions(ValueRange operands) {
-  for (auto operand : operands) {
-    auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
-    if (!rankedTensorType)
-      continue;
-    int64_t rank = rankedTensorType.getRank();
-    for (auto dim : llvm::seq<int64_t>(0, rank)) {
-      if (rankedTensorType.isDynamicDim(dim))
-        return true;
-    }
-  }
-  return false;
-}
-
 static SmallVector<Value>
 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                            IndexPool &indexPool, ValueRange operands,
@@ -933,7 +919,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   if (operands.size() == 1)
     return operands;
 
-  if (!hasDynamicDimensions(operands))
+  // No need to broadcast for static shape
+  bool hasDynamic = false;
+  for (auto op : operands) {
+    const auto tType = dyn_cast<RankedTensorType>(op.getType());
+    if (tType && !tType.hasStaticShape()) {
+      hasDynamic = true;
+      break;
+    }
+  }
+  if (!hasDynamic)
     return operands;
 
   // Broadcast dynamic dimensions operand by operand

>From 76cbfe5962f2a75305a2931fc6a4e6dee0bf2851 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Thu, 28 Aug 2025 03:08:46 +0100
Subject: [PATCH 3/3] Add back Commutative

---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7918812914735..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,6 +983,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
 def Tosa_MulOp : Tosa_Op<"mul", [
     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>,
+    Commutative,
     Pure]> {
   let summary = "Multiplication operator.";
 



More information about the Mlir-commits mailing list