[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 18 18:17:08 PST 2024


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/78660

This PR adds a direct vectorization lowering of `tensor.pack` into `mask(vector.transfer_read)`->`vector.shape_cast`->`vector.transpose`->`vector.transfer_write`

>From 67ca50d8ad974eea83d5fc099c56839f7ae5ed32 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Sat, 6 Jan 2024 00:44:18 +0000
Subject: [PATCH 1/4] Revert "[mlir][tosa] Move lowering of `tosa.transpose` to
 `tosa-to-linalg-named` (#75738)"

This reverts commit 11ac97c67a9315dce48fd938d68ae991e3559f10.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 52 +++++++++++-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        | 30 +------
 .../TosaToLinalg/TosaToLinalgNamedPass.cpp    |  1 -
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 75 ++--------------
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 +++
 .../TosaToLinalg/tosa-to-linalg.mlir          | 85 +++++++++++++++++++
 6 files changed, 154 insertions(+), 99 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 678081837b81382..b4f18d57404cc29 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1052,6 +1052,55 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const final {
+    DenseIntElementsAttr perms;
+    if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
+      return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
+    }
+
+    auto loc = op.getLoc();
+    auto input = op->getOperand(0);
+    auto resultTy = cast<ShapedType>(op.getType());
+
+    SmallVector<Value> dynDims;
+    dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
+
+    SmallVector<AffineExpr, 2> inputExprs;
+    inputExprs.resize(resultTy.getRank());
+    for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
+      auto index = permutation.index();
+      auto value = permutation.value().getZExtValue();
+      if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
+        dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
+      }
+      inputExprs[value] = rewriter.getAffineDimExpr(index);
+    }
+
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
+
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
+                       rewriter.getContext()),
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
+        });
+    return success();
+  }
+};
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -2408,6 +2457,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ReverseConverter,
       RFFT2dConverter,
       TableConverter,
-      TileConverter>(patterns->getContext());
+      TileConverter,
+      TransposeConverter>(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff8..b3fbc7dd0b22c19 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -985,31 +984,6 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
   }
 };
 
-class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
-public:
-  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
-                                PatternRewriter &rewriter) const final {
-    SmallVector<int64_t> constantPerms;
-    if (failed(op.getConstantPerms(constantPerms)))
-      return failure();
-
-    Location loc = op.getLoc();
-    // The verifier should have made sure we have a valid permutation tensor.
-    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
-    SmallVector<OpFoldResult> inputSizes =
-        tensor::getMixedSizes(rewriter, loc, op.getInput1());
-    auto permutedSizes =
-        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
-
-    auto permutedInit = rewriter.create<tensor::EmptyOp>(
-        loc, permutedSizes, op.getInput1().getType().getElementType());
-    rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
-        op, op.getInput1(), permutedInit, constantPerms);
-    return success();
-  }
-};
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1030,8 +1004,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       MatMulConverter,
       MaxPool2dConverter,
       AvgPool2dConverter,
-      FullyConnectedConverter,
-      TransposeConverter
-  >(patterns->getContext());
+      FullyConnectedConverter>(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9d..5312dc164c26c5e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -60,7 +60,6 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::AvgPool2dOp>();
     target.addIllegalOp<tosa::MatMulOp>();
     target.addIllegalOp<tosa::FullyConnectedOp>();
-    target.addIllegalOp<tosa::TransposeOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa5..aa010e759a0f201 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -88,8 +88,7 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
 // CHECK-LABEL: @fully_connected
 func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
-  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -111,7 +110,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
 // CHECK-LABEL: @quantized_fully_connected
 func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -137,7 +136,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -378,7 +377,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
   // CHECK:   arith.extsi
@@ -399,7 +398,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
 
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -678,7 +677,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
 // CHECK-LABEL: @conv3d_f32
 func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
   // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -702,7 +701,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
 // CHECK-LABEL: @conv3d_i8
 func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
   // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
   // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -721,63 +720,3 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
   %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>)  -> tensor<1x47x45x43x28xi32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: @test_transpose
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
-func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
-  %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
-  // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
-  %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>)  -> tensor<2x3x1xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_transpose_dyn
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
-func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
-  %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
-  // CHECK: %[[C1:.+]] = arith.constant 1
-  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
-  // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
-  %1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>)  -> tensor<?x4x1x3xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_transpose_dyn_multiple_2d
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
-func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
-  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
-  // CHECK-DAG: %[[C0:.+]] = arith.constant 0
-  // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-  // CHECK-DAG: %[[C1:.+]] = arith.constant 1
-  // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
-  // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-  %1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>)  -> tensor<?x?xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_transpose_dyn_multiple_3d
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
-func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
-  %0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-  // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
-  %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
-  return
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index c2bbfd5130ebcd0..e009cab4b1bf5bd 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -38,3 +38,13 @@ func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.u
   %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
   return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
 }
+
+// -----
+
+// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
+// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
+// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
+func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3931e454da2e228..3672e6214635133 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -820,6 +820,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   return
 }
 
+
 // -----
 
 // CHECK-LABEL: @test_identity
@@ -835,6 +836,90 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_transpose
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
+func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
+  %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x1xi32>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>)  -> tensor<2x3x1xi32>
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @test_transpose_dyn
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
+func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
+  %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs([[OUT:%.+]] : tensor<?x4x1x3xi32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>)  -> tensor<?x4x1x3xi32>
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_transpose_dyn_multiple_2d
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
+func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
+  %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0
+  // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1
+  // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs([[OUT:%.+]] : tensor<?x?xf32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: f32, [[ARG2:%.+]]: f32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>)  -> tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_transpose_dyn_multiple_3d
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
+func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
+  %0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+  // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
+  // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) {
+  // CHECK: ^bb0(%[[IN0:.*]]: f32, %[[OUT0:.*]]: f32):
+  // CHECK:   linalg.yield %[[IN0]] : f32
+  // CHECK: } -> tensor<?x?x?xf32>
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

>From 1b52507f4a182e8c99ba9b872e0528ca926ead89 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 5 Jan 2024 13:50:50 -0500
Subject: [PATCH 2/4] [mlir] Add vectorization support for tensor.pack

---
 .../TransformOps/LinalgTransformOps.cpp       |   2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 151 ++++++++++++++++++
 2 files changed, 152 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5254aac976f462d..2e58eb3376a1c8e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3134,7 +3134,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5d99951ef09a92b..056b2739ab3973f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
 #include <type_traits>
@@ -1393,6 +1399,121 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1) 
+///   outer_dims_perm = [0, 2, 1] 
+///   inner_dims_pos = [2, 1] 
+///   inner_tiles = [16, 2] 
+///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+  auto innerTiles = packOp.getInnerTiles();
+  int64_t srcRank = packOp.getSourceRank();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  if (innerDimsPos.empty())
+    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty())
+    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t { 
+    int64_t srcIdx;
+    if (idx >= srcRank)
+      srcIdx = innerDimsPos[idx - srcRank];
+    else
+      srcIdx = outerDimsPerm[idx];
+    int64_t tiledIdx = srcIdx;
+    for (int64_t pos : innerDimsPos)
+      if (pos < srcIdx)
+        tiledIdx++;
+    if (idx >= srcRank)
+      tiledIdx++;
+    return tiledIdx;
+  };
+  SmallVector<int64_t> perm;
+  for (int i = 0; i < packOp.getDestRank(); i++) 
+    perm.push_back(packedIdxToTiledIdx(i));
+  return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+  auto perm = getTiledShapeToPackedShapePerm(packOp);
+  auto destShape = packOp.getDestType().getShape();
+  return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+/// 
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+                       ArrayRef<int64_t> inputVectorSizes,
+                       SmallVectorImpl<Value> &newResults) {
+  auto padValue = packOp.getPaddingValue();
+  Location loc = packOp.getLoc();
+  int64_t inputRank = inputVectorSizes.size();
+  int64_t outputRank = packOp.getDestRank();
+  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  ReifiedRankedShapedTypeDims reifiedReturnShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedReturnShapes);
+  (void)status; // prevent unused variable warning on non-assert builds
+  assert(succeeded(status) && "failed to reify result shapes");
+  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+                                                  padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+  Value mask =
+      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/packOp.getSource(),
+      /*indices=*/SmallVector<Value>(inputRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(inputRank, true));
+  auto maskedOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+  // ShapeCast
+  auto tiledPackShape = getTiledPackShape(packOp);
+  auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+  auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+  Operation *write = rewriter.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/transposeOp->getResult(0),
+      /*source=*/emptyOp,
+      /*indices=*/SmallVector<Value>(outputRank, zero),
+      /*inBounds=*/SmallVector<bool>(outputRank, true));
+  // bool needMaskForWrite = llvm::any_of(
+  //     llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
+  //     [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  // if (needMaskForWrite) {
+  //   Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
+  //       loc, maskType, reifiedReturnShapes[0]);
+  //   write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
+  // }
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1706,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    LDBG("pad value is not constant: " << packOp << "\n");
+    return failure();
+  }
+
+  ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+    return failure();
+
+  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+        std::optional<int64_t> res = getConstantIntValue(v);
+        return !res.has_value();
+      })) {
+    LDBG("inner_tiles must be constant: " << packOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1789,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
+      .Case<tensor::PackOp>([&](auto packOp) {
+        return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1732,6 +1880,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
                                           results);
           })
+          .Case<tensor::PackOp>([&](auto packOp) {
+            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {

>From 2c1b749eeab69fe96812811c74b17718c9a840e6 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 20:12:13 -0500
Subject: [PATCH 3/4] Support pack with no padding value

---
 .../Linalg/Transforms/Vectorization.cpp       | 22 ++++++++-----------
 1 file changed, 9 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 056b2739ab3973f..b56289b560272d0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,16 +1458,20 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        ArrayRef<int64_t> inputVectorSizes,
                        SmallVectorImpl<Value> &newResults) {
-  auto padValue = packOp.getPaddingValue();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
   Location loc = packOp.getLoc();
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    padValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+  }
   int64_t inputRank = inputVectorSizes.size();
   int64_t outputRank = packOp.getDestRank();
   auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
 
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
   ReifiedRankedShapedTypeDims reifiedReturnShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
@@ -1502,14 +1506,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
       /*source=*/emptyOp,
       /*indices=*/SmallVector<Value>(outputRank, zero),
       /*inBounds=*/SmallVector<bool>(outputRank, true));
-  // bool needMaskForWrite = llvm::any_of(
-  //     llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
-  //     [](auto it) { return std::get<0>(it) != std::get<1>(it); });
-  // if (needMaskForWrite) {
-  //   Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
-  //       loc, maskType, reifiedReturnShapes[0]);
-  //   write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
-  // }
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1710,7 +1706,7 @@ static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
                            ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
-  if (!padValue) {
+  if (padValue && getConstantIntValue(padValue) != std::nullopt) {
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }

>From 41cd8f8cce10102b7e98fdbf53f6e99bafae716c Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:11:49 -0500
Subject: [PATCH 4/4] add tests

---
 mlir/test/Dialect/Linalg/vectorization.mlir | 61 +++++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 610339405d1c2c9..729596226db0821 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -501,6 +501,67 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
+  return %pack : tensor<4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+    transform.yield 
+  }
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+//  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
+//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
+//      CHECK: return %[[write]] : tensor<4x1x16x2xf32>
+
+// -----
+
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+  return %pack : tensor<32x4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+    transform.yield 
+  }
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+//  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+//  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+//      CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+// -----
+
 func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)



More information about the Mlir-commits mailing list