[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 7 08:11:24 PST 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/78660

>From 345120672dbc0bbec1ea03092a001b66ecaf7e94 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 5 Jan 2024 13:50:50 -0500
Subject: [PATCH 1/8] [mlir] Add vectorization support for tensor.pack

---
 .../TransformOps/LinalgTransformOps.cpp       |   2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 151 ++++++++++++++++++
 2 files changed, 152 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6431bbd25396a5..585fd14b40d764 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0707625819d1a5..f42e85c68f84b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
 #include <type_traits>
@@ -1393,6 +1399,121 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1) 
+///   outer_dims_perm = [0, 2, 1] 
+///   inner_dims_pos = [2, 1] 
+///   inner_tiles = [16, 2] 
+///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+  auto innerTiles = packOp.getInnerTiles();
+  int64_t srcRank = packOp.getSourceRank();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  if (innerDimsPos.empty())
+    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty())
+    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t { 
+    int64_t srcIdx;
+    if (idx >= srcRank)
+      srcIdx = innerDimsPos[idx - srcRank];
+    else
+      srcIdx = outerDimsPerm[idx];
+    int64_t tiledIdx = srcIdx;
+    for (int64_t pos : innerDimsPos)
+      if (pos < srcIdx)
+        tiledIdx++;
+    if (idx >= srcRank)
+      tiledIdx++;
+    return tiledIdx;
+  };
+  SmallVector<int64_t> perm;
+  for (int i = 0; i < packOp.getDestRank(); i++) 
+    perm.push_back(packedIdxToTiledIdx(i));
+  return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+  auto perm = getTiledShapeToPackedShapePerm(packOp);
+  auto destShape = packOp.getDestType().getShape();
+  return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+/// 
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+                       ArrayRef<int64_t> inputVectorSizes,
+                       SmallVectorImpl<Value> &newResults) {
+  auto padValue = packOp.getPaddingValue();
+  Location loc = packOp.getLoc();
+  int64_t inputRank = inputVectorSizes.size();
+  int64_t outputRank = packOp.getDestRank();
+  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  ReifiedRankedShapedTypeDims reifiedReturnShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedReturnShapes);
+  (void)status; // prevent unused variable warning on non-assert builds
+  assert(succeeded(status) && "failed to reify result shapes");
+  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+                                                  padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+  Value mask =
+      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/packOp.getSource(),
+      /*indices=*/SmallVector<Value>(inputRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(inputRank, true));
+  auto maskedOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+  // ShapeCast
+  auto tiledPackShape = getTiledPackShape(packOp);
+  auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+  auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+  Operation *write = rewriter.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/transposeOp->getResult(0),
+      /*source=*/emptyOp,
+      /*indices=*/SmallVector<Value>(outputRank, zero),
+      /*inBounds=*/SmallVector<bool>(outputRank, true));
+  // bool needMaskForWrite = llvm::any_of(
+  //     llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
+  //     [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  // if (needMaskForWrite) {
+  //   Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
+  //       loc, maskType, reifiedReturnShapes[0]);
+  //   write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
+  // }
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1706,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    LDBG("pad value is not constant: " << packOp << "\n");
+    return failure();
+  }
+
+  ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+    return failure();
+
+  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+        std::optional<int64_t> res = getConstantIntValue(v);
+        return !res.has_value();
+      })) {
+    LDBG("inner_tiles must be constant: " << packOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1789,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
+      .Case<tensor::PackOp>([&](auto packOp) {
+        return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1732,6 +1880,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
                                           results);
           })
+          .Case<tensor::PackOp>([&](auto packOp) {
+            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {

>From 8ad3ad742ed0097f76fe1b84966d1e254559b4c2 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 20:12:13 -0500
Subject: [PATCH 2/8] Support pack with no padding value

---
 .../Linalg/Transforms/Vectorization.cpp       | 22 ++++++++-----------
 1 file changed, 9 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f42e85c68f84b1..d0e3b7f4e80287 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,16 +1458,20 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        ArrayRef<int64_t> inputVectorSizes,
                        SmallVectorImpl<Value> &newResults) {
-  auto padValue = packOp.getPaddingValue();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
   Location loc = packOp.getLoc();
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    padValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+  }
   int64_t inputRank = inputVectorSizes.size();
   int64_t outputRank = packOp.getDestRank();
   auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
 
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
   ReifiedRankedShapedTypeDims reifiedReturnShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
@@ -1502,14 +1506,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
       /*source=*/emptyOp,
       /*indices=*/SmallVector<Value>(outputRank, zero),
       /*inBounds=*/SmallVector<bool>(outputRank, true));
-  // bool needMaskForWrite = llvm::any_of(
-  //     llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
-  //     [](auto it) { return std::get<0>(it) != std::get<1>(it); });
-  // if (needMaskForWrite) {
-  //   Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
-  //       loc, maskType, reifiedReturnShapes[0]);
-  //   write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
-  // }
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1710,7 +1706,7 @@ static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
                            ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
-  if (!padValue) {
+  if (padValue && getConstantIntValue(padValue) != std::nullopt) {
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }

>From a0931bdfd9f0bc6f20b5a71c028848ba692b0d8d Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:11:49 -0500
Subject: [PATCH 3/8] add tests

---
 mlir/test/Dialect/Linalg/vectorization.mlir | 61 +++++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d5fb0cbb9c723b..af1c1337224fa2 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -501,6 +501,67 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
+  return %pack : tensor<4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+    transform.yield 
+  }
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+//  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
+//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
+//      CHECK: return %[[write]] : tensor<4x1x16x2xf32>
+
+// -----
+
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+  return %pack : tensor<32x4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+    transform.yield 
+  }
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+//  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+//  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+//      CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+// -----
+
 func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

>From e4950ce644c054372103aafdeaa7e1005ae2a678 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:26:08 -0500
Subject: [PATCH 4/8] clang

---
 .../Linalg/Transforms/Vectorization.cpp       | 35 +++++++++++--------
 1 file changed, 20 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d0e3b7f4e80287..37829fbeb79f7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1408,15 +1408,16 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 /// dimension.
 /// i.e. for the following tensor.pack:
 /// ```mlir
-/// %pack = tensor.pack %0 padding_value(%1) 
-///   outer_dims_perm = [0, 2, 1] 
-///   inner_dims_pos = [2, 1] 
-///   inner_tiles = [16, 2] 
+/// %pack = tensor.pack %0 padding_value(%1)
+///   outer_dims_perm = [0, 2, 1]
+///   inner_dims_pos = [2, 1]
+///   inner_tiles = [16, 2]
 ///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
 /// ```
 /// The "packed" shape is `32x1x4x16x2`
 /// The "tiled" shape is `32x(4x2)x(1x16)`
-static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+static SmallVector<int64_t>
+getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
   auto innerTiles = packOp.getInnerTiles();
   int64_t srcRank = packOp.getSourceRank();
   auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1425,7 +1426,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
   auto outerDimsPerm = packOp.getOuterDimsPerm();
   if (outerDimsPerm.empty())
     outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
-  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t { 
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
     int64_t srcIdx;
     if (idx >= srcRank)
       srcIdx = innerDimsPos[idx - srcRank];
@@ -1440,7 +1441,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
     return tiledIdx;
   };
   SmallVector<int64_t> perm;
-  for (int i = 0; i < packOp.getDestRank(); i++) 
+  for (int i = 0; i < packOp.getDestRank(); i++)
     perm.push_back(packedIdxToTiledIdx(i));
   return perm;
 }
@@ -1453,11 +1454,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
   return applyPermutation(destShape, invertPermutationVector(perm));
 }
 
-/// 
+///
 static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
-                       ArrayRef<int64_t> inputVectorSizes,
-                       SmallVectorImpl<Value> &newResults) {
+                        ArrayRef<int64_t> inputVectorSizes,
+                        SmallVectorImpl<Value> &newResults) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
@@ -1496,10 +1497,13 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
       mlir::vector::maskOperation(rewriter, transferReadOp, mask));
   // ShapeCast
   auto tiledPackShape = getTiledPackShape(packOp);
-  auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
-  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+  auto tiledPackType =
+      VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      loc, tiledPackType, maskedOp->getResult(0));
   auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
-  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+  auto transposeOp = rewriter.create<vector::TransposeOp>(
+      loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
   Operation *write = rewriter.create<vector::TransferWriteOp>(
       loc,
       /*vector=*/transposeOp->getResult(0),
@@ -1704,7 +1708,7 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
 
 static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
-                           ArrayRef<int64_t> inputVectorSizes) {
+                            ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
   if (padValue && getConstantIntValue(padValue) != std::nullopt) {
     LDBG("pad value is not constant: " << packOp << "\n");
@@ -1877,7 +1881,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                           results);
           })
           .Case<tensor::PackOp>([&](auto packOp) {
-            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+                                           results);
           })
           .Default([](auto) { return failure(); });
 

>From 14a73f35673ecdf4d0f7f3615927537bb0ba6664 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 19 Jan 2024 16:25:58 -0500
Subject: [PATCH 5/8] Use result shape pack vector sizes, clean up

---
 .../Linalg/Transforms/Vectorization.cpp       | 175 +++++++++++-------
 mlir/test/Dialect/Linalg/vectorization.mlir   |  26 +--
 2 files changed, 119 insertions(+), 82 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 37829fbeb79f7a..78c8a62933324d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -1454,7 +1455,73 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
   return applyPermutation(destShape, invertPermutationVector(perm));
 }
 
-///
+/// Create a masked TransferReadOp from `source` with shape `readShape`.
+static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
+                                               Value source,
+                                               ArrayRef<int64_t> readShape,
+                                               Value padValue) {
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  int64_t readRank = readShape.size();
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
+  return cast<vector::MaskOp>(
+      mlir::vector::maskOperation(builder, transferReadOp, mask));
+}
+
+/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
+/// create an empty destination tensor and create a TransferWriteOp from the
+/// input to the empty tensor. If the destination shape is not the same as the
+/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
+/// mask for the write.
+static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
+                                           Value input,
+                                           SmallVector<OpFoldResult> destSizes,
+                                           ArrayRef<int64_t> inputVectorSizes) {
+  auto inputType = cast<VectorType>(input.getType());
+  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
+                                               inputType.getElementType());
+  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Operation *write = builder.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/input,
+      /*source=*/dest,
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  bool needMaskForWrite =
+      llvm::any_of(llvm::zip(inputVectorSizes, destShape),
+                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  if (needMaskForWrite) {
+    SmallVector<int64_t> writeMaskShape;
+    writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+    writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
+                          destShape.end());
+    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
+    Value maskForWrite =
+        builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
+    write = mlir::vector::maskOperation(builder, write, maskForWrite);
+  }
+  return write;
+}
+
+/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
+/// padding value into
+/// transfer_write_in_bounds(
+///     transpose(
+///         shape_cast(
+///             transfer_read_masked(pack_source, pad_value))))
 static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
@@ -1468,48 +1535,41 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
     padValue = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
   }
-  int64_t inputRank = inputVectorSizes.size();
-  int64_t outputRank = packOp.getDestRank();
-  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
-
   ReifiedRankedShapedTypeDims reifiedReturnShapes;
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
           .reifyResultShapes(rewriter, reifiedReturnShapes);
   (void)status; // prevent unused variable warning on non-assert builds
   assert(succeeded(status) && "failed to reify result shapes");
-  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
-                                                  padValue.getType());
-  SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(rewriter, loc, packOp.getSource());
-  Value mask =
-      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/packOp.getSource(),
-      /*indices=*/SmallVector<Value>(inputRank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/SmallVector<bool>(inputRank, true));
-  auto maskedOp = cast<vector::MaskOp>(
-      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
-  // ShapeCast
-  auto tiledPackShape = getTiledPackShape(packOp);
-  auto tiledPackType =
-      VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+
+  // Create masked TransferReadOp
+  SmallVector<int64_t> inputShape(inputVectorSizes);
+  auto innerTiles = packOp.getStaticInnerTiles();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector(inputShape,
+                             invertPermutationVector(outerDimsPerm));
+  for (auto [idx, size] : enumerate(innerTiles))
+    inputShape[innerDimsPos[idx]] *= size;
+  auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
+                                           inputShape, padValue);
+
+  // Create ShapeCastOp
+  auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
+                                       packOp.getDestType().getElementType());
   auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       loc, tiledPackType, maskedOp->getResult(0));
+
+  // Create TransposeOp
   auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
   auto transposeOp = rewriter.create<vector::TransposeOp>(
-      loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/transposeOp->getResult(0),
-      /*source=*/emptyOp,
-      /*indices=*/SmallVector<Value>(outputRank, zero),
-      /*inBounds=*/SmallVector<bool>(outputRank, true));
+      loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
+
+  // Create TransferWriteOp
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
+                               reifiedReturnShapes[0], inputVectorSizes);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1523,9 +1583,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
                        SmallVectorImpl<Value> &newResults) {
   auto padValue = padOp.getConstantPaddingValue();
   Location loc = padOp.getLoc();
-  int64_t rank = inputVectorSizes.size();
-  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
 
   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
   OpBuilder::InsertionGuard g(rewriter);
@@ -1537,36 +1594,11 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
           .reifyResultShapes(rewriter, reifiedReturnShapes);
   (void)status; // prevent unused variable warning on non-assert builds
   assert(succeeded(status) && "failed to reify result shapes");
-  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
-                                                  padValue.getType());
-  SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(rewriter, loc, padOp.getSource());
-  Value mask =
-      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/padOp.getSource(),
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/SmallVector<bool>(rank, true));
-  auto maskedOp = cast<vector::MaskOp>(
-      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/maskedOp->getResult(0),
-      /*source=*/emptyOp,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/SmallVector<bool>(rank, true));
-  bool needMaskForWrite = llvm::any_of(
-      llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
-      [](auto it) { return std::get<0>(it) != std::get<1>(it); });
-  if (needMaskForWrite) {
-    Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
-        loc, maskType, reifiedReturnShapes[0]);
-    write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
-  }
+  auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
+                                           inputVectorSizes, padValue);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
+                               reifiedReturnShapes[0], inputVectorSizes);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1710,18 +1742,19 @@ static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
-  if (padValue && getConstantIntValue(padValue) != std::nullopt) {
+  if (padValue && !getConstantIntValue(padValue).has_value()) {
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }
 
-  ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
-  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
+  if (failed(isValidMaskedInputVector(
+          resultTensorShape.take_front(packOp.getSourceRank()),
+          inputVectorSizes)))
     return failure();
 
   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
-        std::optional<int64_t> res = getConstantIntValue(v);
-        return !res.has_value();
+        return !getConstantIntValue(v).has_value();
       })) {
     LDBG("inner_tiles must be constant: " << packOp << "\n");
     return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index af1c1337224fa2..ed9a8eb9183bdf 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -426,7 +426,6 @@ func.func @test_masked_vectorize_pad(
 {
   //  CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
   //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-  //  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
   //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -435,7 +434,9 @@ func.func @test_masked_vectorize_pad(
   // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
   // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
   // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
-  //      CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
+  //  CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
+  //      CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
   // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
   %cst = arith.constant 42.43 : f32
   %c0 = arith.constant 0 : index
@@ -467,7 +468,6 @@ func.func @test_masked_vectorize_dynamic_pad(
   //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
   //  CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
   //  CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
-  //  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
   //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -476,9 +476,11 @@ func.func @test_masked_vectorize_dynamic_pad(
   // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
   // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
   // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
+  //  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
+  //  CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
   //      CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
   //      CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
-  // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
+  // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
   // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
   //      CHECK: return %[[masked_write]] : tensor<?x?xf32>
   %cst = arith.constant 42.43 : f32
@@ -508,7 +510,7 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
     transform.yield 
   }
 }
@@ -517,15 +519,16 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
 //  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
 //  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
-//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
 //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
-//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
 //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
-// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
 // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
 // CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
 //      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
 //      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
 //      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
 // CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
 //      CHECK: return %[[write]] : tensor<4x1x16x2xf32>
@@ -539,7 +542,7 @@ func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
     transform.yield 
   }
 }
@@ -547,7 +550,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
 //  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
 //  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
-//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
 //      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
 //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
 //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
@@ -556,7 +558,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
 //      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
 //      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
 // CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
 //      CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
 

>From b11affebf1c9cb190dde9c637c88c930d299fd69 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 19 Jan 2024 17:25:50 -0500
Subject: [PATCH 6/8] add dynamic test

---
 .../Linalg/Transforms/Vectorization.cpp       | 10 +++--
 mlir/test/Dialect/Linalg/vectorization.mlir   | 40 +++++++++++++++++++
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 78c8a62933324d..2961e8cbee7a1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1442,16 +1442,16 @@ getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
     return tiledIdx;
   };
   SmallVector<int64_t> perm;
-  for (int i = 0; i < packOp.getDestRank(); i++)
+  for (size_t i = 0; i < packOp.getDestRank(); i++)
     perm.push_back(packedIdxToTiledIdx(i));
   return perm;
 }
 
 /// Given a tensor::PackOp, return the "tiled" `dest` shape as described
 /// above in `getTiledShapeToPackedShapePerm`.
-static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
+                                              ArrayRef<int64_t> destShape) {
   auto perm = getTiledShapeToPackedShapePerm(packOp);
-  auto destShape = packOp.getDestType().getShape();
   return applyPermutation(destShape, invertPermutationVector(perm));
 }
 
@@ -1556,7 +1556,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                                            inputShape, padValue);
 
   // Create ShapeCastOp
-  auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
+  SmallVector<int64_t> destShape(inputVectorSizes);
+  destShape.append(innerTiles.begin(), innerTiles.end());
+  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
                                        packOp.getDestType().getElementType());
   auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       loc, tiledPackType, maskedOp->getResult(0));
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index ed9a8eb9183bdf..d9546f6da38a39 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -566,6 +566,46 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @test_vectorize_dynamic_result_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+  return %pack : tensor<?x?x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.yield 
+  }
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+//  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
+//  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
+//  CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
+//  CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+//  CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
+//      CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
+//      CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
+// CHECK-SAME:   vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
+//      CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
+
+// -----
+
 func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

>From 6c0e2a1d4e9a03404e6c7546904e86868d6b3f1e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 1 Feb 2024 18:14:20 -0500
Subject: [PATCH 7/8] address comments

---
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |   8 +
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  36 +---
 .../Linalg/Transforms/Vectorization.cpp       | 160 ++++++++----------
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |  29 ++++
 mlir/test/Dialect/Linalg/vectorization.mlir   |  95 +++++------
 5 files changed, 164 insertions(+), 164 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 04b4de4a33a52f..fe9b16cb44b3da 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
 computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector);
 
+/// Given a tensor::PackOp, compute the permutation vector to shuffle the
+/// packed shape into the shape before any outer or inner permutations have
+/// been applied.
+/// i.e. for a pack from an ABCD layout to an ABCDba:
+/// The packed shape would be ABCDba.
+/// The pre-permutation shape would be AaBbCD.
+SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with
 /// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 02bc3e672bf7a7..596b7c50c1e4e4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   rewriter.setInsertionPoint(packOp);
 
   // 2. Compute the permutation vector to shuffle packed shape into the shape
-  // before any outer or inner permutations have been applied. The permutation
-  // can be obtained from two permutations:
-  //   a) Compute the permutation vector to move the last `numPackedDims` into
-  //      the `innerPosDims` of a shape of rank `packedRank`.
-  //   b) Compute the permutation vector to move outer dims if the pack op
-  //      has outer_dims_perm.
-  // Apply (b) permutation on (a) permutation to get the final permutation.
-  int64_t numPackedDims = packOp.getInnerDimsPos().size();
-  int64_t packedRank = packedTensorType.getRank();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  // before any outer or inner permutations have been applied.
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
-  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
-  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
-  if (!outerPerm.empty())
-    applyPermutationToVector(outerPos, outerPerm);
-  SmallVector<int64_t> outerPositionPerm = computePermutationVector(
-      packedRank, packingMetadata.outerPositions, outerPos);
-
-  SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
-  applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
+  SmallVector<int64_t> packedToStripMinedShapePerm =
+      tensor::getPackInverseDestPermutation(packOp);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
@@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
                                       DBGS() << "packedShape: ");
       DBGSNL();
-      llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
-      DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
-                                      DBGS() << "innerPositionsPerm: ");
-      DBGSNL();
       llvm::interleaveComma(packedToStripMinedShapePerm,
                             DBGS() << "packedToStripMinedShapePerm: ");
       DBGSNL(); llvm::interleaveComma(
@@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       auto emptyOp =
           rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
       // Offsets.
-      SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+      SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
+                                      rewriter.getIndexAttr(0));
       // Strides.
-      SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+      SmallVector<OpFoldResult> ones(packOp.getDestRank(),
+                                     rewriter.getIndexAttr(1));
       SmallVector<OpFoldResult> sizes =
           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2961e8cbee7a1f..7e7de846d99543 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -1400,74 +1401,26 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
-/// Given a tensor::PackOp, return the permutation from the "tiled"
-/// shape to the "packed" shape, defined as the following:
-/// The "packed" shape is the same as the `dest` shape of the pack op.
-/// The "tiled" shape is a permutation of the `dest` shape such that
-/// each outer dimension is in the original `source` order, and the
-/// inner_tile dimensions immediately follow their corresponding outer
-/// dimension.
-/// i.e. for the following tensor.pack:
-/// ```mlir
-/// %pack = tensor.pack %0 padding_value(%1)
-///   outer_dims_perm = [0, 2, 1]
-///   inner_dims_pos = [2, 1]
-///   inner_tiles = [16, 2]
-///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
-/// ```
-/// The "packed" shape is `32x1x4x16x2`
-/// The "tiled" shape is `32x(4x2)x(1x16)`
-static SmallVector<int64_t>
-getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
-  auto innerTiles = packOp.getInnerTiles();
-  int64_t srcRank = packOp.getSourceRank();
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  if (innerDimsPos.empty())
-    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
-  if (outerDimsPerm.empty())
-    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
-  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
-    int64_t srcIdx;
-    if (idx >= srcRank)
-      srcIdx = innerDimsPos[idx - srcRank];
-    else
-      srcIdx = outerDimsPerm[idx];
-    int64_t tiledIdx = srcIdx;
-    for (int64_t pos : innerDimsPos)
-      if (pos < srcIdx)
-        tiledIdx++;
-    if (idx >= srcRank)
-      tiledIdx++;
-    return tiledIdx;
-  };
-  SmallVector<int64_t> perm;
-  for (size_t i = 0; i < packOp.getDestRank(); i++)
-    perm.push_back(packedIdxToTiledIdx(i));
-  return perm;
-}
-
-/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
-/// above in `getTiledShapeToPackedShapePerm`.
+/// Given a tensor::PackOp, return the `dest` shape before any packing
+/// permutations.
 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
-  auto perm = getTiledShapeToPackedShapePerm(packOp);
-  return applyPermutation(destShape, invertPermutationVector(perm));
+  return applyPermutation(destShape,
+                          tensor::getPackInverseDestPermutation(packOp));
 }
 
-/// Create a masked TransferReadOp from `source` with shape `readShape`.
-static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
-                                               Value source,
-                                               ArrayRef<int64_t> readShape,
-                                               Value padValue) {
+/// Create a TransferReadOp from `source` with static shape `readShape`. If the
+/// vector type for the read is not the same as the type of `source`, then a
+/// mask is created on the read.
+static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue) {
+  assert(llvm::none_of(readShape,
+                       [](int64_t s) { return s == ShapedType::kDynamic; }));
   auto maskType = VectorType::get(readShape, builder.getI1Type());
   auto vectorType = VectorType::get(readShape, padValue.getType());
-  SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
-  Value mask =
-      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   int64_t readRank = readShape.size();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   auto transferReadOp = builder.create<vector::TransferReadOp>(
       loc,
       /*vectorType=*/vectorType,
@@ -1475,8 +1428,20 @@ static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
       /*indices=*/SmallVector<Value>(readRank, zero),
       /*padding=*/padValue,
       /*inBounds=*/SmallVector<bool>(readRank, true));
-  return cast<vector::MaskOp>(
-      mlir::vector::maskOperation(builder, transferReadOp, mask));
+  auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
+  if (sourceShape.size() == readShape.size() &&
+      llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
+        return std::get<0>(it) != ShapedType::kDynamic &&
+               std::get<0>(it) == std::get<1>(it);
+      })) {
+    return transferReadOp;
+  }
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  return mlir::vector::maskOperation(builder, transferReadOp, mask)
+      ->getResult(0);
 }
 
 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
@@ -1500,9 +1465,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
       /*indices=*/SmallVector<Value>(rank, zero),
       /*inBounds=*/SmallVector<bool>(rank, true));
   auto destShape = cast<ShapedType>(dest.getType()).getShape();
-  bool needMaskForWrite =
-      llvm::any_of(llvm::zip(inputVectorSizes, destShape),
-                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  assert(llvm::none_of(
+             destShape.drop_front(inputVectorSizes.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVectorSizes may be dynamic");
+  bool needMaskForWrite = llvm::any_of(
+      llvm::zip_equal(inputVectorSizes,
+                      destShape.take_front(inputVectorSizes.size())),
+      [](auto it) { return std::get<0>(it) != std::get<1>(it); });
   if (needMaskForWrite) {
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
@@ -1517,11 +1487,28 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
 }
 
 /// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
-/// padding value into
-/// transfer_write_in_bounds(
-///     transpose(
-///         shape_cast(
-///             transfer_read_masked(pack_source, pad_value))))
+/// padding value into:
+/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
+/// As in the following example:
+/// ```mlir
+/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
+///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ```
+/// This pack would be vectorized to:
+/// ```mlir
+/// %load = vector.mask %mask {
+///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
+///         {in_bounds = [true, true, true]} :
+///         tensor<32x7x16xf32>, vector<32x8x16xf32>
+/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
+///                                         to vector<32x4x2x1x16xf32>
+/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
+///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+/// %write = vector.transfer_write %transpose,
+///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
+///     {in_bounds = [true, true, true, true, true]}
+///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
 static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
@@ -1539,10 +1526,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   LogicalResult status =
       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
           .reifyResultShapes(rewriter, reifiedReturnShapes);
-  (void)status; // prevent unused variable warning on non-assert builds
+  (void)status; // prevent unused variable warning on non-assert builds.
   assert(succeeded(status) && "failed to reify result shapes");
 
-  // Create masked TransferReadOp
+  // Create masked TransferReadOp.
   SmallVector<int64_t> inputShape(inputVectorSizes);
   auto innerTiles = packOp.getStaticInnerTiles();
   auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1552,23 +1539,24 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                              invertPermutationVector(outerDimsPerm));
   for (auto [idx, size] : enumerate(innerTiles))
     inputShape[innerDimsPos[idx]] *= size;
-  auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
+  auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
                                            inputShape, padValue);
 
-  // Create ShapeCastOp
+  // Create ShapeCastOp.
   SmallVector<int64_t> destShape(inputVectorSizes);
   destShape.append(innerTiles.begin(), innerTiles.end());
   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
                                        packOp.getDestType().getElementType());
-  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
-      loc, tiledPackType, maskedOp->getResult(0));
+  auto shapeCastOp =
+      rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
 
-  // Create TransposeOp
-  auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+  // Create TransposeOp.
+  auto destPermutation =
+      invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
   auto transposeOp = rewriter.create<vector::TransposeOp>(
-      loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
+      loc, shapeCastOp.getResult(), destPermutation);
 
-  // Create TransferWriteOp
+  // Create TransferWriteOp.
   Operation *write =
       createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
                                reifiedReturnShapes[0], inputVectorSizes);
@@ -1596,11 +1584,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
           .reifyResultShapes(rewriter, reifiedReturnShapes);
   (void)status; // prevent unused variable warning on non-assert builds
   assert(succeeded(status) && "failed to reify result shapes");
-  auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
+  auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
                                            inputVectorSizes, padValue);
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
-                               reifiedReturnShapes[0], inputVectorSizes);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1740,11 +1727,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+/// TODO: Use a matcher to check for a constant padding value.
 static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
-  if (padValue && !getConstantIntValue(padValue).has_value()) {
+  if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 24cbceb3d11791..f20008a1ed2b2f 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -73,6 +73,35 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
   return transposedTensorType;
 }
 
+SmallVector<int64_t>
+mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
+  // The permutation can be obtained from two permutations:
+  //   a) Compute the permutation vector to move the last `numPackedDims` into
+  //      the `innerPosDims` of a shape of rank `packedRank`.
+  //   b) Compute the permutation vector to move outer dims if the pack op
+  //      has outer_dims_perm.
+  // Apply (b) permutation on (a) permutation to get the final permutation.
+  int64_t numPackedDims = packOp.getInnerDimsPos().size();
+  int64_t packedRank = packOp.getDestType().getRank();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata = computePackingMetadata(
+      packOp.getDestType().getRank(), packOp.getInnerDimsPos());
+  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+  if (!outerPerm.empty())
+    applyPermutationToVector(outerPos, outerPerm);
+  SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+      packedRank, packingMetadata.outerPositions, outerPos);
+
+  SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
+  applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
+  return packInverseDestPermutation;
+}
+
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   llvm::SmallBitVector droppedDims = op.getDroppedDims();
   int64_t srcDim = 0;
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d9546f6da38a39..5d1bef478ee987 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -426,17 +426,17 @@ func.func @test_masked_vectorize_pad(
 {
   //  CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
   //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
   //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
-  //  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
   //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
-  // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
+  // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_0]], %[[c0_0]]], %[[c42]]
   // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
   // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
-  //  CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
   //  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
-  //      CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
+  //      CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_1]], %[[c0_1]]]
   // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
   %cst = arith.constant 42.43 : f32
   %c0 = arith.constant 0 : index
@@ -468,10 +468,10 @@ func.func @test_masked_vectorize_dynamic_pad(
   //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
   //  CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
   //  CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
+  //      CHECK: %[[c0_2:.*]] = arith.constant 0 : index
   //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
-  //  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
   //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
   // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
   // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
@@ -503,58 +503,46 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
-  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
-  return %pack : tensor<4x1x16x2xf32>
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+  %pack = tensor.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
+  return %pack : tensor<4x1x32x16x2xf32>
 }
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//      CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME:    {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
+//      CHECK: return %[[write]] : tensor<4x1x32x16x2xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
     transform.yield 
   }
 }
-//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
-//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-//  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-//  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
-//  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
-//      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
-//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
-//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
-// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
-// CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
-// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
-//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
-//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
-//  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
-//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
-//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
-// CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
-//      CHECK: return %[[write]] : tensor<4x1x16x2xf32>
 
 // -----
 
-func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
-  %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+  %pad = arith.constant 0.000000e+00 : f32
+  %pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
   return %pack : tensor<32x4x1x16x2xf32>
 }
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
-    transform.yield 
-  }
-}
 //  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
-//  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
-//  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-//  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
-//      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
 //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+//  CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
+//  CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
 //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
 // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
-// CHECK-SAME:   {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME:   {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
 // CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
 //      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
 //      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
@@ -564,30 +552,31 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
 //      CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
 
-// -----
-
-func.func @test_vectorize_dynamic_result_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
-  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
-  return %pack : tensor<?x?x16x2xf32>
-}
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
     transform.yield 
   }
 }
+
+// -----
+
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+  %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+  return %pack : tensor<?x?x16x2xf32>
+}
 //  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
 //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
 //  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
 //  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
 //  CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
 //  CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
 //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
-//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
 //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
 // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
 // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
@@ -604,6 +593,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
 //      CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.yield 
+  }
+}
+
 // -----
 
 func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {

>From 5c5278c57426d60bb1770427aed9382adf7c2812 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 7 Feb 2024 11:09:50 -0500
Subject: [PATCH 8/8] last comments

---
 .../Linalg/Transforms/Vectorization.cpp       | 21 +++++++------------
 1 file changed, 8 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7e7de846d99543..2bd6929fea6142 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1417,6 +1417,8 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                     Value padValue) {
   assert(llvm::none_of(readShape,
                        [](int64_t s) { return s == ShapedType::kDynamic; }));
+  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
+  assert(sourceShape.size() == readShape.size());
   auto maskType = VectorType::get(readShape, builder.getI1Type());
   auto vectorType = VectorType::get(readShape, padValue.getType());
   int64_t readRank = readShape.size();
@@ -1428,12 +1430,7 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
       /*indices=*/SmallVector<Value>(readRank, zero),
       /*padding=*/padValue,
       /*inBounds=*/SmallVector<bool>(readRank, true));
-  auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
-  if (sourceShape.size() == readShape.size() &&
-      llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
-        return std::get<0>(it) != ShapedType::kDynamic &&
-               std::get<0>(it) == std::get<1>(it);
-      })) {
+  if (llvm::equal(readShape, sourceShape)) {
     return transferReadOp;
   }
   SmallVector<OpFoldResult> mixedSourceDims =
@@ -1469,10 +1466,8 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
              destShape.drop_front(inputVectorSizes.size()),
              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
          "Only dims aligned with inputVectorSizes may be dynamic");
-  bool needMaskForWrite = llvm::any_of(
-      llvm::zip_equal(inputVectorSizes,
-                      destShape.take_front(inputVectorSizes.size())),
-      [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  bool needMaskForWrite = !llvm::equal(
+      inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
   if (needMaskForWrite) {
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
@@ -1490,12 +1485,12 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
 /// padding value into:
 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
 /// As in the following example:
-/// ```mlir
+///
 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
-/// ```
+///
 /// This pack would be vectorized to:
-/// ```mlir
+///
 /// %load = vector.mask %mask {
 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
 ///         {in_bounds = [true, true, true]} :



More information about the Mlir-commits mailing list