[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer unpack an… (PR #86795)

Prashant Kumar llvmlistbot at llvm.org
Thu Mar 28 03:32:01 PDT 2024


================
@@ -323,12 +322,118 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     return success();
   }
 };
+
+/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
+/// transpose semantics.
+struct FoldProducerUnPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+
+    if (!unPackOp)
+      return failure();
+
+    auto transposePermutation = transposeOp.getPermutation();
+    auto outerDimsPerm = unPackOp.getOuterDimsPerm();
+    auto innerDimsPos = unPackOp.getInnerDimsPos();
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<int64_t> newOuterDimsPermVec =
+        llvm::to_vector(transposePermutation);
+
+    if (!outerDimsPerm.empty())
+      applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+    // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
+    // permutation rank won't necessarily be equal in all cases.
+    for (auto dim : innerDimsPos)
+      newInnerDimsPosVec.push_back(transposePermutation[dim]);
----------------
pashu123 wrote:

I removed the function that checked and applied the permutation with no overlap. Other than that, I don't think there's much else to do.

https://github.com/llvm/llvm-project/pull/86795


More information about the Mlir-commits mailing list