[Mlir-commits] [mlir] [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (PR #93055)

Han-Chung Wang llvmlistbot at llvm.org
Thu May 30 14:47:21 PDT 2024


================
@@ -349,34 +384,41 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+    auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
 
     if (!unPackOp)
       return failure();
 
-    auto transposePermutation = transposeOp.getPermutation();
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec =
-        llvm::to_vector(transposePermutation);
+    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
----------------
hanhanW wrote:

I think we can collapse your changes into a single line. Having variables does not really help the readability because the functions document it by their names.

```cpp
SmallVector<int64_t> newOuterDimsPermVec = invertPermutationVector(maybePerm.value());
```

https://github.com/llvm/llvm-project/pull/93055


More information about the Mlir-commits mailing list