[Mlir-commits] [mlir] adf838d - [mlir][Vectorizer] Added support to Vectorize tensor.unpack (#76087)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 20 14:10:18 PST 2024


Author: Balaji V. Iyer
Date: 2024-02-20T16:10:14-06:00
New Revision: adf838daee63b3245c8822957988da5367e1572c

URL: https://github.com/llvm/llvm-project/commit/adf838daee63b3245c8822957988da5367e1572c
DIFF: https://github.com/llvm/llvm-project/commit/adf838daee63b3245c8822957988da5367e1572c.diff

LOG: [mlir][Vectorizer] Added support to Vectorize tensor.unpack (#76087)

Added support to vectorized tensor.unpack. The unpack Op is split into a
`vector.transfer_read`, `vector.transpose`, `vector.shape_cast` and a
`vector.transfer_write`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index fe9b16cb44b3da..d09c9e36f6ff88 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,13 +32,11 @@ FailureOr<RankedTensorType>
 computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector);
 
-/// Given a tensor::PackOp, compute the permutation vector to shuffle the
-/// packed shape into the shape before any outer or inner permutations have
-/// been applied.
-/// i.e. for a pack from an ABCD layout to an ABCDba:
-/// The packed shape would be ABCDba.
-/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
+
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
+                                             PackingMetadata &metadata);
 
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4ef8859fd5c430..299965bcfc3ab3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
+            target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 01b393644679c5..a17bc8e4cd318f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
   SmallVector<int64_t> packedToStripMinedShapePerm =
-      tensor::getPackInverseDestPermutation(packOp);
+      tensor::getPackInverseDestPerm(packOp);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..ac043e87223dfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 /// permutations.
 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
-  return applyPermutation(destShape,
-                          tensor::getPackInverseDestPermutation(packOp));
+  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
 }
 
 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
 
   // Create TransposeOp.
   auto destPermutation =
-      invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
+      invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
   auto transposeOp = rewriter.create<vector::TransposeOp>(
       loc, shapeCastOp.getResult(), destPermutation);
 
@@ -1559,6 +1558,112 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   return success();
 }
 
+/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   vector::TransposeOp - Transpose the Source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult
+vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
+                          ArrayRef<int64_t> inputVectorSizes,
+                          SmallVectorImpl<Value> &newResults) {
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+
+  SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
+                                     inputVectorSizes.end());
+  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this: Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
+  // size M-N
+  // Thus:
+  // - initially: ReadMaskShape = vectorInputSizes
+  // - Divide all the readMaskShape locations pointed by innerDimPos
+  //   by the innerTileSize attribute value.
+  // - if outer_dims_perms is present: do that permutation on readMaskShape.
+  // - Append the remaining shape from SS
+  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
+  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
+  // 128] and outer_dims_perm is [1, 0] then read shape is:
+  //   ReadMaskShape(initial): [512, 128]
+  //   Final Value(after innerDim Adjustment): [512/32, 128/16]
+  //                                           = [16, 8]
+  //   After applying outer_dims_perm: [8, 16]
+  //   After appending the rest of the sourceShape: [8, 16, 32, 16]
+
+  for (auto [index, size] : enumerate(innerTiles)) {
+    readMaskShape[innerDimPos[index]] =
+        llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
+  }
+  if (!outerDimsPerm.empty()) {
+    applyPermutationToVector(readMaskShape, outerDimsPerm);
+  }
+  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+                       sourceShape.end());
+
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  Location loc = unpackOp->getLoc();
+
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
+
+  // Read result, mask if necessary. If transferReadOp shape is not equal
+  // to shape of source, then a mask is necessary.
+  Value readResult = createReadOrMaskedRead(
+      rewriter, loc, unpackOp.getSource(),
+      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
+
+  PackingMetadata packMetadata;
+  SmallVector<int64_t> lastDimToInsertPosPerm =
+      tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
+  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::get(stripMineShape, stripMineElemType);
+  // Transpose the appropriate rows to match output.
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      loc, readResult, lastDimToInsertPosPerm);
+
+  // Collapse the vector to the size required by result.
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  mlir::VectorType vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      loc, vecCollapsedType, transposeOp->getResult(0));
+
+  // WriteMaskShape had to match the shapecast shape for dynamic sizes,
+  // otherwise the validator complains that the mask size is invalid.
+  SmallVector<int64_t> writeMaskShape(
+      unpackOp.getDestType().hasStaticShape()
+          ? inputVectorSizes
+          : shapeCastOp.getResultVectorType().getShape());
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
+                               reifiedRetShapes[0], writeMaskShape);
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1655,6 +1760,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+
+  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+        return !getConstantIntValue(res).has_value();
+      })) {
+    LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+    return failure();
+  }
+  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+  if (!inputVectorSizes.empty() &&
+      failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
+    return failure();
+
+  return success();
+}
+
 static LogicalResult
 vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                               ArrayRef<int64_t> inputVectorSizes,
@@ -1703,9 +1827,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   }
   if (isElementwise(linalgOp))
     return success();
-  // TODO: isaConvolutionOpInterface that can also infer from generic features.
-  // But we will still need stride/dilation attributes that will be annoying to
-  // reverse-engineer...
+
+  // TODO: isaConvolutionOpInterface that can also infer from generic
+  // features. But we will still need stride/dilation attributes that will be
+  // annoying to reverse-engineer...
   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
     return success();
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -1810,6 +1935,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PackOp>([&](auto packOp) {
         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
       })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1829,11 +1957,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 }
 
 /// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the input
-/// vector sizes must be greater than or equal to their counterpart iteration
-/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
-/// operations with dynamic shapes.
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes.
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
@@ -1867,8 +1995,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   auto vectorizeResult =
       TypeSwitch<Operation *, LogicalResult>(op)
           .Case<linalg::LinalgOp>([&](auto linalgOp) {
-            // TODO: isaConvolutionOpInterface that can also infer from generic
-            // features. Will require stride/dilation attributes inference.
+            // TODO: isaConvolutionOpInterface that can also infer from
+            // generic features. Will require stride/dilation attributes
+            // inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
                   rewriter, linalgOp, flatten1DDepthwiseConv);
@@ -1902,6 +2031,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
                                            results);
           })
+          .Case<tensor::UnPackOp>([&](auto unpackOp) {
+            return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
+                                             inputVectorSizes, results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
@@ -1919,7 +2052,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
                                           memref::CopyOp copyOp) {
-
   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
@@ -2833,8 +2965,8 @@ struct Conv1DGenerator
     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
                                                         resPadding);
 
-    // The base vectorization case for channeled convolution is input: {n,w,c},
-    // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
+    // The base vectorization case for channeled convolution is input:
+    // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
     // vectorization case, we do pre transpose on input, weight, and output.
     switch (conv1DOpOrder) {
     case Conv1DOpOrder::W:
@@ -2877,9 +3009,9 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
-    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
-    // perform outerproduct for non-channeled convolution or
-    // perform simple arith operation for pooling
+    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
+    // or perform outerproduct for non-channeled convolution or perform simple
+    // arith operation for pooling
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         switch (oper) {
@@ -2908,9 +3040,9 @@ struct Conv1DGenerator
     // End vector-only rewrite part
     //===------------------------------------------------------------------===//
 
-    // The base vectorization case for channeled convolution is output: {n,w,f}
-    // To reuse the result from base pattern vectorization case, we post
-    // transpose the base case result.
+    // The base vectorization case for channeled convolution is output:
+    // {n,w,f} To reuse the result from base pattern vectorization case, we
+    // post transpose the base case result.
     switch (conv1DOpOrder) {
     case Conv1DOpOrder::W:
     case Conv1DOpOrder::Nwc:
@@ -3348,9 +3480,9 @@ static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can simply
-  // use them if present, otherwise use the default and let the generic conv.
-  // matcher in the ConvGenerator succeed or fail.
+  // strides/dilations. However, we do not need to rely on those, we can
+  // simply use them if present, otherwise use the default and let the generic
+  // conv. matcher in the ConvGenerator succeed or fail.
   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f20008a1ed2b2f..186f85d2ce20a6 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -72,36 +72,73 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
       RTTBuilder(rankedTensorType).setShape(transposedShape);
   return transposedTensorType;
 }
-
-SmallVector<int64_t>
-mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
-  // The permutation can be obtained from two permutations:
-  //   a) Compute the permutation vector to move the last `numPackedDims` into
-  //      the `innerPosDims` of a shape of rank `packedRank`.
-  //   b) Compute the permutation vector to move outer dims if the pack op
-  //      has outer_dims_perm.
-  // Apply (b) permutation on (a) permutation to get the final permutation.
-  int64_t numPackedDims = packOp.getInnerDimsPos().size();
-  int64_t packedRank = packOp.getDestType().getRank();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata = computePackingMetadata(
-      packOp.getDestType().getRank(), packOp.getInnerDimsPos());
-  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
+/// The permutation can be obtained from two permutations:
+///   a) Compute the permutation vector to move the last `numPackedDims` into
+///      the `innerPosDims` of a shape of rank `rank`.
+///   b) Compute the permutation vector to move outer dims if the
+///      `outerPerm` parameter is not empty.
+/// Apply (b) permutation on (a) permutation to get the final permutation.
+static SmallVector<int64_t>
+computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
+                      ArrayRef<int64_t> &outerPerm,
+                      PackingMetadata &packingMetadata) {
+  int64_t numPackedDims = innerDimsPos.size();
+  auto lastDims =
+      llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
+  packingMetadata = computePackingMetadata(rank, innerDimsPos);
+  SmallVector<int64_t> innerPositionsPerm =
+      computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
 
   SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
-  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
   if (!outerPerm.empty())
     applyPermutationToVector(outerPos, outerPerm);
-  SmallVector<int64_t> outerPositionPerm = computePermutationVector(
-      packedRank, packingMetadata.outerPositions, outerPos);
+  SmallVector<int64_t> outerPositionPerm =
+      computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
 
   SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
   applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
   return packInverseDestPermutation;
 }
 
+/// Shell function to compute the Destination Permutation of PackOp
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major 
diff erence between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
+
+  PackingMetadata pMetadata;
+  int64_t packedRank = packOp.getDestType().getRank();
+  ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+  SmallVector<int64_t> packInvDestPerm =
+      computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
+  return packInvDestPerm;
+}
+
+/// Shell function to compute the Source Permutation of unPackOp.
+/// This function, like the getPackInverseDestPerm uses the helper function
+/// computePackUnPackPerm` to get the permutation vector.
+/// Only major 
diff erence between UnPack and Pack is that packOp uses
+/// destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
+  PackingMetadata metadata;
+  return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
+}
+
+/// Shell function to compute the Source rank permutation for unpackOp
+/// Unpack requires some packing metadata data information, so created
+/// another function where this value is passed by reference.
+SmallVector<int64_t>
+mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
+                                      PackingMetadata &metadata) {
+  int64_t unpackRank = unpackOp.getSourceType().getRank();
+  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
+  SmallVector<int64_t> unpackInvSrcPerm =
+      computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+  return unpackInvSrcPerm;
+}
+
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   llvm::SmallBitVector droppedDims = op.getDroppedDims();
   int64_t srcDim = 0;

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0272ac599aa3db..2d01d57304013c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -697,3 +697,118 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[C01:.*]] = arith.constant 0
+// CHECK: %[[C02:.*]] = arith.constant 0
+// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST14:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
+// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
+// CHECK: %[[empt0:.*]] = tensor.empty
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+   %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op
+   transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+    // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[C256:.*]] = arith.constant 256 : index
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+    // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
+    transform.yield
+  } 
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack_no_masks
+func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+    transform.yield
+  } 
+ }
+
+  // -----
+
+  // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+  func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+    transform.yield
+  } 
+}


        


More information about the Mlir-commits mailing list