[Mlir-commits] [mlir] [llvm] [mlir][tensor] Fold consumer linalg transpose with producer tensor pack (PR #74206)

Prathamesh Tagore llvmlistbot at llvm.org
Wed Dec 13 11:51:15 PST 2023


================
@@ -81,10 +82,83 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packInnerDimsPos = packOp.getInnerDimsPos();
+    auto packMixedInnerTiles = packOp.getMixedTiles();
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackInnerDimsPosVec;
+    SmallVector<OpFoldResult> newPackMixedInnerTilesVec;
+
+    // Variable for storing remapped position after considering original
+    // outer_dims_perm and permutation attributes of tensor.pack and
+    // linalg.transpose.
+    int64_t remappedPosition;
+    int64_t finalOuterDimsSize =
+        transposePerm.size() - packMixedInnerTiles.size();
----------------
meshtag wrote:

`srcRank` gives the rank of the source vector for the initial `packOp`. But we require the rank of the new `packOp` which will be generated after folding old `packOp` and `transpose`. This is then used for remapping the transpose permutation map. 

https://github.com/llvm/llvm-project/pull/74206


More information about the Mlir-commits mailing list