[Mlir-commits] [mlir] 0bfbecf - [mlir][TransformDialect] Simplify the lowering of pack/unpack when these are just pad/unpad

Quentin Colombet llvmlistbot at llvm.org
Thu Apr 13 03:47:41 PDT 2023


Author: Quentin Colombet
Date: 2023-04-13T10:46:36Z
New Revision: 0bfbecf52e8fa04785cf0b5c38b25c13442fb53d

URL: https://github.com/llvm/llvm-project/commit/0bfbecf52e8fa04785cf0b5c38b25c13442fb53d
DIFF: https://github.com/llvm/llvm-project/commit/0bfbecf52e8fa04785cf0b5c38b25c13442fb53d.diff

LOG: [mlir][TransformDialect] Simplify the lowering of pack/unpack when these are just pad/unpad

This patch recognizes when tensor.pack/unpack operations are simple
tensor.pad/unpad (a.k.a. tensor.extract_slice) and lowers them in a simpler
sequence of instruction.

For pack, instead of doing:
```
pad
expand_shape
transpose
```
we do
```
pad
insert_slice
```

For unpack, instead of doing:
```
transpose
collapse_shape
extract_slice
```
we do
```
extract_slice
```

Note: returning nullptr for the transform dialect is fine. The related
handles are just ignored by the following transformation.

Differential Revision: https://reviews.llvm.org/D148159

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 6d701d2ea44a0..e628c9dfef647 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1817,6 +1817,13 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
                                  Location loc,
                                  ArrayRef<int64_t> innerPermutation,
                                  ArrayRef<int64_t> outerPermutation);
+
+    /// Check if this PackOp is like a simple pad operation.
+    /// In other words, this operation:
+    /// 1. adds useless dimensions (dimension of size 1),
+    /// 2. pads the other ones, and
+    /// 3. doesn't shuffle the dimensions
+    bool isLikePad();
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -1892,6 +1899,12 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
                                    Value transposedSource,
                                    ArrayRef<int64_t> innerPermutation,
                                    ArrayRef<int64_t> outerPermutation);
+
+    /// Check if this UnPackOp is like a simple unpad operation.
+    /// In other words, this operation:
+    /// 1. drops useless dimensions (dimension of size 1), and
+    /// 2. reduces dimensions in place (i.e., no tranpose.)
+    bool isLikeUnPad();
   }];
 
   let hasCanonicalizeMethod = 1;

diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 1969bd3a33121..39ae6dc015651 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -198,6 +198,15 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
 /// Method to check if an interchange vector is a permutation.
 bool isPermutationVector(ArrayRef<int64_t> interchange);
 
+/// Return a permutation vector of size permSize that would result in moving
+/// positions into desiredPositions.
+///
+/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
+/// would result in a {4, 2, 0, 1, 3} permutation vector.
+SmallVector<int64_t>
+computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+                         ArrayRef<int64_t> desiredPositions);
+
 /// Helper to return a subset of `arrayAttr` as a vector of int64_t.
 // TODO: Port everything relevant to DenseArrayAttr and drop this util.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index a6cb78b8a2f1e..3d1ae9c7121b6 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -520,6 +520,19 @@ getSimplifyCollapseShapeWithRankReducingSliceInfo(
     RankedTensorType sourceType,
     ArrayRef<ReassociationIndices> reassociationIndices);
 
+struct PackingMetadata {
+  SmallVector<int64_t> insertPositions;
+  SmallVector<ReassociationIndices> reassociations;
+};
+
+/// Given a vector of `positions` indices representing desired packing insertion
+/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
+/// final positions in the target shape as well as the reshape reassociations.
+// Note: This should not be called with a large positions array (or the
+// implementation needs to be updated to use an N.log N sort instead of
+// repeated N^2 counts).
+PackingMetadata computePackingMetadata(int64_t packedRank,
+                                       ArrayRef<int64_t> innerDimPos);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 53a8fae8560b5..a33abe9a508c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -139,81 +139,6 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Return a permutation vector of size permSize that would result in moving
-/// positions into desiredPositions.
-///
-/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
-/// would result in a {4, 2, 0, 1, 3} permutation vector.
-static SmallVector<int64_t>
-computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
-                         ArrayRef<int64_t> desiredPositions) {
-  SmallVector<int64_t> res(permSize, -1);
-  DenseSet<int64_t> seen;
-  for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
-    res[desiredPos] = pos;
-    seen.insert(pos);
-  }
-  int64_t nextPos = 0;
-  for (int64_t &entry : res) {
-    if (entry != -1)
-      continue;
-    while (seen.contains(nextPos))
-      ++nextPos;
-    entry = nextPos;
-    ++nextPos;
-  }
-  return res;
-}
-
-struct PackingMetadata {
-  SmallVector<int64_t> insertPositions;
-  SmallVector<ReassociationIndices> reassociations;
-};
-/// Given a vector of `positions` indices representing desired packing insertion
-/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
-/// final positions in the target shape as well as the reshape reassociations.
-// Note: This should not be called with a large positions array (or the
-// implementation needs to be updated to use an N.log N sort instead of
-// repeated N^2 counts).
-static PackingMetadata computePackingMetadata(int64_t packedRank,
-                                              ArrayRef<int64_t> innerDimPos) {
-  PackingMetadata res;
-  res.insertPositions.reserve(innerDimPos.size());
-  // The pack insert position is the position + the number of previously
-  // inserted positions + offset.
-  // The offset controls whether the packing dimension is the first or last.
-  //
-  // Example
-  // =======
-  // Consider packing from a hypothetical ABCD layout to ABCDba whose
-  // pack.inner_dims is [1, 0]. The first step consists in undoing the
-  // permutation and producing AaBbCD. This is achieved purely by computing the
-  // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
-  // possibility, is to produce insert positions [2, 0], this would result in an
-  // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
-  // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
-  // The latter is what we expect from packing.
-  int64_t offset = 1;
-  for (int64_t pos : innerDimPos) {
-    int64_t numInsertedBefore = llvm::count_if(
-        innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
-    res.insertPositions.push_back(pos + numInsertedBefore + offset);
-  }
-
-  DenseSet<int64_t> posSet(res.insertPositions.begin(),
-                           res.insertPositions.end());
-  res.reassociations.reserve(packedRank);
-  for (int64_t i = 1; i <= packedRank; ++i) {
-    if (!posSet.contains(i)) {
-      res.reassociations.push_back(ReassociationIndices{i - 1});
-      continue;
-    }
-    res.reassociations.push_back(ReassociationIndices{i - 1, i});
-    ++i;
-  }
-  return res;
-}
-
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -888,6 +813,30 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
+  if (packOp.isLikePad()) {
+    // This pack is just a plain pad.
+    // Just insert the pad in the higher ranked tensor.
+    auto emptyOp =
+        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+    // Offsets.
+    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+    // Strides.
+    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes =
+        getMixedDimensions(rewriter, loc, packOp.getDest());
+
+    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, /*source=*/padOp, /*dest=*/emptyOp,
+        /*offsets=*/zeros, sizes,
+        /*strides=*/ones);
+
+    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+    rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+                           /*transposeOp=*/nullptr};
+  }
   // 5. Expand from the padded result to the stripMinedShape.
   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
       loc,
@@ -958,10 +907,32 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  // 2. Compute the permutation vector to move the last `numPackedDims` into the
-  // `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
   int64_t packedRank = packedTensorType.getRank();
+
+  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+  if (unPackOp.isLikeUnPad()) {
+    // This unpack is just a plain unpad.
+    // Just extract the slice from the higher ranked tensor.
+    ArrayRef<int64_t> destShape = destTensorType.getShape();
+    // The inner dimensions stay the same as the destination tensor, but the
+    // outer ones are additional 1s.
+    SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
+    sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
+
+    auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, destTensorType, unPackOp.getSource(),
+        SmallVector<OpFoldResult>(packedRank, zero), sizes,
+        SmallVector<OpFoldResult>(packedRank, one));
+
+    rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+    return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
+                               /*reshapeOp=*/nullptr, extractSliceOp};
+  }
+  // 2. Compute the permutation vector to move the last `numPackedDims` into
+  // the `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
   auto lastDims = llvm::to_vector(
       llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
   PackingMetadata packingMetadata =
@@ -1007,16 +978,14 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
       packingMetadata.reassociations);
 
   // 6. ExtractSlice
-  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
   int64_t destRank = destTensorType.getRank();
-  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
       loc, destTensorType, reshapeOp->getResult(0),
       SmallVector<OpFoldResult>(destRank, zero),
       tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
       SmallVector<OpFoldResult>(destRank, one));
 
-  // 7. Replace unPackOp by transposeOp.
+  // 7. Replace unPackOp by extractSliceOp.
   rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
 
   return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e092c6ea0f4a1..fd59afbc44447 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3709,6 +3709,45 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
   return success();
 }
 
+template <typename PackOrUnpackOp>
+static bool isLikePadUnPad(PackOrUnpackOp packOp,
+                           RankedTensorType packedTensorType) {
+  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
+                    std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
+                "Function meant for pack/unpack");
+  // This is a pad if packing only adds ones and we don't transpose dimensions.
+
+  // Check that we are not transposing any dimensions.
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  int64_t numPackedDims = innerDimsPos.size();
+  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
+  if (orderedDims != innerDimsPos) {
+    // Dimensions don't happen in order.
+    return false;
+  }
+
+  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
+  int64_t packedRank = packedTensorType.getRank();
+  // At this point we know that we are taking numPackedDims outer
+  // dimensions and pushing them all the way as the inner most dimensions.
+  // What's left on the outer most dimensions is, in this order:
+  // - the factor of the packed dimensions, then
+  // - the untouched dimensions
+  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
+  // if all the dimensions that bubble outerward are ones.
+  // Therefore check that all the dimensions but the numPackedDims inner most
+  // ones are ones.
+  return llvm::all_of(
+      llvm::seq<int64_t>(0, packedRank - numPackedDims),
+      [&packedShape](int64_t i) { return packedShape[i] == 1; });
+}
+
+bool PackOp::isLikePad() {
+  auto packedTensorType =
+      (*this)->getResultTypes().front().cast<RankedTensorType>();
+  return isLikePadUnPad(*this, packedTensorType);
+}
+
 //===----------------------------------------------------------------------===//
 // UnPackOp
 //===----------------------------------------------------------------------===//
@@ -3822,6 +3861,10 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   return success();
 }
 
+bool UnPackOp::isLikeUnPad() {
+  RankedTensorType packedTensorType = getSourceType();
+  return isLikePadUnPad(*this, packedTensorType);
+}
 //===----------------------------------------------------------------------===//
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index eb86e0f782a78..e3efa9ca97e30 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -213,6 +213,27 @@ bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
   return seenVals.size() == interchange.size();
 }
 
+SmallVector<int64_t>
+mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+                               ArrayRef<int64_t> desiredPositions) {
+  SmallVector<int64_t> res(permSize, -1);
+  DenseSet<int64_t> seen;
+  for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
+    res[desiredPos] = pos;
+    seen.insert(pos);
+  }
+  int64_t nextPos = 0;
+  for (int64_t &entry : res) {
+    if (entry != -1)
+      continue;
+    while (seen.contains(nextPos))
+      ++nextPos;
+    entry = nextPos;
+    ++nextPos;
+  }
+  return res;
+}
+
 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
                                           unsigned dropFront,
                                           unsigned dropBack) {

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 750c8f9ccb381..383c77f3b7340 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -450,3 +450,42 @@ mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
   return CollapseShapeRankReducingSliceSimplificationInfo{
       sliceType, newReassociationIndices};
 }
+
+PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
+                                             ArrayRef<int64_t> innerDimPos) {
+  PackingMetadata res;
+  res.insertPositions.reserve(innerDimPos.size());
+  // The pack insert position is the position + the number of previously
+  // inserted positions + offset.
+  // The offset controls whether the packing dimension is the first or last.
+  //
+  // Example
+  // =======
+  // Consider packing from a hypothetical ABCD layout to ABCDba whose
+  // pack.inner_dims is [1, 0]. The first step consists in undoing the
+  // permutation and producing AaBbCD. This is achieved purely by computing the
+  // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
+  // possibility, is to produce insert positions [2, 0], this would result in an
+  // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
+  // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
+  // The latter is what we expect from packing.
+  int64_t offset = 1;
+  for (int64_t pos : innerDimPos) {
+    int64_t numInsertedBefore = llvm::count_if(
+        innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
+    res.insertPositions.push_back(pos + numInsertedBefore + offset);
+  }
+
+  DenseSet<int64_t> posSet(res.insertPositions.begin(),
+                           res.insertPositions.end());
+  res.reassociations.reserve(packedRank);
+  for (int64_t i = 1; i <= packedRank; ++i) {
+    if (!posSet.contains(i)) {
+      res.reassociations.push_back(ReassociationIndices{i - 1});
+      continue;
+    }
+    res.reassociations.push_back(ReassociationIndices{i - 1, i});
+    ++i;
+  }
+  return res;
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7a89cf90214cc..f42e3e3dbbcb6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -21,14 +21,79 @@ func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
-  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op 
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
     : (!pdl.operation) -> !transform.op<"tensor.pack">
-  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) 
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
     -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
 }
 
 // -----
 
+// CHECK-LABEL: func.func @pack_as_pad(
+func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // tensor.pack is lowered to tensor.pad + tensor.insert_slice
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  // offsets.
+  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
+  // sizes.
+  // CHECK-SAME:   [1, 1, 1, 1, 136, 64, 16, 16]
+  // strides multipliers.
+  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
+  //      CHECK: return %[[RES]]
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// Check that we don't lower the following pack as a pad.
+// Although all the outer most dimensions in the resulting shape are 1s,
+// some of the original dimensions are not part of the inner_dims_pos, hence
+// some transpose needs to happen.
+// CHECK-LABEL: func.func @pack_not_a_pad(
+func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{.*}} : tensor<1x136x1x64x16x16xf32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<1x1x16x16x136x64xf32>)
+  // CHECK-SAME:   permutation = [0, 2, 4, 5, 1, 3]
+
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [136, 64] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<1x1x16x16x136x64xf32>
+  return %pack :  tensor<1x1x16x16x136x64xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
 // CHECK-LABEL: func.func @unpack(
 func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
@@ -38,9 +103,9 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
   // CHECK-SAME:    ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
   // CHECK-SAME:   outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
   // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
-  //      CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] 
+  //      CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
   // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]  
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
   %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
     : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -49,10 +114,41 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
-  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op 
+  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.unpack">
+  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+    -> (!transform.op<"tensor.empty">,
+        !transform.op<"linalg.transpose">,
+        !transform.op<"tensor.collapse_shape">,
+        !transform.op<"tensor.extract_slice">)
+}
+
+// -----
+// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
+// CHECK-LABEL: func.func @unpack_as_pad(
+func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+  //      CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+  // offsets.
+  // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
+  // sizes.
+  // CHECK-SAME:   [1, 1, 1, 1, 129, 47, 16, 16]
+  // strides multiplers.
+  // CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
+  // CHECK-SAME:   : tensor<1x1x1x1x136x64x16x16xf32> to tensor<129x47x16x16xf32>
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+  return %pack : tensor<129x47x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
     : (!pdl.operation) -> !transform.op<"tensor.unpack">
-  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) 
-    -> (!transform.op<"tensor.empty">, 
+  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+    -> (!transform.op<"tensor.empty">,
         !transform.op<"linalg.transpose">,
         !transform.op<"tensor.collapse_shape">,
         !transform.op<"tensor.extract_slice">)


        


More information about the Mlir-commits mailing list