[Mlir-commits] [mlir] 73f487d - [mlir][TosaToLinalg] Fix bugs in PointwiseConverter (#132526)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 26 01:33:51 PDT 2025
Author: Longsheng Mou
Date: 2025-03-26T08:33:47Z
New Revision: 73f487d31eeaebff085b12396b16432c4fed78a2
URL: https://github.com/llvm/llvm-project/commit/73f487d31eeaebff085b12396b16432c4fed78a2
DIFF: https://github.com/llvm/llvm-project/commit/73f487d31eeaebff085b12396b16432c4fed78a2.diff
LOG: [mlir][TosaToLinalg] Fix bugs in PointwiseConverter (#132526)
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return nullptr;
}
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
- int64_t rank) {
- // No need to expand if we are already at the desired rank
- auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
- assert(tensorType && "expected a ranked tensor type");
- int64_t tensorRank = tensorType.getRank();
- int64_t numExtraDims = rank - tensorRank;
- assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
- if (!numExtraDims)
- return tensor;
-
- // Compute reassociation indices
- SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
- int64_t index = 0;
- if (tensorRank != 0) {
- for (index = 0; index <= numExtraDims; index++)
- reassociationIndices[0].push_back(index);
- for (size_t position = 1; position < reassociationIndices.size();
- position++)
- reassociationIndices[position].push_back(index++);
- }
-
- // Compute result type
- SmallVector<int64_t> resultShape;
- for (index = 0; index < numExtraDims; index++)
- resultShape.push_back(1);
- for (auto size : tensorType.getShape())
- resultShape.push_back(size);
- auto resultType =
- RankedTensorType::get(resultShape, tensorType.getElementType());
-
- // Emit 'tensor.expand_shape' op
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
- reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
- Location loc, ValueRange operands,
- int64_t rank) {
- return llvm::map_to_vector(operands, [&](Value operand) {
- return expandRank(rewriter, loc, operand, rank);
- });
-}
-
using IndexPool = DenseMap<int64_t, Value>;
// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
+static ValueRange getBroadcastableOperands(Operation *operation,
+ ValueRange operands) {
+ // Shift cannot broadcast
+ if (isa<tosa::MulOp>(operation))
+ return operands.take_front(2);
+ // Input1_zp and output_zp cannot broadcast
+ if (isa<tosa::NegateOp>(operation))
+ return operands.take_front(1);
+ return operands;
+}
+
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
- auto rank =
- cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- // For the mul op we need to avoid expanding the rank of the optional shift
- // input.
- auto operandsToExpand =
- isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
- auto expandedOperands =
- expandInputRanks(rewriter, loc, operandsToExpand, rank);
+ auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
auto [targetShape, masterOperands] =
- computeTargetShape(rewriter, loc, indexPool, expandedOperands);
- auto broadcastOperands = broadcastDynamicDimensions(
- rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+ computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+ auto broadcastOperands =
+ broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+ targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
targetShape, converter);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK-LABEL: @test_negate_quantized
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
%0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[C_128:%.+]] = arith.constant -128
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
return
}
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.muli
+ %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+ %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+ return
+}
More information about the Mlir-commits
mailing list