[Mlir-commits] [mlir] [mlir][TosaToLinalg] Fix bugs in PointwiseConverter (PR #132526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 21 22:54:01 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

- TOSA ensures pointwise op inputs have equal ranks, so the redundant rank expansion function is removed.
- Added `getBroadcastableOperands` to prevent crashes by handling non-broadcastable inputs in `tosa.mul` and `tosa.negate`. Fixes #<!-- -->131294.

---
Full diff: https://github.com/llvm/llvm-project/pull/132526.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+16-56) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23-3) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   return nullptr;
 }
 
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
-                        int64_t rank) {
-  // No need to expand if we are already at the desired rank
-  auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
-  assert(tensorType && "expected a ranked tensor type");
-  int64_t tensorRank = tensorType.getRank();
-  int64_t numExtraDims = rank - tensorRank;
-  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
-  if (!numExtraDims)
-    return tensor;
-
-  // Compute reassociation indices
-  SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
-  int64_t index = 0;
-  if (tensorRank != 0) {
-    for (index = 0; index <= numExtraDims; index++)
-      reassociationIndices[0].push_back(index);
-    for (size_t position = 1; position < reassociationIndices.size();
-         position++)
-      reassociationIndices[position].push_back(index++);
-  }
-
-  // Compute result type
-  SmallVector<int64_t> resultShape;
-  for (index = 0; index < numExtraDims; index++)
-    resultShape.push_back(1);
-  for (auto size : tensorType.getShape())
-    resultShape.push_back(size);
-  auto resultType =
-      RankedTensorType::get(resultShape, tensorType.getElementType());
-
-  // Emit 'tensor.expand_shape' op
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
-                                                reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
-                                           Location loc, ValueRange operands,
-                                           int64_t rank) {
-  return llvm::map_to_vector(operands, [&](Value operand) {
-    return expandRank(rewriter, loc, operand, rank);
-  });
-}
-
 using IndexPool = DenseMap<int64_t, Value>;
 
 // Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
   return success();
 }
 
+static ValueRange getBroadcastableOperands(Operation *operation,
+                                           ValueRange operands) {
+  // Shift cannot broadcast
+  if (isa<tosa::MulOp>(operation))
+    return operands.take_front(2);
+  // Input1_zp and output_zp cannot broadcast
+  if (isa<tosa::NegateOp>(operation))
+    return operands.take_front(1);
+  return operands;
+}
+
 static LogicalResult
 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
                                  ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   // Lower operation
   IndexPool indexPool;
   auto loc = operation->getLoc();
-  auto rank =
-      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  // For the mul op we need to avoid expanding the rank of the optional shift
-  // input.
-  auto operandsToExpand =
-      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
-  auto expandedOperands =
-      expandInputRanks(rewriter, loc, operandsToExpand, rank);
+  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
   auto [targetShape, masterOperands] =
-      computeTargetShape(rewriter, loc, indexPool, expandedOperands);
-  auto broadcastOperands = broadcastDynamicDimensions(
-      rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+      computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+  auto broadcastOperands =
+      broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+                                 targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
                                     targetShape, converter);
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: arith.subi [[ZERO]], %[[ARG1]]
   %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 // CHECK-LABEL: @test_negate_quantized
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[C_128:%.+]] = arith.constant -128
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.muli
+  %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+  %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/132526


More information about the Mlir-commits mailing list