[Mlir-commits] [mlir] [mlir][tosa] Add lowering of `tosa.transpose` to `tosa-to-linalg-named` (PR #75738)
Felix Schneider
llvmlistbot at llvm.org
Wed Dec 20 14:09:42 PST 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/75738
>From 8d4c185c5e9e85362317d837ece25d8828f2f349 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 17 Dec 2023 13:30:38 +0100
Subject: [PATCH 1/2] [mlir][tosa] Add lowering of `tosa.transpose` to
`tosa-to-linalg-named`
Currently, there exists a pattern lowering `tosa.transpose` to `linalg.generic` in `tosa-to-linalg`. This patch adds a pattern lowering `tosa.transpose` to `linalg.transpose` in `tosa-to-linalg-named`. Lowering to the named linalg Op has the advantage that following optimization passes can easily identify transposition without having to perform pattern matching on linalg.generic Ops. The `linalg.transpose` can simply be generalized to a `linalg.generic` in a second step.
---
.../TosaToLinalg/TosaToLinalgNamed.cpp | 30 ++++++++++++++++++-
.../TosaToLinalg/TosaToLinalgNamedPass.cpp | 1 +
.../TosaToLinalg/tosa-to-linalg-named.mlir | 15 +++++-----
.../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 -------
4 files changed, 38 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index b3fbc7dd0b22c19..8dc2d27bd545ff8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -984,6 +985,31 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
}
};
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+ using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const final {
+ SmallVector<int64_t> constantPerms;
+ if (failed(op.getConstantPerms(constantPerms)))
+ return failure();
+
+ Location loc = op.getLoc();
+ // The verifier should have made sure we have a valid permutation tensor.
+ assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+ SmallVector<OpFoldResult> inputSizes =
+ tensor::getMixedSizes(rewriter, loc, op.getInput1());
+ auto permutedSizes =
+ applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+
+ auto permutedInit = rewriter.create<tensor::EmptyOp>(
+ loc, permutedSizes, op.getInput1().getType().getElementType());
+ rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
+ op, op.getInput1(), permutedInit, constantPerms);
+ return success();
+ }
+};
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1004,6 +1030,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
MatMulConverter,
MaxPool2dConverter,
AvgPool2dConverter,
- FullyConnectedConverter>(patterns->getContext());
+ FullyConnectedConverter,
+ TransposeConverter
+ >(patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 5312dc164c26c5e..096969391e51b9d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -60,6 +60,7 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::AvgPool2dOp>();
target.addIllegalOp<tosa::MatMulOp>();
target.addIllegalOp<tosa::FullyConnectedOp>();
+ target.addIllegalOp<tosa::TransposeOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index aa010e759a0f201..4e545721041be6f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -88,7 +88,8 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// CHECK-LABEL: @fully_connected
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
+ // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -110,7 +111,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
// CHECK-LABEL: @quantized_fully_connected
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -136,7 +137,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -377,7 +378,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
- // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
+ // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
// CHECK: arith.extsi
@@ -398,7 +399,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
- // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
+ // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -677,7 +678,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -701,7 +702,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index e009cab4b1bf5bd..c2bbfd5130ebcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -38,13 +38,3 @@ func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.u
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}
-
-// -----
-
-// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
-// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
-// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
-func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
- return %0 : tensor<3x13x21xf32>
-}
>From f42055fdb7a7845a26886d7417c3a4fc9e836708 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 20 Dec 2023 23:08:45 +0100
Subject: [PATCH 2/2] Remove old TransposeConverter, migrate tests
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 52 +-----------
.../TosaToLinalg/tosa-to-linalg-named.mlir | 62 +++++++++++++-
.../TosaToLinalg/tosa-to-linalg.mlir | 85 -------------------
3 files changed, 62 insertions(+), 137 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index beed71d93d9b0d9..7c35389f1e9290d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1052,55 +1052,6 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
}
};
-class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
-public:
- using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::TransposeOp op,
- PatternRewriter &rewriter) const final {
- DenseIntElementsAttr perms;
- if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
- return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
- }
-
- auto loc = op.getLoc();
- auto input = op->getOperand(0);
- auto resultTy = cast<ShapedType>(op.getType());
-
- SmallVector<Value> dynDims;
- dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
-
- SmallVector<AffineExpr, 2> inputExprs;
- inputExprs.resize(resultTy.getRank());
- for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
- auto index = permutation.index();
- auto value = permutation.value().getZExtValue();
- if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
- dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
- }
- inputExprs[value] = rewriter.getAffineDimExpr(index);
- }
-
- SmallVector<Value> filteredDims = condenseValues(dynDims);
-
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
-
- SmallVector<AffineMap, 2> affineMaps = {
- AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
- rewriter.getContext()),
- rewriter.getMultiDimIdentityMap(resultTy.getRank())};
-
- rewriter.replaceOpWithNewOp<linalg::GenericOp>(
- op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
- });
- return success();
- }
-};
-
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -2454,7 +2405,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ReverseConverter,
RFFT2dConverter,
TableConverter,
- TileConverter,
- TransposeConverter>(patterns->getContext());
+ TileConverter>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 4e545721041be6f..6616ea7cf699fa5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -702,7 +702,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -721,3 +721,63 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
return
}
+
+// -----
+
+// CHECK-LABEL: @test_transpose
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
+func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
+ %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
+ %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_dyn
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
+func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
+ %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
+ %1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_dyn_multiple_2d
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
+func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
+ %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0
+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1
+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+ %1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_dyn_multiple_3d
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
+func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
+ %0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+ // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+ return
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e0e041139fe4dc2..535316e5a37210e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -820,7 +820,6 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
return
}
-
// -----
// CHECK-LABEL: @test_identity
@@ -836,90 +835,6 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
// -----
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_transpose
-// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
-func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
- %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x1xi32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
- // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
- // CHECK: linalg.yield [[ARG1]]
- // CHECK: }
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
- return
-}
-
-// -----
-
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
-// CHECK-LABEL: @test_transpose_dyn
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
-func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
- %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
- // CHECK: %[[C1:.+]] = arith.constant 1
- // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
- // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor<?x4x1x3xi32>)
- // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
- // CHECK: linalg.yield [[ARG1]]
- // CHECK: }
- %1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
- return
-}
-
-// -----
-
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @test_transpose_dyn_multiple_2d
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
-func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
- %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0
- // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1
- // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
- // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs([[OUT:%.+]] : tensor<?x?xf32>)
- // CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32)
- // CHECK: linalg.yield [[ARG1]]
- // CHECK: }
- %1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_transpose_dyn_multiple_3d
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
-func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
- %0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
- // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
- // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) {
- // CHECK: ^bb0(%[[IN0:.*]]: f32, %[[OUT0:.*]]: f32):
- // CHECK: linalg.yield %[[IN0]] : f32
- // CHECK: } -> tensor<?x?x?xf32>
- %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
- return
-}
-
-// -----
-
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
More information about the Mlir-commits
mailing list