[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Balaji V. Iyer. llvmlistbot at llvm.org
Tue Feb 6 15:47:04 PST 2024


https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/76087

>From d97a5729f496fb603f4fb9cf2977b015d8e37ed6 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 30 Nov 2023 20:39:55 +0000
Subject: [PATCH 1/4] [mlir][Vectorizer] Vectorize `tensor.unpack`

This patch allows vectorization of a `tensor.unpack` operation.
---
 .../Linalg/Transforms/Vectorization.cpp       | 96 +++++++++++++++++++
 1 file changed, 96 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a60..7a9846154bf34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
@@ -1385,6 +1386,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+//   Vector::TransferReadOp - Reads the Vector Array of Source data
+//   vector::TransposeOp - Transpose the Source
+//   ShapeCastOp - Reshapes the data based on the target.
+//   vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType packTensorType = unpackOp.getSourceType();
+  auto maskType =
+      VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+  auto vectorType = VectorType::get(packTensorType.getShape(),
+                                    packTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), maskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(packTensorType.getRank(), zeroOp),
+      rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+  vector::MaskOp maskedOp =
+      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  int64_t packRank = packTensorType.getRank();
+  auto lastDims =
+      llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      packRank, lastDims, packMetadata.insertPositions);
+  SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  auto vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+      unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+  vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+      unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+      SmallVector<Value>(lastDims.size(), zeroOp),
+      SmallVector<bool>(lastDims.size(), true));
+
+  newResults.push_back(writeOp->getResult(0));
+  return success();
+}
 
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
@@ -1578,6 +1661,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+static LogicalResult
+vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1637,6 +1726,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1724,6 +1816,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
                                           results);
           })
+          .Case<tensor::UnPackOp>([&](auto unpackOp) {
+            return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+                                       results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {

>From a52d650fd8ff0f761fd2aa77671ffa858f3e9237 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Dec 2023 11:45:20 -0600
Subject: [PATCH 2/4] Enabled tensor.unpack vectorization and added test case.

---
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 mlir/test/Dialect/Linalg/vectorization.mlir   | 25 +++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 14404d837ff74..4c456a5e671f5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3078,7 +3078,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::UnPackOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 610339405d1c2..acf7276626a4c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,6 +419,31 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
+    // CHECK %[[c0:.*]] = arith.constant 0 : index
+    // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
+    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
+    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+    // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
+    // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+    // CHECK: return %[[tw0]]
+    %8 = tensor.empty() : tensor<100x18176xf32>
+    %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
+    return %unpack : tensor<100x18176xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)

>From a77ac60ddd46e25a9183d83be302445552766e3f Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 19 Jan 2024 19:11:04 -0600
Subject: [PATCH 3/4] Added some of the changes requested by Diego and HanHan

---
 .../Linalg/Transforms/Vectorization.cpp       | 20 ++++++-------------
 1 file changed, 6 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7a9846154bf34..611dec4af0f93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1386,11 +1386,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
-// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-//   Vector::TransferReadOp - Reads the Vector Array of Source data
-//   vector::TransposeOp - Transpose the Source
-//   ShapeCastOp - Reshapes the data based on the target.
-//   vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
 
 static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
@@ -1661,12 +1661,6 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
-static LogicalResult
-vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
-                              ArrayRef<int64_t> inputVectorSizes) {
-  return success();
-}
-
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1726,9 +1720,7 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
-      .Case<tensor::UnPackOp>([&](auto unpackOp) {
-        return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
-      })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) { return success(); })
       .Default([](auto) { return failure(); });
 }
 

>From 432ce04dfe64e9c24f27061de811b4e34e113776 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 6 Feb 2024 23:45:59 +0000
Subject: [PATCH 4/4] Used vectorSizes for masks and added a dynamic shapes
 test case.

---
 .../Linalg/Transforms/Vectorization.cpp       | 113 +++++++++++++-----
 mlir/test/Dialect/Linalg/vectorization.mlir   |  33 ++++-
 2 files changed, 114 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4aa604610e217..c450b0635edfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1393,17 +1393,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
+
 /// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
 ///   Vector::TransferReadOp - Reads the Vector Array of Source data
 ///   vector::TransposeOp - Transpose the Source
 ///   ShapeCastOp - Reshapes the data based on the target.
 ///   vector::TransferWriteOp. - Write the result vector back.
-
 static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
                                          ArrayRef<int64_t> inputVectorSizes,
                                          SmallVectorImpl<Value> &newResults) {
-
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
   if (!unpackOp.getOuterDimsPerm().empty()) {
     LDBG("outer dimensions perms NYI for: " << unpackOp);
     return failure();
@@ -1412,11 +1413,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType packTensorType = unpackOp.getSourceType();
-  auto maskType =
-      VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
-  auto vectorType = VectorType::get(packTensorType.getShape(),
-                                    packTensorType.getElementType());
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+    readMaskShape[ii] = inputVectorSizes[ii];
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  auto vectorType =
+      VectorType::get(readMaskShape, unpackTensorType.getElementType());
   ReifiedRankedShapedTypeDims reifiedRetShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1425,54 +1434,87 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
     LDBG("Unable to reify result shapes of " << unpackOp);
     return failure();
   }
-
+  int64_t unpackRank = unpackTensorType.getRank();
   arith::ConstantIndexOp zeroOp =
       rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
-  Value mask = rewriter.create<vector::CreateMaskOp>(
-      unpackOp.getLoc(), maskType,
-      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
 
   vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
       unpackOp.getLoc(), vectorType, unpackOp.getSource(),
-      SmallVector<Value>(packTensorType.getRank(), zeroOp),
-      rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+      SmallVector<Value>(unpackRank, zeroOp),
+      rewriter.getMultiDimIdentityMap(unpackRank));
 
+  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), readMaskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
   vector::MaskOp maskedOp =
       cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
 
   int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
-  int64_t packRank = packTensorType.getRank();
-  auto lastDims =
-      llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
   PackingMetadata packMetadata =
-      computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
   SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
-      packRank, lastDims, packMetadata.insertPositions);
-  SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+      unpackRank, lastDims, packMetadata.insertPositions);
+  ShapedType maskedOpShapedType =
+      cast<ShapedType>(maskedOp.getResult(0).getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
 
   RankedTensorType stripMineTensorType =
-      RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
 
+  // Collapse the tensor to the size required by result.
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMineTensorType, packMetadata.reassociations);
   auto vecCollapsedType =
       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
 
+  // Transpose the appropriate rows to match output.
   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
       unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
 
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
-  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
-      unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+  tensor::EmptyOp emptyOp =
+      rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+                                       unpackTensorType.getElementType());
 
-  vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+  int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+  Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
       unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
-      SmallVector<Value>(lastDims.size(), zeroOp),
-      SmallVector<bool>(lastDims.size(), true));
-
-  newResults.push_back(writeOp->getResult(0));
+      SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+  auto resultShape = unpackOp.getResult().getType().getShape();
+
+  // If the shape of the result doesn't match the inputVectorSizes, a mask
+  // is necessary.
+  bool needMaskForWrite =
+      llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  mlir::OpResult result = writeOp->getResult(0);
+  if (needMaskForWrite) {
+    SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+    llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+    llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+    for (auto [index, size] : enumerate(innerTiles)) {
+      writeMaskShape[innerDimPos[index]] *= size;
+    }
+    // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+    // innerTiles.
+    // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+    // for-each index, value in inner-Tiles vector:
+    //      WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+    auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+    Value writeMask = rewriter.create<vector::CreateMaskOp>(
+        unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+    Operation *writeOpWithMask =
+        mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+    result = writeOpWithMask->getResult(0);
+  }
+  newResults.push_back(result);
   return success();
 }
 
@@ -1601,6 +1643,19 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+        return !getConstantIntValue(res).has_value();
+      })) {
+    LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult
 vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                               ArrayRef<int64_t> inputVectorSizes,
@@ -1727,7 +1782,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
-      .Case<tensor::UnPackOp>([&](auto unpackOp) { return success(); })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeUnPackOpPreCondition(unpackOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a4bcfec33d11b..91ada32e7b68f 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -422,11 +422,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func @test_vectorize_unpack
 func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
     // CHECK %[[c0:.*]] = arith.constant 0 : index
-    // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
-    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
-    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+    // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
+    // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
+    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
+    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<2x16x4x16xf32> to vector<32x64xf32>
     // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
-    // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+    // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
+    // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
     // CHECK: return %[[tw0]]
     %8 = tensor.empty() : tensor<100x18176xf32>
     %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
@@ -444,6 +446,29 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
+ // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty
+ // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
+ // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+   %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+   transform.yield
+ }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)



More information about the Mlir-commits mailing list