[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer unpack an… (PR #86795)
Prashant Kumar
llvmlistbot at llvm.org
Thu Mar 28 00:00:18 PDT 2024
================
@@ -323,12 +322,118 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
return success();
}
};
+
+/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldProducerUnPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+
+ if (!unPackOp)
+ return failure();
+
+ auto transposePermutation = transposeOp.getPermutation();
+ auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+ auto innerDimsPos = unPackOp.getInnerDimsPos();
+ SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<int64_t> newOuterDimsPermVec =
+ llvm::to_vector(transposePermutation);
+
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+ // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
+ // permutation rank won't necessarily be equal in all cases.
+ for (auto dim : innerDimsPos)
+ newInnerDimsPosVec.push_back(transposePermutation[dim]);
----------------
pashu123 wrote:
The problem seems that we have to pattern rewrite on different ops, transpose vs unpack, and hence, it will have control flow, though I can wrap the common part in a single function that can be shared by both.
https://github.com/llvm/llvm-project/pull/86795
More information about the Mlir-commits
mailing list