[Mlir-commits] [mlir] [mlir][tosa] Add lowering of `tosa.transpose` to `tosa-to-linalg-named` (PR #75738)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 17 04:38:37 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

Currently, there exists a pattern lowering `tosa.transpose` to `linalg.generic` in `tosa-to-linalg`. This patch adds a pattern lowering `tosa.transpose` to `linalg.transpose` in `tosa-to-linalg-named`. Lowering to the named linalg Op has the advantage that following optimization passes can easily identify transposition without having to perform pattern matching on linalg.generic Ops. The `linalg.transpose` can simply be generalized to a `linalg.generic` in a second step.

---
Full diff: https://github.com/llvm/llvm-project/pull/75738.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+29-1) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (+1) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+8-7) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (-10) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index b3fbc7dd0b22c1..8dc2d27bd545ff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -984,6 +985,31 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
   }
 };
 
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const final {
+    SmallVector<int64_t> constantPerms;
+    if (failed(op.getConstantPerms(constantPerms)))
+      return failure();
+
+    Location loc = op.getLoc();
+    // The verifier should have made sure we have a valid permutation tensor.
+    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+    SmallVector<OpFoldResult> inputSizes =
+        tensor::getMixedSizes(rewriter, loc, op.getInput1());
+    auto permutedSizes =
+        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+
+    auto permutedInit = rewriter.create<tensor::EmptyOp>(
+        loc, permutedSizes, op.getInput1().getType().getElementType());
+    rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
+        op, op.getInput1(), permutedInit, constantPerms);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1004,6 +1030,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       MatMulConverter,
       MaxPool2dConverter,
       AvgPool2dConverter,
-      FullyConnectedConverter>(patterns->getContext());
+      FullyConnectedConverter,
+      TransposeConverter
+  >(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 5312dc164c26c5..096969391e51b9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -60,6 +60,7 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::AvgPool2dOp>();
     target.addIllegalOp<tosa::MatMulOp>();
     target.addIllegalOp<tosa::FullyConnectedOp>();
+    target.addIllegalOp<tosa::TransposeOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index aa010e759a0f20..4e545721041be6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -88,7 +88,8 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
 // CHECK-LABEL: @fully_connected
 func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
+  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -110,7 +111,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
 // CHECK-LABEL: @quantized_fully_connected
 func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+  // CHECK: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -136,7 +137,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -377,7 +378,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
+  // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
   // CHECK:   arith.extsi
@@ -398,7 +399,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
+  // HWCF: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
 
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -677,7 +678,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
 // CHECK-LABEL: @conv3d_f32
 func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
   // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -701,7 +702,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
 // CHECK-LABEL: @conv3d_i8
 func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
   // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
   // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index e009cab4b1bf5b..c2bbfd5130ebcd 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -38,13 +38,3 @@ func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.u
   %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
   return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
 }
-
-// -----
-
-// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
-// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
-// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
-func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
-  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
-  return %0 : tensor<3x13x21xf32>
-}

``````````

</details>


https://github.com/llvm/llvm-project/pull/75738


More information about the Mlir-commits mailing list