[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer unpack an… (PR #86795)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Mar 27 20:58:35 PDT 2024
================
@@ -323,12 +322,118 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
return success();
}
};
+
+/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldProducerUnPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+
+ if (!unPackOp)
+ return failure();
+
+ auto transposePermutation = transposeOp.getPermutation();
+ auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+ auto innerDimsPos = unPackOp.getInnerDimsPos();
+ SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<int64_t> newOuterDimsPermVec =
+ llvm::to_vector(transposePermutation);
+
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+ // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
+ // permutation rank won't necessarily be equal in all cases.
+ for (auto dim : innerDimsPos)
+ newInnerDimsPosVec.push_back(transposePermutation[dim]);
----------------
hanhanW wrote:
This and above is similar to what we have in FoldConsumerPackWithProducerLinalgTransposeOp pattern. I wonder if we can do some template trick, so they can reuse the same method.
https://github.com/llvm/llvm-project/pull/86795
More information about the Mlir-commits
mailing list