[mlir] [llvm] [mlir][tensor] Fold consumer linalg transpose with producer tensor pack (PR #74206)
Han-Chung Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 13 12:03:43 PST 2023
================
@@ -81,10 +82,83 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto packInnerDimsPos = packOp.getInnerDimsPos();
+ auto packMixedInnerTiles = packOp.getMixedTiles();
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ SmallVector<int64_t> newPackOuterDimsPermVec;
+ SmallVector<int64_t> newPackInnerDimsPosVec;
+ SmallVector<OpFoldResult> newPackMixedInnerTilesVec;
+
+ // Variable for storing remapped position after considering original
+ // outer_dims_perm and permutation attributes of tensor.pack and
+ // linalg.transpose.
+ int64_t remappedPosition;
+ int64_t finalOuterDimsSize =
+ transposePerm.size() - packMixedInnerTiles.size();
----------------
hanhanW wrote:
`transposePerm.size() == destRank` and `innerTiles.size() + srcRank == destRank`. So I think we don't need the variable, because we can use `srcRank`. Do I misunderstand something?
https://github.com/llvm/llvm-project/pull/74206
More information about the llvm-commits
mailing list