[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer unpack an… (PR #86795)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Mar 28 09:46:52 PDT 2024
================
@@ -323,12 +340,106 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
return success();
}
};
+
+/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldProducerUnPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+
+ if (!unPackOp)
+ return failure();
+
+ auto transposePermutation = transposeOp.getPermutation();
+ auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+ auto innerDimsPos = unPackOp.getInnerDimsPos();
+ SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<int64_t> newOuterDimsPermVec =
+ llvm::to_vector(transposePermutation);
+
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+ // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
+ // permutation rank won't necessarily be equal in all cases.
+ for (auto dim : innerDimsPos)
+ newInnerDimsPosVec.push_back(transposePermutation[dim]);
+
+ Value output = unPackOp.createDestinationTensor(
+ rewriter, transposeOp.getLoc(), unPackOp.getSource(),
+ unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
+
+ rewriter.replaceOpWithNewOp<UnPackOp>(
+ transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
+ unPackOp.getMixedTiles(), newOuterDimsPermVec);
+
+ return success();
+ }
+};
+
+/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldConsumerUnPackWithProducerLinalgTransposeOp
+ : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp unPackOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeOp =
+ unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+
+ if (!transposeOp)
+ return failure();
+
+ auto transposePermutation = transposeOp.getPermutation();
+ auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+ auto innerDimsPos = unPackOp.getInnerDimsPos();
+ int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
+ auto mixedInnerTilesVec = unPackOp.getMixedTiles();
+ SmallVector<int64_t> newOuterDimsPermVec;
+ SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<OpFoldResult> newMixedInnerTilesVec;
+
+ // Check whether there is no transpose from the outer dimension to inner
+ // tile dimension. For e.g., 4d tensor with permutation {0,2,1,3} is not
+ // folded for `destRank` 2.
----------------
hanhanW wrote:
I think this comment is outdated. We can drop it because it is also documented in the `checkAndPermute` method.
https://github.com/llvm/llvm-project/pull/86795
More information about the Mlir-commits
mailing list