[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Balaji V. Iyer. llvmlistbot at llvm.org
Tue Feb 20 13:46:48 PST 2024


https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/76087

>From c0c6432dce8f94b2b2f07595de0973dc12f90d45 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 30 Nov 2023 20:39:55 +0000
Subject: [PATCH 01/12] [mlir][Vectorizer] Vectorize `tensor.unpack`

This patch allows vectorization of a `tensor.unpack` operation.
---
 .../Linalg/Transforms/Vectorization.cpp       | 348 +++++++++++-------
 1 file changed, 220 insertions(+), 128 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..0760ad114b07b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,6 +1400,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+//   Vector::TransferReadOp - Reads the Vector Array of Source data
+//   vector::TransposeOp - Transpose the Source
+//   ShapeCastOp - Reshapes the data based on the target.
+//   vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType packTensorType = unpackOp.getSourceType();
+  auto maskType =
+      VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+  auto vectorType = VectorType::get(packTensorType.getShape(),
+                                    packTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), maskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(packTensorType.getRank(), zeroOp),
+      rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+  vector::MaskOp maskedOp =
+      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  int64_t packRank = packTensorType.getRank();
+  auto lastDims =
+      llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      packRank, lastDims, packMetadata.insertPositions);
+  SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  auto vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+      unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+  vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+      unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+      SmallVector<Value>(lastDims.size(), zeroOp),
+      SmallVector<bool>(lastDims.size(), true));
+
+  newResults.push_back(writeOp->getResult(0));
+  return success();
+}
 
 /// Given a tensor::PackOp, return the `dest` shape before any packing
 /// permutations.
@@ -1748,6 +1830,12 @@ vectorizePackOpPrecondition(tensor::PackOp packOp,
   return success();
 }
 
+static LogicalResult
+vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1801,31 +1889,32 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
 
   return TypeSwitch<Operation *, LogicalResult>(op)
       .Case<linalg::LinalgOp>([&](auto linalgOp) {
-        return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                             vectorizeNDExtract);
+    return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+                                         vectorizeNDExtract);
       })
       .Case<tensor::PadOp>([&](auto padOp) {
-        return vectorizePadOpPrecondition(padOp, inputVectorSizes);
+    return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
       .Case<tensor::PackOp>([&](auto packOp) {
-        return vectorizePackOpPrecondition(packOp, inputVectorSizes);
-      })
-      .Default([](auto) { return failure(); });
+    return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+    .Case<tensor::UnPackOp>([&](auto unpackOp) {
+      return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
+    }).Default([](auto) { return failure(); });
 }
 
 /// Converts affine.apply Ops to arithmetic operations.
 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
-  OpBuilder::InsertionGuard g(rewriter);
-  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
-
-  for (auto op : make_early_inc_range(toReplace)) {
-    rewriter.setInsertionPoint(op);
-    auto expanded = affine::expandAffineExpr(
-        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-        op.getOperands().take_front(op.getAffineMap().getNumDims()),
-        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
-    rewriter.replaceOp(op, expanded);
-  }
+    OpBuilder::InsertionGuard g(rewriter);
+    auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
+
+    for (auto op : make_early_inc_range(toReplace)) {
+      rewriter.setInsertionPoint(op);
+      auto expanded = affine::expandAffineExpr(
+          rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+          op.getOperands().take_front(op.getAffineMap().getNumDims()),
+          op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+      rewriter.replaceOp(op, expanded);
+    }
 }
 
 /// Emit a suitable vector form for an operation. If provided,
@@ -1839,117 +1928,119 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<bool> inputScalableVecDims,
                                       bool vectorizeNDExtract,
                                       bool flatten1DDepthwiseConv) {
-  LDBG("Attempting to vectorize:\n" << *op << "\n");
-  LDBG("Input vector sizes: ");
-  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
-  LLVM_DEBUG(llvm::dbgs() << "\n");
-  LDBG("Input scalable vector dims: ");
-  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
-  LLVM_DEBUG(llvm::dbgs() << "\n");
-
-  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
-                                     vectorizeNDExtract))) {
-    LDBG("Vectorization pre-conditions failed\n");
-    return failure();
-  }
-
-  // Initialize vectorization state.
-  VectorizationState state(rewriter);
-  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
-    if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
-      LDBG("Vectorization state couldn't be initialized\n");
+    LDBG("Attempting to vectorize:\n" << *op << "\n");
+    LDBG("Input vector sizes: ");
+    LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+    LDBG("Input scalable vector dims: ");
+    LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+
+    if (failed(vectorizeOpPrecondition(
+            op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract))) {
+      LDBG("Vectorization pre-conditions failed\n");
       return failure();
     }
-  }
 
-  SmallVector<Value> results;
-  auto vectorizeResult =
-      TypeSwitch<Operation *, LogicalResult>(op)
-          .Case<linalg::LinalgOp>([&](auto linalgOp) {
-            // TODO: isaConvolutionOpInterface that can also infer from generic
-            // features. Will require stride/dilation attributes inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-              FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
-              if (succeeded(convOr)) {
-                llvm::append_range(results, (*convOr)->getResults());
-                return success();
-              }
-
-              LDBG("Unsupported convolution can't be vectorized.\n");
-              return failure();
-            }
-
-            LDBG("Vectorize generic by broadcasting to the canonical vector "
-                 "shape\n");
-
-            // Pre-process before proceeding.
-            convertAffineApply(rewriter, linalgOp);
-
-            // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
-            // to 'OpBuilder' when it is passed over to some methods like
-            // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
-            // erase an op within these methods, the actual rewriter won't be
-            // notified and we will end up with read-after-free issues!
-            return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
-          })
-          .Case<tensor::PadOp>([&](auto padOp) {
-            return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
-                                          results);
-          })
-          .Case<tensor::PackOp>([&](auto packOp) {
-            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
-                                           results);
-          })
-          .Default([](auto) { return failure(); });
-
-  if (failed(vectorizeResult)) {
-    LDBG("Vectorization failed\n");
-    return failure();
-  }
+    // Initialize vectorization state.
+    VectorizationState state(rewriter);
+    if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+      if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
+                                 inputScalableVecDims))) {
+        LDBG("Vectorization state couldn't be initialized\n");
+        return failure();
+      }
+    }
 
-  if (!results.empty())
-    rewriter.replaceOp(op, results);
-  else
-    rewriter.eraseOp(op);
+    SmallVector<Value> results;
+    auto vectorizeResult =
+        TypeSwitch<Operation *, LogicalResult>(op)
+            .Case<linalg::LinalgOp>([&](auto linalgOp) {
+      // TODO: isaConvolutionOpInterface that can also infer from
+      // generic features. Will require stride/dilation attributes
+      // inference.
+      if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+        FailureOr<Operation *> convOr =
+            vectorizeConvolution(rewriter, linalgOp, flatten1DDepthwiseConv);
+        if (succeeded(convOr)) {
+          llvm::append_range(results, (*convOr)->getResults());
+          return success();
+        }
 
-  return success();
+        LDBG("Unsupported convolution can't be vectorized.\n");
+        return failure();
+      }
+
+      LDBG("Vectorize generic by broadcasting to the canonical vector "
+           "shape\n");
+
+      // Pre-process before proceeding.
+      convertAffineApply(rewriter, linalgOp);
+
+      // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+      // to 'OpBuilder' when it is passed over to some methods like
+      // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+      // erase an op within these methods, the actual rewriter won't be
+      // notified and we will end up with read-after-free issues!
+      return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
+            })
+            .Case<tensor::PadOp>([&](auto padOp) {
+      return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results);
+            })
+            .Case<tensor::PackOp>([&](auto packOp) {
+      return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+                                     results);
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+                                   results);
+      }).Default([](auto) { return failure(); });
+
+      if (failed(vectorizeResult)) {
+        LDBG("Vectorization failed\n");
+        return failure();
+      }
+
+      if (!results.empty())
+        rewriter.replaceOp(op, results);
+      else
+        rewriter.eraseOp(op);
+
+      return success();
 }
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
                                           memref::CopyOp copyOp) {
+      auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+      auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
+      if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
+        return failure();
 
-  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
-  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
-  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
-    return failure();
+      auto srcElementType = getElementTypeOrSelf(srcType);
+      auto dstElementType = getElementTypeOrSelf(dstType);
+      if (!VectorType::isValidElementType(srcElementType) ||
+          !VectorType::isValidElementType(dstElementType))
+        return failure();
 
-  auto srcElementType = getElementTypeOrSelf(srcType);
-  auto dstElementType = getElementTypeOrSelf(dstType);
-  if (!VectorType::isValidElementType(srcElementType) ||
-      !VectorType::isValidElementType(dstElementType))
-    return failure();
+      auto readType = VectorType::get(srcType.getShape(), srcElementType);
+      auto writeType = VectorType::get(dstType.getShape(), dstElementType);
 
-  auto readType = VectorType::get(srcType.getShape(), srcElementType);
-  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+      Location loc = copyOp->getLoc();
+      Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      SmallVector<Value> indices(srcType.getRank(), zero);
 
-  Location loc = copyOp->getLoc();
-  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  SmallVector<Value> indices(srcType.getRank(), zero);
-
-  Value readValue = rewriter.create<vector::TransferReadOp>(
-      loc, readType, copyOp.getSource(), indices,
-      rewriter.getMultiDimIdentityMap(srcType.getRank()));
-  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
-    readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
-  }
-  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
-      loc, readValue, copyOp.getTarget(), indices,
-      rewriter.getMultiDimIdentityMap(srcType.getRank()));
-  rewriter.replaceOp(copyOp, writeValue->getResults());
-  return success();
+      Value readValue = rewriter.create<vector::TransferReadOp>(
+          loc, readType, copyOp.getSource(), indices,
+          rewriter.getMultiDimIdentityMap(srcType.getRank()));
+      if (cast<VectorType>(readValue.getType()).getRank() == 0) {
+        readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+        readValue =
+            rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
+      }
+      Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
+          loc, readValue, copyOp.getTarget(), indices,
+          rewriter.getMultiDimIdentityMap(srcType.getRank()));
+      rewriter.replaceOp(copyOp, writeValue->getResults());
+      return success();
 }
 
 //----------------------------------------------------------------------------//
@@ -1958,7 +2049,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
 
 /// Helper function that retrieves the value of an IntegerAttr.
 static int64_t getIntFromAttr(Attribute attr) {
-  return cast<IntegerAttr>(attr).getInt();
+      return cast<IntegerAttr>(attr).getInt();
 }
 
 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -1966,16 +2057,16 @@ static int64_t getIntFromAttr(Attribute attr) {
 /// not supported.
 static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
                                            ArrayRef<OpFoldResult> ofrs) {
-  SmallVector<Value> result;
-  for (auto o : ofrs) {
-    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
-      result.push_back(val);
-    } else {
-      result.push_back(rewriter.create<arith::ConstantIndexOp>(
-          loc, getIntFromAttr(o.template get<Attribute>())));
-    }
-  }
-  return result;
+      SmallVector<Value> result;
+      for (auto o : ofrs) {
+        if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+          result.push_back(val);
+        } else {
+          result.push_back(rewriter.create<arith::ConstantIndexOp>(
+              loc, getIntFromAttr(o.template get<Attribute>())));
+        }
+      }
+      return result;
 }
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
@@ -2050,7 +2141,8 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
     // If `dest` is a FillOp and the TransferWriteOp would overwrite the
     // entire tensor, write directly to the FillOp's operand.
     if (llvm::equal(vecShape, resultType.getShape()) &&
-        llvm::all_of(writeInBounds, [](bool b) { return b; }))
+        llvm::all_of(writeInBounds, [](bool b) {
+      return b; }))
       if (auto fill = dest.getDefiningOp<FillOp>())
         dest = fill.output();
 
@@ -2061,7 +2153,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
         padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
 
     return success();
-  }
+}
 };
 
 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a

>From 853a735ad233ad24c30a36ba7a1a870ced8eb947 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Dec 2023 11:45:20 -0600
Subject: [PATCH 02/12] Enabled tensor.unpack vectorization and added test
 case.

---
 .../TransformOps/LinalgTransformOps.cpp       |  3 ++-
 mlir/test/Dialect/Linalg/vectorization.mlir   | 25 +++++++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..12e3f1a5d0d31e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
+            target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..50d872c95128ba 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,6 +419,31 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
+    // CHECK %[[c0:.*]] = arith.constant 0 : index
+    // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
+    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
+    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+    // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
+    // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+    // CHECK: return %[[tw0]]
+    %8 = tensor.empty() : tensor<100x18176xf32>
+    %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
+    return %unpack : tensor<100x18176xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)

>From a48dfac0d89493134e75c36ea307c64d0c941875 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 19 Jan 2024 19:11:04 -0600
Subject: [PATCH 03/12] Added some of the changes requested by Diego and HanHan

---
 .../Dialect/Linalg/Transforms/Vectorization.cpp  | 16 +++++-----------
 1 file changed, 5 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0760ad114b07b1..f0b9da7aca4171 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,11 +1400,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
-// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-//   Vector::TransferReadOp - Reads the Vector Array of Source data
-//   vector::TransposeOp - Transpose the Source
-//   ShapeCastOp - Reshapes the data based on the target.
-//   vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
 
 static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
@@ -1830,12 +1830,6 @@ vectorizePackOpPrecondition(tensor::PackOp packOp,
   return success();
 }
 
-static LogicalResult
-vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
-                              ArrayRef<int64_t> inputVectorSizes) {
-  return success();
-}
-
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {

>From 70cc122d0f8b77668f51592f6fa7fd78534b1f16 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 6 Feb 2024 23:45:59 +0000
Subject: [PATCH 04/12] Used vectorSizes for masks and added a dynamic shapes
 test case.

---
 .../Linalg/Transforms/Vectorization.cpp       | 109 +++++++++++++-----
 mlir/test/Dialect/Linalg/vectorization.mlir   |  33 +++++-
 2 files changed, 111 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f0b9da7aca4171..866b4e8774f5e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,17 +1400,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
+
 /// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
 ///   Vector::TransferReadOp - Reads the Vector Array of Source data
 ///   vector::TransposeOp - Transpose the Source
 ///   ShapeCastOp - Reshapes the data based on the target.
 ///   vector::TransferWriteOp. - Write the result vector back.
-
 static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
                                          ArrayRef<int64_t> inputVectorSizes,
                                          SmallVectorImpl<Value> &newResults) {
-
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
   if (!unpackOp.getOuterDimsPerm().empty()) {
     LDBG("outer dimensions perms NYI for: " << unpackOp);
     return failure();
@@ -1419,11 +1420,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType packTensorType = unpackOp.getSourceType();
-  auto maskType =
-      VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
-  auto vectorType = VectorType::get(packTensorType.getShape(),
-                                    packTensorType.getElementType());
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+    readMaskShape[ii] = inputVectorSizes[ii];
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  auto vectorType =
+      VectorType::get(readMaskShape, unpackTensorType.getElementType());
   ReifiedRankedShapedTypeDims reifiedRetShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1432,54 +1441,87 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
     LDBG("Unable to reify result shapes of " << unpackOp);
     return failure();
   }
-
+  int64_t unpackRank = unpackTensorType.getRank();
   arith::ConstantIndexOp zeroOp =
       rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
-  Value mask = rewriter.create<vector::CreateMaskOp>(
-      unpackOp.getLoc(), maskType,
-      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
 
   vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
       unpackOp.getLoc(), vectorType, unpackOp.getSource(),
-      SmallVector<Value>(packTensorType.getRank(), zeroOp),
-      rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+      SmallVector<Value>(unpackRank, zeroOp),
+      rewriter.getMultiDimIdentityMap(unpackRank));
 
+  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), readMaskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
   vector::MaskOp maskedOp =
       cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
 
   int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
-  int64_t packRank = packTensorType.getRank();
-  auto lastDims =
-      llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
   PackingMetadata packMetadata =
-      computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
   SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
-      packRank, lastDims, packMetadata.insertPositions);
-  SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+      unpackRank, lastDims, packMetadata.insertPositions);
+  ShapedType maskedOpShapedType =
+      cast<ShapedType>(maskedOp.getResult(0).getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
 
   RankedTensorType stripMineTensorType =
-      RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
 
+  // Collapse the tensor to the size required by result.
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMineTensorType, packMetadata.reassociations);
   auto vecCollapsedType =
       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
 
+  // Transpose the appropriate rows to match output.
   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
       unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
 
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
-  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
-      unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+  tensor::EmptyOp emptyOp =
+      rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+                                       unpackTensorType.getElementType());
 
-  vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+  int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+  Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
       unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
-      SmallVector<Value>(lastDims.size(), zeroOp),
-      SmallVector<bool>(lastDims.size(), true));
-
-  newResults.push_back(writeOp->getResult(0));
+      SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+  auto resultShape = unpackOp.getResult().getType().getShape();
+
+  // If the shape of the result doesn't match the inputVectorSizes, a mask
+  // is necessary.
+  bool needMaskForWrite =
+      llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  mlir::OpResult result = writeOp->getResult(0);
+  if (needMaskForWrite) {
+    SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+    llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+    llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+    for (auto [index, size] : enumerate(innerTiles)) {
+      writeMaskShape[innerDimPos[index]] *= size;
+    }
+    // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+    // innerTiles.
+    // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+    // for-each index, value in inner-Tiles vector:
+    //      WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+    auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+    Value writeMask = rewriter.create<vector::CreateMaskOp>(
+        unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+    Operation *writeOpWithMask =
+        mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+    result = writeOpWithMask->getResult(0);
+  }
+  newResults.push_back(result);
   return success();
 }
 
@@ -1737,6 +1779,19 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+        return !getConstantIntValue(res).has_value();
+      })) {
+    LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult
 vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                               ArrayRef<int64_t> inputVectorSizes,
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 50d872c95128ba..a79ff6bd75795c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -422,11 +422,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func @test_vectorize_unpack
 func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
     // CHECK %[[c0:.*]] = arith.constant 0 : index
-    // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
-    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
-    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+    // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
+    // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
+    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
+    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<2x16x4x16xf32> to vector<32x64xf32>
     // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
-    // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+    // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
+    // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
     // CHECK: return %[[tw0]]
     %8 = tensor.empty() : tensor<100x18176xf32>
     %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
@@ -444,6 +446,29 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
+ // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty
+ // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
+ // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+   %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+   transform.yield
+ }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)

>From c33642b2da876026e13968199adbbba6f6bb2432 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 7 Feb 2024 19:43:48 +0000
Subject: [PATCH 05/12] Added some changes proposed by HanHan.

---
 .../Linalg/Transforms/Vectorization.cpp       | 52 +++++++++++--------
 1 file changed, 29 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 866b4e8774f5e2..4d0c62fbd46cfa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1410,12 +1410,6 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
                                          ArrayRef<int64_t> inputVectorSizes,
                                          SmallVectorImpl<Value> &newResults) {
-  // Handling this case requires a bit more change. Right now
-  // just the required attributes are handled.
-  if (!unpackOp.getOuterDimsPerm().empty()) {
-    LDBG("outer dimensions perms NYI for: " << unpackOp);
-    return failure();
-  }
 
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
@@ -1442,18 +1436,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
     return failure();
   }
   int64_t unpackRank = unpackTensorType.getRank();
+  Location loc = unpackOp->getLoc();
   arith::ConstantIndexOp zeroOp =
-      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+      rewriter.create<arith::ConstantIndexOp>(loc, 0);
 
   vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
-      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      loc, vectorType, unpackOp.getSource(),
       SmallVector<Value>(unpackRank, zeroOp),
       rewriter.getMultiDimIdentityMap(unpackRank));
 
   auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
   Value mask = rewriter.create<vector::CreateMaskOp>(
-      unpackOp.getLoc(), readMaskType,
-      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+      loc, readMaskType,
+      tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()));
   vector::MaskOp maskedOp =
       cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
 
@@ -1474,25 +1469,23 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
       RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
           .setShape(stripMineShape);
 
-  // Collapse the tensor to the size required by result.
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMineTensorType, packMetadata.reassociations);
-  auto vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
-
   // Transpose the appropriate rows to match output.
   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
-      unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+      loc, maskedOp.getResult(0), lastDimToInsertPosPerm);
 
+  // Collapse the vector to the size required by result.
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  mlir::VectorType vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
-      unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
-  tensor::EmptyOp emptyOp =
-      rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
-                                       unpackTensorType.getElementType());
+      loc, vecCollapsedType, transposeOp->getResult(0));
+  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+      loc, reifiedRetShapes[0], unpackTensorType.getElementType());
 
   int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
   Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
-      unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+      loc, shapeCastOp->getResult(0), emptyOp,
       SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
   auto resultShape = unpackOp.getResult().getType().getShape();
 
@@ -1516,7 +1509,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
     //      WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
     auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
     Value writeMask = rewriter.create<vector::CreateMaskOp>(
-        unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+        loc, writeMaskType, reifiedRetShapes[0]);
     Operation *writeOpWithMask =
         mlir::vector::maskOperation(rewriter, writeOp, writeMask);
     result = writeOpWithMask->getResult(0);
@@ -1783,12 +1776,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
 static LogicalResult
 vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
+
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
         return !getConstantIntValue(res).has_value();
       })) {
     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
     return failure();
   }
+  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+  if (inputVectorSizes.empty() == false &&
+      failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
+    return failure();
+
   return success();
 }
 

>From 744a291b346a3f5bf36f8c744b4d28e152ca4f5e Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 9 Feb 2024 17:33:25 +0000
Subject: [PATCH 06/12] Fixed all issues pointed out by HanHan except factoring
 in StripMineTensorType

---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |   3 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  |   2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 544 +++++++++---------
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |  27 +-
 mlir/test/Dialect/Linalg/vectorization.mlir   |  85 ++-
 5 files changed, 322 insertions(+), 339 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index fe9b16cb44b3da..60522ac48d95b5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -38,7 +38,8 @@ computeTransposedType(RankedTensorType rankedTensorType,
 /// i.e. for a pack from an ABCD layout to an ABCDba:
 /// The packed shape would be ABCDba.
 /// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+SmallVector<int64_t> getPackUnPackInverseDestPerm(
+    std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
 
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 596b7c50c1e4e4..9f8ea7f1f3969b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
   SmallVector<int64_t> packedToStripMinedShapePerm =
-      tensor::getPackInverseDestPermutation(packOp);
+      tensor::getPackUnPackInverseDestPerm(packOp);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4d0c62fbd46cfa..420ffe533ff0b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1401,129 +1401,12 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
-/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-///   Vector::TransferReadOp - Reads the Vector Array of Source data
-///   vector::TransposeOp - Transpose the Source
-///   ShapeCastOp - Reshapes the data based on the target.
-///   vector::TransferWriteOp. - Write the result vector back.
-static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
-                                         tensor::UnPackOp unpackOp,
-                                         ArrayRef<int64_t> inputVectorSizes,
-                                         SmallVectorImpl<Value> &newResults) {
-
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(unpackOp);
-
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
-  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
-  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
-    readMaskShape[ii] = inputVectorSizes[ii];
-  }
-
-  // ReadMask is the size of tensor used to read and apply mask. It is
-  // set like this. Let's say the vectorSize (VS) array is size 'N' and
-  // the sourceShape(SS) is 'M' where M >= N
-  // Thus:
-  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
-  auto vectorType =
-      VectorType::get(readMaskShape, unpackTensorType.getElementType());
-  ReifiedRankedShapedTypeDims reifiedRetShapes;
-  LogicalResult status =
-      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
-          .reifyResultShapes(rewriter, reifiedRetShapes);
-  if (status.failed()) {
-    LDBG("Unable to reify result shapes of " << unpackOp);
-    return failure();
-  }
-  int64_t unpackRank = unpackTensorType.getRank();
-  Location loc = unpackOp->getLoc();
-  arith::ConstantIndexOp zeroOp =
-      rewriter.create<arith::ConstantIndexOp>(loc, 0);
-
-  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
-      loc, vectorType, unpackOp.getSource(),
-      SmallVector<Value>(unpackRank, zeroOp),
-      rewriter.getMultiDimIdentityMap(unpackRank));
-
-  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
-  Value mask = rewriter.create<vector::CreateMaskOp>(
-      loc, readMaskType,
-      tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()));
-  vector::MaskOp maskedOp =
-      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
-
-  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
-  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
-  PackingMetadata packMetadata =
-      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
-      unpackRank, lastDims, packMetadata.insertPositions);
-  ShapedType maskedOpShapedType =
-      cast<ShapedType>(maskedOp.getResult(0).getType());
-  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
-  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
-  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
-
-  RankedTensorType stripMineTensorType =
-      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
-          .setShape(stripMineShape);
-
-  // Transpose the appropriate rows to match output.
-  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
-      loc, maskedOp.getResult(0), lastDimToInsertPosPerm);
-
-  // Collapse the vector to the size required by result.
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMineTensorType, packMetadata.reassociations);
-  mlir::VectorType vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
-  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
-      loc, vecCollapsedType, transposeOp->getResult(0));
-  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, reifiedRetShapes[0], unpackTensorType.getElementType());
-
-  int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
-  Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
-      loc, shapeCastOp->getResult(0), emptyOp,
-      SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
-  auto resultShape = unpackOp.getResult().getType().getShape();
-
-  // If the shape of the result doesn't match the inputVectorSizes, a mask
-  // is necessary.
-  bool needMaskForWrite =
-      llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
-                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
-  mlir::OpResult result = writeOp->getResult(0);
-  if (needMaskForWrite) {
-    SmallVector<int64_t> writeMaskShape(inputVectorSizes);
-    llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
-    llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
-    for (auto [index, size] : enumerate(innerTiles)) {
-      writeMaskShape[innerDimPos[index]] *= size;
-    }
-    // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
-    // innerTiles.
-    // WriteMaskShape (WMS) initialized to [inputVectorSizes]
-    // for-each index, value in inner-Tiles vector:
-    //      WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
-    auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
-    Value writeMask = rewriter.create<vector::CreateMaskOp>(
-        loc, writeMaskType, reifiedRetShapes[0]);
-    Operation *writeOpWithMask =
-        mlir::vector::maskOperation(rewriter, writeOp, writeMask);
-    result = writeOpWithMask->getResult(0);
-  }
-  newResults.push_back(result);
-  return success();
-}
-
 /// Given a tensor::PackOp, return the `dest` shape before any packing
 /// permutations.
 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
   return applyPermutation(destShape,
-                          tensor::getPackInverseDestPermutation(packOp));
+                          tensor::getPackUnPackInverseDestPerm(packOp));
 }
 
 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1537,16 +1420,28 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
   assert(sourceShape.size() == readShape.size());
   auto maskType = VectorType::get(readShape, builder.getI1Type());
-  auto vectorType = VectorType::get(readShape, padValue.getType());
+  Type vecElemType = padValue != nullptr
+                         ? padValue.getType()
+                         : cast<ShapedType>(source.getType()).getElementType();
+  auto vectorType = VectorType::get(readShape, vecElemType);
   int64_t readRank = readShape.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  auto transferReadOp = builder.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/source,
-      /*indices=*/SmallVector<Value>(readRank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/SmallVector<bool>(readRank, true));
+  vector::TransferReadOp transferReadOp = nullptr;
+  if (padValue == nullptr) {
+    transferReadOp = builder.create<vector::TransferReadOp>(
+        loc,
+        /*vectorType=*/vectorType,
+        /*source=*/source,
+        /*indices=*/SmallVector<Value>(readRank, zero));
+  } else {
+    transferReadOp = builder.create<vector::TransferReadOp>(
+        loc,
+        /*vectorType=*/vectorType,
+        /*source=*/source,
+        /*indices=*/SmallVector<Value>(readRank, zero),
+        /*padding=*/padValue,
+        /*inBounds=*/SmallVector<bool>(readRank, true));
+  }
   if (llvm::equal(readShape, sourceShape)) {
     return transferReadOp;
   }
@@ -1664,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
 
   // Create TransposeOp.
   auto destPermutation =
-      invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
+      invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp));
   auto transposeOp = rewriter.create<vector::TransposeOp>(
       loc, shapeCastOp.getResult(), destPermutation);
 
@@ -1676,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   return success();
 }
 
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+  SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+  for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
+    readMaskShape[i] = inputVectorSizes[i];
+  }
+  for (auto [index, size] : enumerate(innerTiles)) {
+    readMaskShape[innerDimPos[index]] =
+        llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  int64_t unpackRank = unpackTensorType.getRank();
+  Location loc = unpackOp->getLoc();
+
+  Value readResult = createReadOrMaskedRead(
+      rewriter, loc, unpackOp.getSource(),
+      llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
+      nullptr);
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      unpackRank, lastDims, packMetadata.insertPositions);
+  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
+
+  // Transpose the appropriate rows to match output.
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      loc, readResult, lastDimToInsertPosPerm);
+
+  // Collapse the vector to the size required by result.
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  mlir::VectorType vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      loc, vecCollapsedType, transposeOp->getResult(0));
+
+  SmallVector<int64_t> writeMaskShape(
+      shapeCastOp.getResultVectorType().getShape());
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
+                               reifiedRetShapes[0], writeMaskShape);
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1774,11 +1753,12 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
 
 /// Need to check if the inner-tiles are static/constant.
 static LogicalResult
-vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
 
   // Handling this case requires a bit more change. Right now
   // just the required attributes are handled.
+  // TODO: Handle OuterDimsPerm.
   if (!unpackOp.getOuterDimsPerm().empty()) {
     LDBG("outer dimensions perms NYI for: " << unpackOp);
     return failure();
@@ -1846,9 +1826,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   }
   if (isElementwise(linalgOp))
     return success();
-  // TODO: isaConvolutionOpInterface that can also infer from generic features.
-  // But we will still need stride/dilation attributes that will be annoying to
-  // reverse-engineer...
+
+  // TODO: isaConvolutionOpInterface that can also infer from generic
+  // features. But we will still need stride/dilation attributes that will be
+  // annoying to reverse-engineer...
   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
     return success();
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -1944,158 +1925,162 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
 
   return TypeSwitch<Operation *, LogicalResult>(op)
       .Case<linalg::LinalgOp>([&](auto linalgOp) {
-    return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                         vectorizeNDExtract);
+        return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+                                             vectorizeNDExtract);
       })
       .Case<tensor::PadOp>([&](auto padOp) {
-    return vectorizePadOpPrecondition(padOp, inputVectorSizes);
+        return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
       .Case<tensor::PackOp>([&](auto packOp) {
-    return vectorizePackOpPrecondition(packOp, inputVectorSizes);
-    .Case<tensor::UnPackOp>([&](auto unpackOp) {
-      return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
-    }).Default([](auto) { return failure(); });
+        return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+      })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
+      })
+      .Default([](auto) { return failure(); });
 }
 
 /// Converts affine.apply Ops to arithmetic operations.
 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
-    OpBuilder::InsertionGuard g(rewriter);
-    auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
-
-    for (auto op : make_early_inc_range(toReplace)) {
-      rewriter.setInsertionPoint(op);
-      auto expanded = affine::expandAffineExpr(
-          rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-          op.getOperands().take_front(op.getAffineMap().getNumDims()),
-          op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
-      rewriter.replaceOp(op, expanded);
-    }
+  OpBuilder::InsertionGuard g(rewriter);
+  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
+
+  for (auto op : make_early_inc_range(toReplace)) {
+    rewriter.setInsertionPoint(op);
+    auto expanded = affine::expandAffineExpr(
+        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+        op.getOperands().take_front(op.getAffineMap().getNumDims()),
+        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+    rewriter.replaceOp(op, expanded);
+  }
 }
 
 /// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the input
-/// vector sizes must be greater than or equal to their counterpart iteration
-/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
-/// operations with dynamic shapes.
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes.
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
                                       bool vectorizeNDExtract,
                                       bool flatten1DDepthwiseConv) {
-    LDBG("Attempting to vectorize:\n" << *op << "\n");
-    LDBG("Input vector sizes: ");
-    LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
-    LLVM_DEBUG(llvm::dbgs() << "\n");
-    LDBG("Input scalable vector dims: ");
-    LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
-    LLVM_DEBUG(llvm::dbgs() << "\n");
-
-    if (failed(vectorizeOpPrecondition(
-            op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract))) {
-      LDBG("Vectorization pre-conditions failed\n");
-      return failure();
-    }
-
-    // Initialize vectorization state.
-    VectorizationState state(rewriter);
-    if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
-      if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                                 inputScalableVecDims))) {
-        LDBG("Vectorization state couldn't be initialized\n");
-        return failure();
-      }
-    }
-
-    SmallVector<Value> results;
-    auto vectorizeResult =
-        TypeSwitch<Operation *, LogicalResult>(op)
-            .Case<linalg::LinalgOp>([&](auto linalgOp) {
-      // TODO: isaConvolutionOpInterface that can also infer from
-      // generic features. Will require stride/dilation attributes
-      // inference.
-      if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-        FailureOr<Operation *> convOr =
-            vectorizeConvolution(rewriter, linalgOp, flatten1DDepthwiseConv);
-        if (succeeded(convOr)) {
-          llvm::append_range(results, (*convOr)->getResults());
-          return success();
-        }
+  LDBG("Attempting to vectorize:\n" << *op << "\n");
+  LDBG("Input vector sizes: ");
+  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
+  LDBG("Input scalable vector dims: ");
+  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
 
-        LDBG("Unsupported convolution can't be vectorized.\n");
-        return failure();
-      }
+  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
+                                     vectorizeNDExtract))) {
+    LDBG("Vectorization pre-conditions failed\n");
+    return failure();
+  }
 
-      LDBG("Vectorize generic by broadcasting to the canonical vector "
-           "shape\n");
-
-      // Pre-process before proceeding.
-      convertAffineApply(rewriter, linalgOp);
-
-      // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
-      // to 'OpBuilder' when it is passed over to some methods like
-      // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
-      // erase an op within these methods, the actual rewriter won't be
-      // notified and we will end up with read-after-free issues!
-      return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
-            })
-            .Case<tensor::PadOp>([&](auto padOp) {
-      return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results);
-            })
-            .Case<tensor::PackOp>([&](auto packOp) {
-      return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
-                                     results);
-      .Case<tensor::UnPackOp>([&](auto unpackOp) {
-        return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
-                                   results);
-      }).Default([](auto) { return failure(); });
+  // Initialize vectorization state.
+  VectorizationState state(rewriter);
+  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+    if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
+                               inputScalableVecDims))) {
+      LDBG("Vectorization state couldn't be initialized\n");
+      return failure();
+    }
+  }
 
-      if (failed(vectorizeResult)) {
-        LDBG("Vectorization failed\n");
-        return failure();
-      }
+  SmallVector<Value> results;
+  auto vectorizeResult =
+      TypeSwitch<Operation *, LogicalResult>(op)
+          .Case<linalg::LinalgOp>([&](auto linalgOp) {
+            // TODO: isaConvolutionOpInterface that can also infer from
+            // generic features. Will require stride/dilation attributes
+            // inference.
+            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+              FailureOr<Operation *> convOr = vectorizeConvolution(
+                  rewriter, linalgOp, flatten1DDepthwiseConv);
+              if (succeeded(convOr)) {
+                llvm::append_range(results, (*convOr)->getResults());
+                return success();
+              }
+
+              LDBG("Unsupported convolution can't be vectorized.\n");
+              return failure();
+            }
+
+            LDBG("Vectorize generic by broadcasting to the canonical vector "
+                 "shape\n");
+
+            // Pre-process before proceeding.
+            convertAffineApply(rewriter, linalgOp);
+
+            // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+            // to 'OpBuilder' when it is passed over to some methods like
+            // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+            // erase an op within these methods, the actual rewriter won't be
+            // notified and we will end up with read-after-free issues!
+            return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
+          })
+          .Case<tensor::PadOp>([&](auto padOp) {
+            return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
+                                          results);
+          })
+          .Case<tensor::PackOp>([&](auto packOp) {
+            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+                                           results);
+          })
+          .Case<tensor::UnPackOp>([&](auto unpackOp) {
+            return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+                                       results);
+          })
+          .Default([](auto) { return failure(); });
+
+  if (failed(vectorizeResult)) {
+    LDBG("Vectorization failed\n");
+    return failure();
+  }
 
-      if (!results.empty())
-        rewriter.replaceOp(op, results);
-      else
-        rewriter.eraseOp(op);
+  if (!results.empty())
+    rewriter.replaceOp(op, results);
+  else
+    rewriter.eraseOp(op);
 
-      return success();
+  return success();
 }
 
 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
                                           memref::CopyOp copyOp) {
-      auto srcType = cast<MemRefType>(copyOp.getSource().getType());
-      auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
-      if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
-        return failure();
-
-      auto srcElementType = getElementTypeOrSelf(srcType);
-      auto dstElementType = getElementTypeOrSelf(dstType);
-      if (!VectorType::isValidElementType(srcElementType) ||
-          !VectorType::isValidElementType(dstElementType))
-        return failure();
+  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
+  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
+    return failure();
 
-      auto readType = VectorType::get(srcType.getShape(), srcElementType);
-      auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+  auto srcElementType = getElementTypeOrSelf(srcType);
+  auto dstElementType = getElementTypeOrSelf(dstType);
+  if (!VectorType::isValidElementType(srcElementType) ||
+      !VectorType::isValidElementType(dstElementType))
+    return failure();
 
-      Location loc = copyOp->getLoc();
-      Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-      SmallVector<Value> indices(srcType.getRank(), zero);
+  auto readType = VectorType::get(srcType.getShape(), srcElementType);
+  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
 
-      Value readValue = rewriter.create<vector::TransferReadOp>(
-          loc, readType, copyOp.getSource(), indices,
-          rewriter.getMultiDimIdentityMap(srcType.getRank()));
-      if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-        readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
-        readValue =
-            rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
-      }
-      Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
-          loc, readValue, copyOp.getTarget(), indices,
-          rewriter.getMultiDimIdentityMap(srcType.getRank()));
-      rewriter.replaceOp(copyOp, writeValue->getResults());
-      return success();
+  Location loc = copyOp->getLoc();
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  SmallVector<Value> indices(srcType.getRank(), zero);
+
+  Value readValue = rewriter.create<vector::TransferReadOp>(
+      loc, readType, copyOp.getSource(), indices,
+      rewriter.getMultiDimIdentityMap(srcType.getRank()));
+  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
+    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
+  }
+  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
+      loc, readValue, copyOp.getTarget(), indices,
+      rewriter.getMultiDimIdentityMap(srcType.getRank()));
+  rewriter.replaceOp(copyOp, writeValue->getResults());
+  return success();
 }
 
 //----------------------------------------------------------------------------//
@@ -2104,7 +2089,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
 
 /// Helper function that retrieves the value of an IntegerAttr.
 static int64_t getIntFromAttr(Attribute attr) {
-      return cast<IntegerAttr>(attr).getInt();
+  return cast<IntegerAttr>(attr).getInt();
 }
 
 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -2112,16 +2097,16 @@ static int64_t getIntFromAttr(Attribute attr) {
 /// not supported.
 static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
                                            ArrayRef<OpFoldResult> ofrs) {
-      SmallVector<Value> result;
-      for (auto o : ofrs) {
-        if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
-          result.push_back(val);
-        } else {
-          result.push_back(rewriter.create<arith::ConstantIndexOp>(
-              loc, getIntFromAttr(o.template get<Attribute>())));
-        }
-      }
-      return result;
+  SmallVector<Value> result;
+  for (auto o : ofrs) {
+    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+      result.push_back(val);
+    } else {
+      result.push_back(rewriter.create<arith::ConstantIndexOp>(
+          loc, getIntFromAttr(o.template get<Attribute>())));
+    }
+  }
+  return result;
 }
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
@@ -2196,8 +2181,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
     // If `dest` is a FillOp and the TransferWriteOp would overwrite the
     // entire tensor, write directly to the FillOp's operand.
     if (llvm::equal(vecShape, resultType.getShape()) &&
-        llvm::all_of(writeInBounds, [](bool b) {
-      return b; }))
+        llvm::all_of(writeInBounds, [](bool b) { return b; }))
       if (auto fill = dest.getDefiningOp<FillOp>())
         dest = fill.output();
 
@@ -2208,7 +2192,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
         padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
 
     return success();
-}
+  }
 };
 
 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
@@ -2980,8 +2964,8 @@ struct Conv1DGenerator
     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
                                                         resPadding);
 
-    // The base vectorization case for channeled convolution is input: {n,w,c},
-    // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
+    // The base vectorization case for channeled convolution is input:
+    // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
     // vectorization case, we do pre transpose on input, weight, and output.
     switch (conv1DOpOrder) {
     case Conv1DOpOrder::W:
@@ -3024,9 +3008,9 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
-    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
-    // perform outerproduct for non-channeled convolution or
-    // perform simple arith operation for pooling
+    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
+    // or perform outerproduct for non-channeled convolution or perform simple
+    // arith operation for pooling
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         switch (oper) {
@@ -3055,9 +3039,9 @@ struct Conv1DGenerator
     // End vector-only rewrite part
     //===------------------------------------------------------------------===//
 
-    // The base vectorization case for channeled convolution is output: {n,w,f}
-    // To reuse the result from base pattern vectorization case, we post
-    // transpose the base case result.
+    // The base vectorization case for channeled convolution is output:
+    // {n,w,f} To reuse the result from base pattern vectorization case, we
+    // post transpose the base case result.
     switch (conv1DOpOrder) {
     case Conv1DOpOrder::W:
     case Conv1DOpOrder::Nwc:
@@ -3495,9 +3479,9 @@ static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can simply
-  // use them if present, otherwise use the default and let the generic conv.
-  // matcher in the ConvGenerator succeed or fail.
+  // strides/dilations. However, we do not need to rely on those, we can
+  // simply use them if present, otherwise use the default and let the generic
+  // conv. matcher in the ConvGenerator succeed or fail.
   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f20008a1ed2b2f..6303dec81327a0 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -73,25 +73,38 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
   return transposedTensorType;
 }
 
-SmallVector<int64_t>
-mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
+SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
+    std::variant<tensor::PackOp, tensor::UnPackOp> op) {
+
+  llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
+  RankedTensorType destType;
+  if (std::holds_alternative<tensor::PackOp>(op)) {
+    tensor::PackOp packOp = std::get<tensor::PackOp>(op);
+    innerDimsPos = packOp.getInnerDimsPos();
+    destType = packOp.getDestType();
+    outerPerm = packOp.getOuterDimsPerm();
+  } else {
+    tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
+    innerDimsPos = unpackOp.getInnerDimsPos();
+    destType = unpackOp.getDestType();
+    outerPerm = unpackOp.getOuterDimsPerm();
+  }
   // The permutation can be obtained from two permutations:
   //   a) Compute the permutation vector to move the last `numPackedDims` into
   //      the `innerPosDims` of a shape of rank `packedRank`.
   //   b) Compute the permutation vector to move outer dims if the pack op
   //      has outer_dims_perm.
   // Apply (b) permutation on (a) permutation to get the final permutation.
-  int64_t numPackedDims = packOp.getInnerDimsPos().size();
-  int64_t packedRank = packOp.getDestType().getRank();
+  int64_t numPackedDims = innerDimsPos.size();
+  int64_t packedRank = destType.getRank();
   auto lastDims = llvm::to_vector(
       llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata = computePackingMetadata(
-      packOp.getDestType().getRank(), packOp.getInnerDimsPos());
+  PackingMetadata packingMetadata =
+      computePackingMetadata(destType.getRank(), innerDimsPos);
   SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
       packedRank, lastDims, packingMetadata.insertPositions);
 
   SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
-  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
   if (!outerPerm.empty())
     applyPermutationToVector(outerPos, outerPerm);
   SmallVector<int64_t> outerPositionPerm = computePermutationVector(
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a79ff6bd75795c..76ea8d83b3c0cf 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,56 +419,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func @test_vectorize_unpack
-func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
-    // CHECK %[[c0:.*]] = arith.constant 0 : index
-    // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
-    // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
-    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
-    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<2x16x4x16xf32> to vector<32x64xf32>
-    // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
-    // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
-    // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
-    // CHECK: return %[[tw0]]
-    %8 = tensor.empty() : tensor<100x18176xf32>
-    %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
-    return %unpack : tensor<100x18176xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
-      : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
-func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
- // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
- // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
- // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
- // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
- // CHECK: %[[empt0:.*]] = tensor.empty
- // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
- // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
- // CHECK: return %[[write0]]
- %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
- return %ret : tensor<?x?xf32>
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-   %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-   transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
-   transform.yield
- }
-}
-
-// -----
-
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
@@ -722,3 +672,38 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[C01:.*]] = arith.constant 0
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C02:.*]] = arith.constant 0
+// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST15:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST15]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
+// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHEdCK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<4x16xf32>
+// CHEdCK: %[[empt0:.*]] = tensor.empty
+// CHEdCK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
+// CHEdCK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHEdCK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+   %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+   transform.yield
+ }
+}

>From d5a0dec6194c6f843765ef202bf5cd3a4150ed99 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 9 Feb 2024 23:14:22 +0000
Subject: [PATCH 07/12] Fixed all the issues pointed out by HanHan and Diego.

---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |  5 ++
 .../Linalg/Transforms/Vectorization.cpp       | 30 ++++-----
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 48 +++++++------
 mlir/test/Dialect/Linalg/vectorization.mlir   | 67 +++++++++++++++++--
 4 files changed, 108 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 60522ac48d95b5..8c8107e0507d70 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -41,6 +41,11 @@ computeTransposedType(RankedTensorType rankedTensorType,
 SmallVector<int64_t> getPackUnPackInverseDestPerm(
     std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
 
+/// Unpack requires some packing metadata data, so create another
+/// function where this value is passed by reference.
+SmallVector<int64_t> getPackUnPackInverseDestPerm(
+    std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
+    PackingMetadata &PackingMetadata);
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with
 /// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 420ffe533ff0b3..8c5fb1b03d033f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1571,11 +1571,12 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   return success();
 }
 
-/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-///   Vector::TransferReadOp - Reads the Vector Array of Source data
-///   vector::TransposeOp - Transpose the Source
-///   ShapeCastOp - Reshapes the data based on the target.
-///   vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   vector::TransposeOp - Transpose the Source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
 static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
                                          tensor::UnPackOp unpackOp,
                                          ArrayRef<int64_t> inputVectorSizes,
@@ -1610,26 +1611,21 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
     LDBG("Unable to reify result shapes of " << unpackOp);
     return failure();
   }
-  int64_t unpackRank = unpackTensorType.getRank();
   Location loc = unpackOp->getLoc();
 
+  // Read result, mask if necessary.
   Value readResult = createReadOrMaskedRead(
       rewriter, loc, unpackOp.getSource(),
       llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
       nullptr);
 
-  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
-  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
-  PackingMetadata packMetadata =
-      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
-      unpackRank, lastDims, packMetadata.insertPositions);
+  PackingMetadata packMetadata;
+  SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
+      tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
-
   RankedTensorType stripMineTensorType =
       RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
           .setShape(stripMineShape);
@@ -1646,8 +1642,12 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       loc, vecCollapsedType, transposeOp->getResult(0));
 
+  // WriteMaskShape had to match the shapecast shape for dynamic sizes,
+  // otherwise the validator complains that the mask size is invalid.
   SmallVector<int64_t> writeMaskShape(
-      shapeCastOp.getResultVectorType().getShape());
+      unpackOp.getDestType().hasStaticShape()
+          ? inputVectorSizes
+          : shapeCastOp.getResultVectorType().getShape());
   Operation *write =
       createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
                                reifiedRetShapes[0], writeMaskShape);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 6303dec81327a0..0902e33a1f19fd 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -75,18 +75,26 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
 
 SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
     std::variant<tensor::PackOp, tensor::UnPackOp> op) {
+  PackingMetadata pMetaData;
+  return getPackUnPackInverseDestPerm(op, pMetaData);
+}
+
+SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
+    std::variant<tensor::PackOp, tensor::UnPackOp> op,
+    PackingMetadata &packingMetadata) {
 
   llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
-  RankedTensorType destType;
-  if (std::holds_alternative<tensor::PackOp>(op)) {
+  int64_t rank = 0;
+  bool isPackOp = std::holds_alternative<tensor::PackOp>(op);
+  if (isPackOp) {
     tensor::PackOp packOp = std::get<tensor::PackOp>(op);
     innerDimsPos = packOp.getInnerDimsPos();
-    destType = packOp.getDestType();
+    rank = packOp.getDestType().getRank();
     outerPerm = packOp.getOuterDimsPerm();
   } else {
     tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
     innerDimsPos = unpackOp.getInnerDimsPos();
-    destType = unpackOp.getDestType();
+    rank = unpackOp.getSourceType().getRank();
     outerPerm = unpackOp.getOuterDimsPerm();
   }
   // The permutation can be obtained from two permutations:
@@ -96,23 +104,21 @@ SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
   //      has outer_dims_perm.
   // Apply (b) permutation on (a) permutation to get the final permutation.
   int64_t numPackedDims = innerDimsPos.size();
-  int64_t packedRank = destType.getRank();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata =
-      computePackingMetadata(destType.getRank(), innerDimsPos);
-  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
-  if (!outerPerm.empty())
-    applyPermutationToVector(outerPos, outerPerm);
-  SmallVector<int64_t> outerPositionPerm = computePermutationVector(
-      packedRank, packingMetadata.outerPositions, outerPos);
-
-  SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
-  applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
-  return packInverseDestPermutation;
+  auto lastDims =
+      llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
+  packingMetadata = computePackingMetadata(rank, innerDimsPos);
+  SmallVector<int64_t> innerPositionsPerm =
+      computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
+
+  if (isPackOp) {
+    SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+    if (!outerPerm.empty())
+      applyPermutationToVector(outerPos, outerPerm);
+    SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+        rank, packingMetadata.outerPositions, outerPos);
+    applyPermutationToVector(innerPositionsPerm, outerPositionPerm);
+  }
+  return innerPositionsPerm;
 }
 
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 76ea8d83b3c0cf..0c8a76d5231f02 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -691,12 +691,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
 // CHECK: %[[CNST2:.*]] = arith.constant 2 : index
 // CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
 // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
-// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
-// CHEdCK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<4x16xf32>
-// CHEdCK: %[[empt0:.*]] = tensor.empty
-// CHEdCK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
-// CHEdCK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
-// CHEdCK: return %[[write0]]
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32> to vector<32x2xf32>
+// CHECK: %[[empt0:.*]] = tensor.empty
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1>
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: return %[[write0]]
  %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
  return %ret : tensor<?x?xf32>
 }
@@ -707,3 +707,58 @@ module attributes {transform.with_named_sequence} {
    transform.yield
  }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+    // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[C256:.*]] = arith.constant 256 : index
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+    // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
+    transform.yield
+  } 
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack_no_masks
+func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+    transform.yield
+  } }
\ No newline at end of file

>From 59d761fae181acf4e66075696cd46bca5c609db5 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 15 Feb 2024 23:24:24 +0000
Subject: [PATCH 08/12] Added all the changes requested by Diego and Max
 (Except handling of outer Dimensions attribute)

---
 .../Linalg/Transforms/Vectorization.cpp       | 66 +++++++++----------
 mlir/test/Dialect/Linalg/vectorization.mlir   | 17 ++---
 2 files changed, 42 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c5fb1b03d033f..f57fae3baa9e6b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1420,28 +1420,16 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
   assert(sourceShape.size() == readShape.size());
   auto maskType = VectorType::get(readShape, builder.getI1Type());
-  Type vecElemType = padValue != nullptr
-                         ? padValue.getType()
-                         : cast<ShapedType>(source.getType()).getElementType();
-  auto vectorType = VectorType::get(readShape, vecElemType);
+  auto vectorType = VectorType::get(readShape, padValue.getType());
   int64_t readRank = readShape.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  vector::TransferReadOp transferReadOp = nullptr;
-  if (padValue == nullptr) {
-    transferReadOp = builder.create<vector::TransferReadOp>(
-        loc,
-        /*vectorType=*/vectorType,
-        /*source=*/source,
-        /*indices=*/SmallVector<Value>(readRank, zero));
-  } else {
-    transferReadOp = builder.create<vector::TransferReadOp>(
-        loc,
-        /*vectorType=*/vectorType,
-        /*source=*/source,
-        /*indices=*/SmallVector<Value>(readRank, zero),
-        /*padding=*/padValue,
-        /*inBounds=*/SmallVector<bool>(readRank, true));
-  }
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
   if (llvm::equal(readShape, sourceShape)) {
     return transferReadOp;
   }
@@ -1588,21 +1576,32 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
   SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
-  llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
-  llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
   for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
     readMaskShape[i] = inputVectorSizes[i];
   }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
+  // size M-N
+  // Thus:
+  // ReadMaskShape (initial) = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  // Then divide all the readMaskShape locations pointed by innerDimPos
+  // by the innerTileSize attribute value.
+  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
+  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
+  // 128] then read shape is:
+  //   ReadMaskShape(initial): [8, 8, 32, 16]
+  //   After settin vectorSizes: [512, 128, 32, 16]
+  //   Final Value(after innerDim Adjustment): [512/32, 128/16, 32, 16]
+  //                                           = [16, 8, 32, 16]
   for (auto [index, size] : enumerate(innerTiles)) {
     readMaskShape[innerDimPos[index]] =
         llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
   }
 
-  // ReadMask is the size of tensor used to read and apply mask. It is
-  // set like this. Let's say the vectorSize (VS) array is size 'N' and
-  // the sourceShape(SS) is 'M' where M >= N
-  // Thus:
-  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
   ReifiedRankedShapedTypeDims reifiedRetShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1613,11 +1612,14 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   }
   Location loc = unpackOp->getLoc();
 
-  // Read result, mask if necessary.
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
+
+  // Read result, mask if necessary. If transferReadOp shape is not equal
+  // to shape of source, then a mask is necessary.
   Value readResult = createReadOrMaskedRead(
       rewriter, loc, unpackOp.getSource(),
-      llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
-      nullptr);
+      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
 
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
@@ -1627,9 +1629,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
   RankedTensorType stripMineTensorType =
-      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
-          .setShape(stripMineShape);
-
+      RankedTensorType::get(stripMineShape, stripMineElemType);
   // Transpose the appropriate rows to match output.
   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
       loc, readResult, lastDimToInsertPosPerm);
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0c8a76d5231f02..757cc46093daf9 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -681,12 +681,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
 // CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
 // CHECK: %[[C01:.*]] = arith.constant 0
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[C02:.*]] = arith.constant 0
 // CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST15:.*]] = arith.constant 1
-// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST15]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST14:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
 // CHECK: %[[CNST16:.*]] = arith.constant 16 : index
 // CHECK: %[[CNST2:.*]] = arith.constant 2 : index
 // CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
@@ -703,7 +703,7 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
 module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-   transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op
    transform.yield
  }
 }
@@ -712,13 +712,13 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func @test_vectorize_unpack
 func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
-    // CHECK: %[[C0:.*]]= arith.constant 0 : index
     // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
     // CHECK: %[[C8:.*]] = arith.constant 8 : index
     // CHECK: %[[C80:.*]] = arith.constant 8 : index
     // CHECK: %[[C32:.*]] = arith.constant 32 : index
     // CHECK: %[[C16:.*]] = arith.constant 16 : index
-    // CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
     // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
     // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
     // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
@@ -744,8 +744,8 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
 
 // CHECK-LABEL: func @test_vectorize_unpack_no_masks
 func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
   // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
   // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
@@ -761,4 +761,5 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
     transform.yield
-  } }
\ No newline at end of file
+  } 
+}

>From c7ed75e39f79fdf8c4de880c7ea8d1800be347d4 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 16 Feb 2024 01:45:50 +0000
Subject: [PATCH 09/12] Added outer_dims_perm support to unpack.

---
 .../Linalg/Transforms/Vectorization.cpp       | 41 ++++++++++---------
 mlir/test/Dialect/Linalg/vectorization.mlir   | 32 +++++++++++++++
 2 files changed, 53 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f57fae3baa9e6b..0aa43b6c863e28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1575,28 +1575,37 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
 
   RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
-  SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
-  for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
-    readMaskShape[i] = inputVectorSizes[i];
+
+  SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
+                                     inputVectorSizes.end());
+  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty() == false) {
+    applyPermutationToVector(readMaskShape, outerDimsPerm);
   }
+  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+                       sourceShape.end());
 
   // ReadMask is the size of tensor used to read and apply mask. It is
-  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // set like this: Let's say the vectorSize (VS) array is size 'N' and
   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
   // size M-N
   // Thus:
-  // ReadMaskShape (initial) = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
-  // Then divide all the readMaskShape locations pointed by innerDimPos
-  // by the innerTileSize attribute value.
+  // - initially: ReadMaskShape = vectorInputSizes
+  // - if outer_dims_perms is present: do that permutation on readMaskShape.
+  // - Append the remaining shape from SS
+  // - Divide all teh readMaskShape locations pointed by innerDimPos
+  //   by the innerTileSize attribute value.
   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
-  // 128] then read shape is:
-  //   ReadMaskShape(initial): [8, 8, 32, 16]
-  //   After settin vectorSizes: [512, 128, 32, 16]
-  //   Final Value(after innerDim Adjustment): [512/32, 128/16, 32, 16]
-  //                                           = [16, 8, 32, 16]
+  // 128] and outer_dims_perm is [1, 0] then read shape is:
+  //   ReadMaskShape(initial): [512, 128]
+  //   After applying outer_dims_perm: [128, 512]
+  //   After appending the rest of the sourceShape: [128, 512, 32, 16]
+  //   Final Value(after innerDim Adjustment): [128/32, 512/16, 32, 16]
+  //                                           = [4, 32, 32, 16]
   for (auto [index, size] : enumerate(innerTiles)) {
     readMaskShape[innerDimPos[index]] =
         llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
@@ -1756,14 +1765,6 @@ static LogicalResult
 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
 
-  // Handling this case requires a bit more change. Right now
-  // just the required attributes are handled.
-  // TODO: Handle OuterDimsPerm.
-  if (!unpackOp.getOuterDimsPerm().empty()) {
-    LDBG("outer dimensions perms NYI for: " << unpackOp);
-    return failure();
-  }
-
   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
         return !getConstantIntValue(res).has_value();
       })) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 757cc46093daf9..3d37c657740055 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -762,4 +762,36 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
    transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
     transform.yield
   } 
+ }
+
+  // -----
+
+  // This test is same as the one test_vectorize_unpack_no_masks but with outer_dims_perm.
+  // Note that adding this attribute causes a read mask.
+
+  // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+  func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: %[[C80:.*]] = arith.constant 8 : index
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[C16:.*]] = arith.constant 16 : index
+  // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
+  // CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<4x16x32x16xf32> to vector<4x32x16x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x32x16x16xf32> to vector<128x256xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+    transform.yield
+  } 
 }

>From a349b1446ee399dca59164ffef4dfc130bb1202f Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 16 Feb 2024 17:37:58 +0000
Subject: [PATCH 10/12] Fixed all the issues mentioned by Diego on 2/16.

---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h | 18 ++--
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 13 ++-
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 85 ++++++++++---------
 mlir/test/Dialect/Linalg/vectorization.mlir   |  4 +-
 5 files changed, 68 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 8c8107e0507d70..009702f126eaf3 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -38,14 +38,22 @@ computeTransposedType(RankedTensorType rankedTensorType,
 /// i.e. for a pack from an ABCD layout to an ABCDba:
 /// The packed shape would be ABCDba.
 /// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackUnPackInverseDestPerm(
-    std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
+SmallVector<int64_t> computePackUnPackPerm(int64_t rank,
+                                           ArrayRef<int64_t> &innerDimsPos,
+                                           ArrayRef<int64_t> &outerPerm,
+                                           PackingMetadata &packingMetadata);
+
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major difference between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
 
 /// Unpack requires some packing metadata data, so create another
 /// function where this value is passed by reference.
-SmallVector<int64_t> getPackUnPackInverseDestPerm(
-    std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
-    PackingMetadata &PackingMetadata);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
+                                             PackingMetadata &metadata);
+
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with
 /// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9f8ea7f1f3969b..850cb861672ad6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
   SmallVector<int64_t> packedToStripMinedShapePerm =
-      tensor::getPackUnPackInverseDestPerm(packOp);
+      tensor::getPackInverseDestPerm(packOp);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0aa43b6c863e28..f066967c4a9097 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 /// permutations.
 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
-  return applyPermutation(destShape,
-                          tensor::getPackUnPackInverseDestPerm(packOp));
+  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
 }
 
 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
 
   // Create TransposeOp.
   auto destPermutation =
-      invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp));
+      invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
   auto transposeOp = rewriter.create<vector::TransposeOp>(
       loc, shapeCastOp.getResult(), destPermutation);
 
@@ -1559,7 +1558,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   return success();
 }
 
-/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
 ///   Vector::TransferReadOp - Reads a vector from the source tensor
 ///   vector::TransposeOp - Transpose the Source tensor
 ///   ShapeCastOp - Reshape the data based on the target.
@@ -1581,7 +1580,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
                                      inputVectorSizes.end());
   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
-  if (outerDimsPerm.empty() == false) {
+  if (!outerDimsPerm.empty()) {
     applyPermutationToVector(readMaskShape, outerDimsPerm);
   }
   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
@@ -1632,7 +1631,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
 
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
-      tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
+      tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata));
   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
@@ -1772,7 +1771,7 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
     return failure();
   }
   llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
-  if (inputVectorSizes.empty() == false &&
+  if (!inputVectorSizes.empty() &&
       failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
     return failure();
 
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 0902e33a1f19fd..f1126aaf44c76c 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -72,37 +72,15 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
       RTTBuilder(rankedTensorType).setShape(transposedShape);
   return transposedTensorType;
 }
-
-SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
-    std::variant<tensor::PackOp, tensor::UnPackOp> op) {
-  PackingMetadata pMetaData;
-  return getPackUnPackInverseDestPerm(op, pMetaData);
-}
-
-SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
-    std::variant<tensor::PackOp, tensor::UnPackOp> op,
+/// The permutation can be obtained from two permutations:
+///   a) Compute the permutation vector to move the last `numPackedDims` into
+///      the `innerPosDims` of a shape of rank `rank`.
+///   b) Compute the permutation vector to move outer dims if the
+///      `outerPerm` parameter is not empty.
+/// Apply (b) permutation on (a) permutation to get the final permutation.
+SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
+    int64_t rank, ArrayRef<int64_t> &innerDimsPos, ArrayRef<int64_t> &outerPerm,
     PackingMetadata &packingMetadata) {
-
-  llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
-  int64_t rank = 0;
-  bool isPackOp = std::holds_alternative<tensor::PackOp>(op);
-  if (isPackOp) {
-    tensor::PackOp packOp = std::get<tensor::PackOp>(op);
-    innerDimsPos = packOp.getInnerDimsPos();
-    rank = packOp.getDestType().getRank();
-    outerPerm = packOp.getOuterDimsPerm();
-  } else {
-    tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
-    innerDimsPos = unpackOp.getInnerDimsPos();
-    rank = unpackOp.getSourceType().getRank();
-    outerPerm = unpackOp.getOuterDimsPerm();
-  }
-  // The permutation can be obtained from two permutations:
-  //   a) Compute the permutation vector to move the last `numPackedDims` into
-  //      the `innerPosDims` of a shape of rank `packedRank`.
-  //   b) Compute the permutation vector to move outer dims if the pack op
-  //      has outer_dims_perm.
-  // Apply (b) permutation on (a) permutation to get the final permutation.
   int64_t numPackedDims = innerDimsPos.size();
   auto lastDims =
       llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
@@ -110,15 +88,44 @@ SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
   SmallVector<int64_t> innerPositionsPerm =
       computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
 
-  if (isPackOp) {
-    SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
-    if (!outerPerm.empty())
-      applyPermutationToVector(outerPos, outerPerm);
-    SmallVector<int64_t> outerPositionPerm = computePermutationVector(
-        rank, packingMetadata.outerPositions, outerPos);
-    applyPermutationToVector(innerPositionsPerm, outerPositionPerm);
-  }
-  return innerPositionsPerm;
+  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+  if (!outerPerm.empty())
+    applyPermutationToVector(outerPos, outerPerm);
+  SmallVector<int64_t> outerPositionPerm =
+      computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
+
+  SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
+  applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
+  return packInverseDestPermutation;
+}
+
+/// Shell function to compute the Destination Permutation of PackOp
+SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
+
+  PackingMetadata pMetadata;
+  int64_t packedRank = packOp.getDestType().getRank();
+  ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+  SmallVector<int64_t> packInvDestPerm = mlir::tensor::computePackUnPackPerm(
+      packedRank, innerDimPos, outerPerm, pMetadata);
+  return packInvDestPerm;
+}
+
+SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
+  PackingMetadata metadata;
+  return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
+}
+
+/// Shell function to compute the Source rank permutation for unpackOp
+SmallVector<int64_t>
+mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
+                                      PackingMetadata &metadata) {
+  int64_t unpackRank = unpackOp.getSourceType().getRank();
+  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
+  SmallVector<int64_t> unpackInvSrcPerm = mlir::tensor::computePackUnPackPerm(
+      unpackRank, innerDimPos, outerPerm, metadata);
+  return unpackInvSrcPerm;
 }
 
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3d37c657740055..36106312be4f95 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -779,8 +779,8 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
   // CHECK: %[[C16:.*]] = arith.constant 16 : index
   // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
   // CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
-  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<4x16x32x16xf32> to vector<4x32x16x16xf32>
-  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x32x16x16xf32> to vector<128x256xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [2, 0, 1, 3] : vector<4x16x32x16xf32> to vector<32x4x16x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<32x4x16x16xf32> to vector<128x256xf32>
   // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
   // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>

>From e8e0d88d33dbdcfb1b26838c6d68e49070c0e1f3 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Sat, 17 Feb 2024 05:15:55 +0000
Subject: [PATCH 11/12] Added all the comment changes requested by Diego.

---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h | 16 ------------
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       | 25 +++++++++++++------
 3 files changed, 19 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 009702f126eaf3..d09c9e36f6ff88 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,25 +32,9 @@ FailureOr<RankedTensorType>
 computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector);
 
-/// Given a tensor::PackOp, compute the permutation vector to shuffle the
-/// packed shape into the shape before any outer or inner permutations have
-/// been applied.
-/// i.e. for a pack from an ABCD layout to an ABCDba:
-/// The packed shape would be ABCDba.
-/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> computePackUnPackPerm(int64_t rank,
-                                           ArrayRef<int64_t> &innerDimsPos,
-                                           ArrayRef<int64_t> &outerPerm,
-                                           PackingMetadata &packingMetadata);
-
-/// This function uses the helper function `computePackUnPackPerm` to get
-/// the permutation vector. Only major difference between UnPack and Pack is
-/// that packOp uses destination rank whereas unpack Uses source rank.
 SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
 SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
 
-/// Unpack requires some packing metadata data, so create another
-/// function where this value is passed by reference.
 SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
                                              PackingMetadata &metadata);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f066967c4a9097..a8b64eb149ed63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1595,7 +1595,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   // - initially: ReadMaskShape = vectorInputSizes
   // - if outer_dims_perms is present: do that permutation on readMaskShape.
   // - Append the remaining shape from SS
-  // - Divide all teh readMaskShape locations pointed by innerDimPos
+  // - Divide all the readMaskShape locations pointed by innerDimPos
   //   by the innerTileSize attribute value.
   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f1126aaf44c76c..186f85d2ce20a6 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -78,9 +78,10 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
 ///   b) Compute the permutation vector to move outer dims if the
 ///      `outerPerm` parameter is not empty.
 /// Apply (b) permutation on (a) permutation to get the final permutation.
-SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
-    int64_t rank, ArrayRef<int64_t> &innerDimsPos, ArrayRef<int64_t> &outerPerm,
-    PackingMetadata &packingMetadata) {
+static SmallVector<int64_t>
+computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
+                      ArrayRef<int64_t> &outerPerm,
+                      PackingMetadata &packingMetadata) {
   int64_t numPackedDims = innerDimsPos.size();
   auto lastDims =
       llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
@@ -100,31 +101,41 @@ SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
 }
 
 /// Shell function to compute the Destination Permutation of PackOp
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major difference between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
 SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
 
   PackingMetadata pMetadata;
   int64_t packedRank = packOp.getDestType().getRank();
   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
-  SmallVector<int64_t> packInvDestPerm = mlir::tensor::computePackUnPackPerm(
-      packedRank, innerDimPos, outerPerm, pMetadata);
+  SmallVector<int64_t> packInvDestPerm =
+      computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
   return packInvDestPerm;
 }
 
+/// Shell function to compute the Source Permutation of unPackOp.
+/// This function, like the getPackInverseDestPerm uses the helper function
+/// computePackUnPackPerm` to get the permutation vector.
+/// Only major difference between UnPack and Pack is that packOp uses
+/// destination rank whereas unpack Uses source rank.
 SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
   PackingMetadata metadata;
   return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
 }
 
 /// Shell function to compute the Source rank permutation for unpackOp
+/// Unpack requires some packing metadata data information, so created
+/// another function where this value is passed by reference.
 SmallVector<int64_t>
 mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
                                       PackingMetadata &metadata) {
   int64_t unpackRank = unpackOp.getSourceType().getRank();
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
-  SmallVector<int64_t> unpackInvSrcPerm = mlir::tensor::computePackUnPackPerm(
-      unpackRank, innerDimPos, outerPerm, metadata);
+  SmallVector<int64_t> unpackInvSrcPerm =
+      computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
   return unpackInvSrcPerm;
 }
 

>From 524c0d97228d94f9c994bc12754850b3a6c641d6 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 20 Feb 2024 21:46:02 +0000
Subject: [PATCH 12/12] Fixed all the issues mentioned by Max on 2/20.

---
 .../Linalg/Transforms/Vectorization.cpp       | 39 ++++++++++---------
 mlir/test/Dialect/Linalg/vectorization.mlir   | 22 ++++-------
 2 files changed, 27 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a8b64eb149ed63..ac043e87223dfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1564,10 +1564,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
 ///   ShapeCastOp - Reshape the data based on the target.
 ///   vector::TransferWriteOp. - Write the result vector back to the destination
 ///   tensor
-static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
-                                         tensor::UnPackOp unpackOp,
-                                         ArrayRef<int64_t> inputVectorSizes,
-                                         SmallVectorImpl<Value> &newResults) {
+static LogicalResult
+vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
+                          ArrayRef<int64_t> inputVectorSizes,
+                          SmallVectorImpl<Value> &newResults) {
 
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
@@ -1580,12 +1580,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
                                      inputVectorSizes.end());
   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector(readMaskShape, outerDimsPerm);
-  }
   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
-  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
-                       sourceShape.end());
 
   // ReadMask is the size of tensor used to read and apply mask. It is
   // set like this: Let's say the vectorSize (VS) array is size 'N' and
@@ -1593,22 +1588,28 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
   // size M-N
   // Thus:
   // - initially: ReadMaskShape = vectorInputSizes
-  // - if outer_dims_perms is present: do that permutation on readMaskShape.
-  // - Append the remaining shape from SS
   // - Divide all the readMaskShape locations pointed by innerDimPos
   //   by the innerTileSize attribute value.
+  // - if outer_dims_perms is present: do that permutation on readMaskShape.
+  // - Append the remaining shape from SS
   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
   // 128] and outer_dims_perm is [1, 0] then read shape is:
   //   ReadMaskShape(initial): [512, 128]
-  //   After applying outer_dims_perm: [128, 512]
-  //   After appending the rest of the sourceShape: [128, 512, 32, 16]
-  //   Final Value(after innerDim Adjustment): [128/32, 512/16, 32, 16]
-  //                                           = [4, 32, 32, 16]
+  //   Final Value(after innerDim Adjustment): [512/32, 128/16]
+  //                                           = [16, 8]
+  //   After applying outer_dims_perm: [8, 16]
+  //   After appending the rest of the sourceShape: [8, 16, 32, 16]
+
   for (auto [index, size] : enumerate(innerTiles)) {
     readMaskShape[innerDimPos[index]] =
         llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
   }
+  if (!outerDimsPerm.empty()) {
+    applyPermutationToVector(readMaskShape, outerDimsPerm);
+  }
+  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+                       sourceShape.end());
 
   ReifiedRankedShapedTypeDims reifiedRetShapes;
   LogicalResult status =
@@ -1630,8 +1631,8 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
       ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
 
   PackingMetadata packMetadata;
-  SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
-      tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata));
+  SmallVector<int64_t> lastDimToInsertPosPerm =
+      tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
@@ -2031,8 +2032,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                            results);
           })
           .Case<tensor::UnPackOp>([&](auto unpackOp) {
-            return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
-                                       results);
+            return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
+                                             inputVectorSizes, results);
           })
           .Default([](auto) { return failure(); });
 
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 36106312be4f95..64f9439d6fe3a8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -691,10 +691,10 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
 // CHECK: %[[CNST2:.*]] = arith.constant 2 : index
 // CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
 // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
-// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
-// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32> to vector<32x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
 // CHECK: %[[empt0:.*]] = tensor.empty
-// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1>
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
 // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
 // CHECK: return %[[write0]]
  %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
@@ -766,24 +766,16 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
 
   // -----
 
-  // This test is same as the one test_vectorize_unpack_no_masks but with outer_dims_perm.
-  // Note that adding this attribute causes a read mask.
-
   // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
   func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[C80:.*]] = arith.constant 8 : index
-  // CHECK: %[[C32:.*]] = arith.constant 32 : index
-  // CHECK: %[[C16:.*]] = arith.constant 16 : index
-  // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
-  // CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
-  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [2, 0, 1, 3] : vector<4x16x32x16xf32> to vector<32x4x16x16xf32>
-  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<32x4x16x16xf32> to vector<128x256xf32>
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
   // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
   // CHECK: %[[C00:.*]] = arith.constant 0 : index
-  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
   // CHECK: return %[[WRIT]] : tensor<256x128xf32>
    %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
    return %0 : tensor<256x128xf32>



More information about the Mlir-commits mailing list