[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer unpack an… (PR #86795)

Han-Chung Wang llvmlistbot at llvm.org
Thu Mar 28 09:46:52 PDT 2024


================
@@ -323,12 +340,106 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     return success();
   }
 };
+
+/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldProducerUnPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+
+    if (!unPackOp)
+      return failure();
+
+    auto transposePermutation = transposeOp.getPermutation();
+    auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+    auto innerDimsPos = unPackOp.getInnerDimsPos();
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<int64_t> newOuterDimsPermVec =
+        llvm::to_vector(transposePermutation);
+
+    if (!outerDimsPerm.empty())
+      applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+    // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
+    // permutation rank won't necessarily be equal in all cases.
+    for (auto dim : innerDimsPos)
+      newInnerDimsPosVec.push_back(transposePermutation[dim]);
+
+    Value output = unPackOp.createDestinationTensor(
+        rewriter, transposeOp.getLoc(), unPackOp.getSource(),
+        unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
+
+    rewriter.replaceOpWithNewOp<UnPackOp>(
+        transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
+        unPackOp.getMixedTiles(), newOuterDimsPermVec);
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldConsumerUnPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UnPackOp unPackOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp =
+        unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto transposePermutation = transposeOp.getPermutation();
+    auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+    auto innerDimsPos = unPackOp.getInnerDimsPos();
+    int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
+    auto mixedInnerTilesVec = unPackOp.getMixedTiles();
+    SmallVector<int64_t> newOuterDimsPermVec;
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<OpFoldResult> newMixedInnerTilesVec;
+
+    // Check whether there is no transpose from the outer dimension to inner
+    // tile dimension. For e.g., 4d tensor with permutation {0,2,1,3} is not
+    // folded for `destRank` 2.
----------------
hanhanW wrote:

I think this comment is outdated. We can drop it because it is also documented in the `checkAndPermute` method.

https://github.com/llvm/llvm-project/pull/86795


More information about the Mlir-commits mailing list