[Mlir-commits] [mlir] 7ef83f5 - [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (#93055)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 6 07:54:31 PDT 2024
Author: Max191
Date: 2024-06-06T10:54:27-04:00
New Revision: 7ef83f5561b34ca07fdef23ca2b3c01c583dbbf5
URL: https://github.com/llvm/llvm-project/commit/7ef83f5561b34ca07fdef23ca2b3c01c583dbbf5
DIFF: https://github.com/llvm/llvm-project/commit/7ef83f5561b34ca07fdef23ca2b3c01c583dbbf5.diff
LOG: [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (#93055)
This PR adds transpose + pack/unpack folding support for transpose ops
in the form of `linalg.generic` ops. There were also some bugs with the
permutation composing in the previous patterns, so this PR fixes these
bugs and adds tests for them as well.
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 5d6e3ec9756af..c681cadcb27cb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
return success();
}
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+ if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+ return SmallVector<int64_t>(transposeOp.getPermutation());
+ if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ return failure();
+
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+ return failure();
+ auto mapRange = linalgOp.getIndexingMapsArray();
+ if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
+ mapRange.front() == mapRange.back()) {
+ return failure();
+ }
+ if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+ return failure();
+ AffineMap outMap = mapRange.back();
+ AffineMap inMap = mapRange.front();
+ // To get the permutation, look at each output index and find which
+ // dimension in the input we're reading from for that index.
+ return llvm::map_to_vector(outMap.getResults(),
+ [&](AffineExpr expr) -> int64_t {
+ return *inMap.getResultPosition(expr);
+ });
+}
+
/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
for (unsigned int i = 0; i < rank; ++i) {
int64_t remappedPosition = permutation[i];
-
- if (!inVec.empty()) {
- if (remappedPosition >= rank) {
- return false;
- }
+ if (remappedPosition >= rank)
+ return false;
+ if (!inVec.empty())
remappedPosition = inVec[remappedPosition];
- }
-
resVec.push_back(remappedPosition);
}
@@ -263,20 +287,25 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+ auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
if (!packOp)
return failure();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm))
+ return failure();
+
auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
- auto transposePerm = transposeOp.getPermutation();
+ auto transposePerm = maybePerm.value();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +314,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
srcRank))
return rewriter.notifyMatchFailure(
- transposeOp,
+ linalgOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
@@ -297,11 +326,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
}
Value output = packOp.createDestinationTensor(
- rewriter, transposeOp.getLoc(), packOp.getSource(),
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+ rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+ newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+ linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -316,12 +345,16 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
- auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm))
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ auto transposePermutation = maybePerm.value();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
@@ -337,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = packOp.createDestinationTensor(
- rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, packOp.getLoc(), linalgOp->getOperand(0),
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -351,34 +384,38 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+ auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
if (!unPackOp)
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm))
+ return failure();
+
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec =
- llvm::to_vector(transposePermutation);
-
- if (!outerDimsPerm.empty())
- applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+ invertPermutationVector(maybePerm.value());
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
- newInnerDimsPosVec.push_back(transposePermutation[dim]);
+ newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
+
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
// Reuse the destination of the transpose op.
rewriter.replaceOpWithNewOp<UnPackOp>(
- transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+ linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
return success();
@@ -393,13 +430,17 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
- auto transposeOp =
- unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm))
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(maybePerm.value());
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -408,7 +449,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
- if (!checkAndPermute(transposePermutation, outerDimsPerm,
+ if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
unPackOp,
@@ -416,18 +457,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
"with a non-tile dimension in linalg.transpose.");
// Process transpose operation for tiled inner dimensions
- for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
- int64_t remappedPosition = transposePermutation[i] - destRank;
+ for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+ int64_t remappedPosition = inverseTransposePerm[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = unPackOp.createDestinationTensor(
- rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<UnPackOp>(
- unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
newMixedInnerTilesVec, newOuterDimsPermVec);
return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9a3143f5e550e..629a4c2135720 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
// CHECK: }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 1, 4, 3]
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.transpose
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>)
+ permutation = [2, 0, 1]
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 4, 1, 3]
+ %1 = tensor.empty() : tensor<5x32x12xi32>
+ %unpack = tensor.unpack %transposed
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+ return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
+// CHECK: linalg.transpose
+// CHECK: tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>) {
+ ^bb0(%in : i32, %out : i32):
+ linalg.yield %in : i32
+ } -> tensor<5x2x3x16x4xi32>
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>) {
+ ^bb0(%in : f32, %out : f32):
+ linalg.yield %in : f32
+ } -> tensor<3648x3x56xf32>
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
More information about the Mlir-commits
mailing list