[Mlir-commits] [mlir] 6f87b50 - [mlir][linalg] Add support for lowering pack with outer_dims_perm.
Hanhan Wang
llvmlistbot at llvm.org
Mon Apr 24 10:39:54 PDT 2023
Author: Hanhan Wang
Date: 2023-04-24T10:39:37-07:00
New Revision: 6f87b50be67a3be9d9b55f1102c3ecd992a5ed28
URL: https://github.com/llvm/llvm-project/commit/6f87b50be67a3be9d9b55f1102c3ecd992a5ed28
DIFF: https://github.com/llvm/llvm-project/commit/6f87b50be67a3be9d9b55f1102c3ecd992a5ed28.diff
LOG: [mlir][linalg] Add support for lowering pack with outer_dims_perm.
Reviewed By: chelini, qcolombet
Differential Revision: https://reviews.llvm.org/D148845
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3d1ae9c7121b6..43a260427cf5e 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -522,6 +522,7 @@ getSimplifyCollapseShapeWithRankReducingSliceInfo(
struct PackingMetadata {
SmallVector<int64_t> insertPositions;
+ SmallVector<int64_t> outerPositions;
SmallVector<ReassociationIndices> reassociations;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4d5ef0edc9f8a..4f3f2dc0c734b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -480,9 +480,6 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
// 1. Filter out NYI cases.
- if (!packOp.getOuterDimsPerm().empty())
- return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
auto packedTensorType =
packOp->getResultTypes().front().cast<RankedTensorType>();
if (!packedTensorType.hasStaticShape()) {
@@ -495,21 +492,37 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 2. Compute the permutation vector to move the last `numPackedDims` into the
- // `innerPosDims` of a shape of rank `packedRank`.
+ // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // before any outer or inner permutations have been applied. The permutation
+ // can be obtained from two permutations:
+ // a) Compute the permutation vector to move the last `numPackedDims` into
+ // the `innerPosDims` of a shape of rank `packedRank`.
+ // b) Compute the permutation vector to move outer dims if the pack op
+ // has outer_dims_perm.
+ // Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);
+ SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+ ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+ if (!outerPerm.empty())
+ applyPermutationToVector(outerPos, outerPerm);
+ SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+ packedRank, packingMetadata.outerPositions, outerPos);
+
+ SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
+ applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
+
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
- applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+ applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
@@ -527,11 +540,17 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
+ DBGS() << "outerPositions: ");
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
- llvm::interleaveComma(lastDimsToInsertPositionsPerm,
- DBGS() << "lastDimsToInsertPositionsPerm: ");
+ llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
+ DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
+ DBGS() << "innerPositionsPerm: ");
+ DBGSNL();
+ llvm::interleaveComma(packedToStripMinedShapePerm,
+ DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
@@ -572,16 +591,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
padOp.getResult(), packingMetadata.reassociations);
// 6. Transpose stripMinedShape to packedShape.
- SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
- packedRank, packingMetadata.insertPositions, lastDims);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(packedToStripMinedShapePerm);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(),
- insertPositionsToLastDimsPerm);
+ loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- llvm::interleaveComma(insertPositionsToLastDimsPerm,
- DBGS() << "insertPositionsToLastDimsPerm: ");
+ llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
// 7. Replace packOp by transposeOp.
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 383c77f3b7340..18646f598bcee 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -480,6 +480,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
res.insertPositions.end());
res.reassociations.reserve(packedRank);
for (int64_t i = 1; i <= packedRank; ++i) {
+ res.outerPositions.push_back(i - 1);
if (!posSet.contains(i)) {
res.reassociations.push_back(ReassociationIndices{i - 1});
continue;
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 83141ec75aba1..40f9f3e0761e7 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -183,3 +183,104 @@ transform.sequence failures(propagate) {
!transform.op<"tensor.collapse_shape">,
!transform.op<"tensor.extract_slice">)
}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
+func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
+ %dest: tensor<200x4x16x100x16x32xi32>)
+ -> tensor<200x4x16x100x16x32xi32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+ // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+ // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3]
+ %0 = tensor.pack %src
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
+ return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm(
+func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
+ %dest: tensor<200x4x16x100x16x32xi32>)
+ -> tensor<200x4x16x100x16x32xi32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
+ // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>)
+ // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3]
+ %cst_0 = arith.constant 0 : i32
+ %0 = tensor.pack %src
+ padding_value(%cst_0 : i32)
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x127x255xi32> -> tensor<200x4x16x100x16x32xi32>
+ return %0 : tensor<200x4x16x100x16x32xi32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // tensor.pack is lowered to tensor.pad + tensor.insert_slice
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // offsets.
+ // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
+ // sizes.
+ // CHECK-SAME: [1, 1, 1, 1, 136, 64, 16, 16]
+ // strides multipliers.
+ // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x1x1x1x136x64x16x16xf32>
+ // CHECK: return %[[RES]]
+ %pack = tensor.pack %arg0
+ padding_value(%cst_0 : f32)
+ outer_dims_perm = [1, 2, 3, 0]
+ inner_dims_pos = [0, 1, 2, 3]
+ inner_tiles = [136, 64, 16, 16]
+ into %arg1 : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+ return %pack : tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
More information about the Mlir-commits
mailing list