[Mlir-commits] [mlir] 0bfbecf - [mlir][TransformDialect] Simplify the lowering of pack/unpack when these are just pad/unpad
Quentin Colombet
llvmlistbot at llvm.org
Thu Apr 13 03:47:41 PDT 2023
Author: Quentin Colombet
Date: 2023-04-13T10:46:36Z
New Revision: 0bfbecf52e8fa04785cf0b5c38b25c13442fb53d
URL: https://github.com/llvm/llvm-project/commit/0bfbecf52e8fa04785cf0b5c38b25c13442fb53d
DIFF: https://github.com/llvm/llvm-project/commit/0bfbecf52e8fa04785cf0b5c38b25c13442fb53d.diff
LOG: [mlir][TransformDialect] Simplify the lowering of pack/unpack when these are just pad/unpad
This patch recognizes when tensor.pack/unpack operations are simple
tensor.pad/unpad (a.k.a. tensor.extract_slice) and lowers them in a simpler
sequence of instruction.
For pack, instead of doing:
```
pad
expand_shape
transpose
```
we do
```
pad
insert_slice
```
For unpack, instead of doing:
```
transpose
collapse_shape
extract_slice
```
we do
```
extract_slice
```
Note: returning nullptr for the transform dialect is fine. The related
handles are just ignored by the following transformation.
Differential Revision: https://reviews.llvm.org/D148159
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/IndexingUtils.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 6d701d2ea44a0..e628c9dfef647 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1817,6 +1817,13 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
Location loc,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
+
+ /// Check if this PackOp is like a simple pad operation.
+ /// In other words, this operation:
+ /// 1. adds useless dimensions (dimension of size 1),
+ /// 2. pads the other ones, and
+ /// 3. doesn't shuffle the dimensions
+ bool isLikePad();
}];
let hasCanonicalizeMethod = 1;
@@ -1892,6 +1899,12 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
Value transposedSource,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
+
+ /// Check if this UnPackOp is like a simple unpad operation.
+ /// In other words, this operation:
+ /// 1. drops useless dimensions (dimension of size 1), and
+ /// 2. reduces dimensions in place (i.e., no tranpose.)
+ bool isLikeUnPad();
}];
let hasCanonicalizeMethod = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 1969bd3a33121..39ae6dc015651 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -198,6 +198,15 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
/// Method to check if an interchange vector is a permutation.
bool isPermutationVector(ArrayRef<int64_t> interchange);
+/// Return a permutation vector of size permSize that would result in moving
+/// positions into desiredPositions.
+///
+/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
+/// would result in a {4, 2, 0, 1, 3} permutation vector.
+SmallVector<int64_t>
+computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+ ArrayRef<int64_t> desiredPositions);
+
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index a6cb78b8a2f1e..3d1ae9c7121b6 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -520,6 +520,19 @@ getSimplifyCollapseShapeWithRankReducingSliceInfo(
RankedTensorType sourceType,
ArrayRef<ReassociationIndices> reassociationIndices);
+struct PackingMetadata {
+ SmallVector<int64_t> insertPositions;
+ SmallVector<ReassociationIndices> reassociations;
+};
+
+/// Given a vector of `positions` indices representing desired packing insertion
+/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
+/// final positions in the target shape as well as the reshape reassociations.
+// Note: This should not be called with a large positions array (or the
+// implementation needs to be updated to use an N.log N sort instead of
+// repeated N^2 counts).
+PackingMetadata computePackingMetadata(int64_t packedRank,
+ ArrayRef<int64_t> innerDimPos);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 53a8fae8560b5..a33abe9a508c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -139,81 +139,6 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
return DiagnosedSilenceableFailure::success();
}
-/// Return a permutation vector of size permSize that would result in moving
-/// positions into desiredPositions.
-///
-/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
-/// would result in a {4, 2, 0, 1, 3} permutation vector.
-static SmallVector<int64_t>
-computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
- ArrayRef<int64_t> desiredPositions) {
- SmallVector<int64_t> res(permSize, -1);
- DenseSet<int64_t> seen;
- for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
- res[desiredPos] = pos;
- seen.insert(pos);
- }
- int64_t nextPos = 0;
- for (int64_t &entry : res) {
- if (entry != -1)
- continue;
- while (seen.contains(nextPos))
- ++nextPos;
- entry = nextPos;
- ++nextPos;
- }
- return res;
-}
-
-struct PackingMetadata {
- SmallVector<int64_t> insertPositions;
- SmallVector<ReassociationIndices> reassociations;
-};
-/// Given a vector of `positions` indices representing desired packing insertion
-/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
-/// final positions in the target shape as well as the reshape reassociations.
-// Note: This should not be called with a large positions array (or the
-// implementation needs to be updated to use an N.log N sort instead of
-// repeated N^2 counts).
-static PackingMetadata computePackingMetadata(int64_t packedRank,
- ArrayRef<int64_t> innerDimPos) {
- PackingMetadata res;
- res.insertPositions.reserve(innerDimPos.size());
- // The pack insert position is the position + the number of previously
- // inserted positions + offset.
- // The offset controls whether the packing dimension is the first or last.
- //
- // Example
- // =======
- // Consider packing from a hypothetical ABCD layout to ABCDba whose
- // pack.inner_dims is [1, 0]. The first step consists in undoing the
- // permutation and producing AaBbCD. This is achieved purely by computing the
- // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
- // possibility, is to produce insert positions [2, 0], this would result in an
- // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
- // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
- // The latter is what we expect from packing.
- int64_t offset = 1;
- for (int64_t pos : innerDimPos) {
- int64_t numInsertedBefore = llvm::count_if(
- innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
- res.insertPositions.push_back(pos + numInsertedBefore + offset);
- }
-
- DenseSet<int64_t> posSet(res.insertPositions.begin(),
- res.insertPositions.end());
- res.reassociations.reserve(packedRank);
- for (int64_t i = 1; i <= packedRank; ++i) {
- if (!posSet.contains(i)) {
- res.reassociations.push_back(ReassociationIndices{i - 1});
- continue;
- }
- res.reassociations.push_back(ReassociationIndices{i - 1, i});
- ++i;
- }
- return res;
-}
-
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
@@ -888,6 +813,30 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+ if (packOp.isLikePad()) {
+ // This pack is just a plain pad.
+ // Just insert the pad in the higher ranked tensor.
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+ // Offsets.
+ SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+ // Strides.
+ SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes =
+ getMixedDimensions(rewriter, loc, packOp.getDest());
+
+ auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, /*source=*/padOp, /*dest=*/emptyOp,
+ /*offsets=*/zeros, sizes,
+ /*strides=*/ones);
+
+ LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+ rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+ return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+ /*transposeOp=*/nullptr};
+ }
// 5. Expand from the padded result to the stripMinedShape.
auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
@@ -958,10 +907,32 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- // 2. Compute the permutation vector to move the last `numPackedDims` into the
- // `innerPosDims` of a shape of rank `packedRank`.
- int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
+
+ OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+ auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+ if (unPackOp.isLikeUnPad()) {
+ // This unpack is just a plain unpad.
+ // Just extract the slice from the higher ranked tensor.
+ ArrayRef<int64_t> destShape = destTensorType.getShape();
+ // The inner dimensions stay the same as the destination tensor, but the
+ // outer ones are additional 1s.
+ SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
+ sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
+
+ auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, destTensorType, unPackOp.getSource(),
+ SmallVector<OpFoldResult>(packedRank, zero), sizes,
+ SmallVector<OpFoldResult>(packedRank, one));
+
+ rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+ return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
+ /*reshapeOp=*/nullptr, extractSliceOp};
+ }
+ // 2. Compute the permutation vector to move the last `numPackedDims` into
+ // the `innerPosDims` of a shape of rank `packedRank`.
+ int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
PackingMetadata packingMetadata =
@@ -1007,16 +978,14 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
packingMetadata.reassociations);
// 6. ExtractSlice
- auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
int64_t destRank = destTensorType.getRank();
- OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
SmallVector<OpFoldResult>(destRank, one));
- // 7. Replace unPackOp by transposeOp.
+ // 7. Replace unPackOp by extractSliceOp.
rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e092c6ea0f4a1..fd59afbc44447 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3709,6 +3709,45 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
+template <typename PackOrUnpackOp>
+static bool isLikePadUnPad(PackOrUnpackOp packOp,
+ RankedTensorType packedTensorType) {
+ static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
+ std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
+ "Function meant for pack/unpack");
+ // This is a pad if packing only adds ones and we don't transpose dimensions.
+
+ // Check that we are not transposing any dimensions.
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ int64_t numPackedDims = innerDimsPos.size();
+ auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
+ if (orderedDims != innerDimsPos) {
+ // Dimensions don't happen in order.
+ return false;
+ }
+
+ ArrayRef<int64_t> packedShape = packedTensorType.getShape();
+ int64_t packedRank = packedTensorType.getRank();
+ // At this point we know that we are taking numPackedDims outer
+ // dimensions and pushing them all the way as the inner most dimensions.
+ // What's left on the outer most dimensions is, in this order:
+ // - the factor of the packed dimensions, then
+ // - the untouched dimensions
+ // This shifting inward of dimensions is a no-op (as opposed to a transpose)
+ // if all the dimensions that bubble outerward are ones.
+ // Therefore check that all the dimensions but the numPackedDims inner most
+ // ones are ones.
+ return llvm::all_of(
+ llvm::seq<int64_t>(0, packedRank - numPackedDims),
+ [&packedShape](int64_t i) { return packedShape[i] == 1; });
+}
+
+bool PackOp::isLikePad() {
+ auto packedTensorType =
+ (*this)->getResultTypes().front().cast<RankedTensorType>();
+ return isLikePadUnPad(*this, packedTensorType);
+}
+
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
@@ -3822,6 +3861,10 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return success();
}
+bool UnPackOp::isLikeUnPad() {
+ RankedTensorType packedTensorType = getSourceType();
+ return isLikePadUnPad(*this, packedTensorType);
+}
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index eb86e0f782a78..e3efa9ca97e30 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -213,6 +213,27 @@ bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
return seenVals.size() == interchange.size();
}
+SmallVector<int64_t>
+mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+ ArrayRef<int64_t> desiredPositions) {
+ SmallVector<int64_t> res(permSize, -1);
+ DenseSet<int64_t> seen;
+ for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
+ res[desiredPos] = pos;
+ seen.insert(pos);
+ }
+ int64_t nextPos = 0;
+ for (int64_t &entry : res) {
+ if (entry != -1)
+ continue;
+ while (seen.contains(nextPos))
+ ++nextPos;
+ entry = nextPos;
+ ++nextPos;
+ }
+ return res;
+}
+
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 750c8f9ccb381..383c77f3b7340 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -450,3 +450,42 @@ mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
return CollapseShapeRankReducingSliceSimplificationInfo{
sliceType, newReassociationIndices};
}
+
+PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
+ ArrayRef<int64_t> innerDimPos) {
+ PackingMetadata res;
+ res.insertPositions.reserve(innerDimPos.size());
+ // The pack insert position is the position + the number of previously
+ // inserted positions + offset.
+ // The offset controls whether the packing dimension is the first or last.
+ //
+ // Example
+ // =======
+ // Consider packing from a hypothetical ABCD layout to ABCDba whose
+ // pack.inner_dims is [1, 0]. The first step consists in undoing the
+ // permutation and producing AaBbCD. This is achieved purely by computing the
+ // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
+ // possibility, is to produce insert positions [2, 0], this would result in an
+ // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
+ // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
+ // The latter is what we expect from packing.
+ int64_t offset = 1;
+ for (int64_t pos : innerDimPos) {
+ int64_t numInsertedBefore = llvm::count_if(
+ innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
+ res.insertPositions.push_back(pos + numInsertedBefore + offset);
+ }
+
+ DenseSet<int64_t> posSet(res.insertPositions.begin(),
+ res.insertPositions.end());
+ res.reassociations.reserve(packedRank);
+ for (int64_t i = 1; i <= packedRank; ++i) {
+ if (!posSet.contains(i)) {
+ res.reassociations.push_back(ReassociationIndices{i - 1});
+ continue;
+ }
+ res.reassociations.push_back(ReassociationIndices{i - 1, i});
+ ++i;
+ }
+ return res;
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7a89cf90214cc..f42e3e3dbbcb6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -21,14 +21,79 @@ func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
- %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
: (!pdl.operation) -> !transform.op<"tensor.pack">
- transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}
// -----
+// CHECK-LABEL: func.func @pack_as_pad(
+func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // tensor.pack is lowered to tensor.pad + tensor.insert_slice
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // offsets.
+ // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
+ // sizes.
+ // CHECK-SAME: [1, 1, 1, 1, 136, 64, 16, 16]
+ // strides multipliers.
+ // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: return %[[RES]]
+ %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+ return %pack : tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// Check that we don't lower the following pack as a pad.
+// Although all the outer most dimensions in the resulting shape are 1s,
+// some of the original dimensions are not part of the inner_dims_pos, hence
+// some transpose needs to happen.
+// CHECK-LABEL: func.func @pack_not_a_pad(
+func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<1x136x1x64x16x16xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<1x1x16x16x136x64xf32>)
+ // CHECK-SAME: permutation = [0, 2, 4, 5, 1, 3]
+
+ %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [136, 64] into %arg1
+ : tensor<129x47x16x16xf32> -> tensor<1x1x16x16x136x64xf32>
+ return %pack : tensor<1x1x16x16x136x64xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
// CHECK-LABEL: func.func @unpack(
func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
@@ -38,9 +103,9 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
// CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
// CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
- // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -49,10 +114,41 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
- %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+}
+
+// -----
+// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
+// CHECK-LABEL: func.func @unpack_as_pad(
+func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+ // offsets.
+ // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
+ // sizes.
+ // CHECK-SAME: [1, 1, 1, 1, 129, 47, 16, 16]
+ // strides multiplers.
+ // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-SAME: : tensor<1x1x1x1x136x64x16x16xf32> to tensor<129x47x16x16xf32>
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+ return %pack : tensor<129x47x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
: (!pdl.operation) -> !transform.op<"tensor.unpack">
- transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
- -> (!transform.op<"tensor.empty">,
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
!transform.op<"tensor.extract_slice">)
More information about the Mlir-commits
mailing list