[Mlir-commits] [mlir] [mlir][tosa] Allow shift operand of tosa::MulOp as non-constant (PR #155197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Aug 24 19:37:13 PDT 2025
https://github.com/ShivaChen created https://github.com/llvm/llvm-project/pull/155197
The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg.
Commutative of MulOp might need to be removed to avoid shift operand been reordered with other operands when the shift operand is non-constant.
>From d23cfcd20097e1687bcb716c507c080dce19741b Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Thu, 24 Jul 2025 05:18:16 +0100
Subject: [PATCH] [mlir][tosa] Allow shift operand of tosa::MulOp as
non-constant
The shift operand of tosa::MulOp could be non-constant when
the dynamic extension enabled. Given that checkConstantOperandMul
could check the shift operand according to the extension, we
might able to relax the checking in TosaToLinalg.
Commutative of MulOp might need to be removed to avoid shift
operand been reordered with other operands when the shift operand
is non-constant.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 -
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 56 +++++++++++++------
.../TosaToLinalg/tosa-to-linalg-invalid.mlir | 8 ---
.../TosaToLinalg/tosa-to-linalg.mlir | 11 ++++
4 files changed, 51 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..eed428da99192 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -983,7 +983,6 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
- Commutative,
Pure]> {
let summary = "Multiplication operator.";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..a02d6c97aa5d8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
- if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
- }
-
- int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ bool shiftIsConstant = true;
+ int32_t shift = 0;
+ if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+ shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ else
+ shiftIsConstant = false;
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];
- if (shift > 0) {
- auto shiftConst =
- arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+ if (shift > 0 || !shiftIsConstant) {
+ Value shiftConst;
+ if (shiftIsConstant)
+ shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
rewriter.getStringAttr("SINGLE_ROUND"));
- if (elementTy.isInteger(32))
- return result;
-
- return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+ return result;
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -909,6 +910,20 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
return operand;
}
+static bool hasDynamicDimensions(ValueRange operands) {
+ for (auto operand : operands) {
+ auto rankedTensorType = cast_or_null<RankedTensorType>(operand.getType());
+ if (!rankedTensorType)
+ continue;
+ int64_t rank = rankedTensorType.getRank();
+ for (auto dim : llvm::seq<int64_t>(0, rank)) {
+ if (rankedTensorType.isDynamicDim(dim))
+ return true;
+ }
+ }
+ return false;
+}
+
static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands,
@@ -918,6 +933,9 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;
+ if (!hasDynamicDimensions(operands))
+ return operands;
+
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1008,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
- if (isa<tosa::MulOp>(operation))
- return operands.take_front(2);
+ if (isa<tosa::MulOp>(operation)) {
+ DenseElementsAttr shiftElems;
+ // Shift cannot broadcast when it is constant
+ if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+ return operands.take_front(2);
+ else
+ return operands.take_front(3);
+ }
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471df8032..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
- // expected-error at +1 {{failed to legalize operation 'tosa.mul'}}
- %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
- return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e49ff920..aee0caa91043d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+ // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
More information about the Mlir-commits
mailing list