[Mlir-commits] [mlir] [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (PR #93055)
Han-Chung Wang
llvmlistbot at llvm.org
Thu May 30 14:47:21 PDT 2024
================
@@ -349,34 +384,41 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+ auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
if (!unPackOp)
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
+ return failure();
+ }
+
+ auto transposePermutation = maybePerm.value();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
- SmallVector<int64_t> newOuterDimsPermVec =
- llvm::to_vector(transposePermutation);
+ SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
----------------
hanhanW wrote:
I think we can collapse your changes into a single line. Having variables does not really help the readability because the functions document it by their names.
```cpp
SmallVector<int64_t> newOuterDimsPermVec = invertPermutationVector(maybePerm.value());
```
https://github.com/llvm/llvm-project/pull/93055
More information about the Mlir-commits
mailing list