[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer tensor pack (PR #75658)

lorenzo chelini llvmlistbot at llvm.org
Tue Jan 2 09:16:50 PST 2024


================
@@ -142,11 +143,52 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     return success();
   }
 };
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto transposePermutation = transposeOp.getPermutation();
+    auto outerDimsPerm = packOp.getOuterDimsPerm();
+    auto innerDimsPos = packOp.getInnerDimsPos();
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<int64_t> newOuterDimsPermVec = to_vector(transposePermutation);
+
+    if (!outerDimsPerm.empty())
+      applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+
+    for (auto dim : innerDimsPos) {
+      newInnerDimsPosVec.push_back(std::find(transposePermutation.begin(),
----------------
chelini wrote:

You can use llvm::find and avoid passing the begin and end iterators, something like:
`newInnerDimsPosVec.push_back(llvm::find(transposePermutation, dim) ...`.

https://github.com/llvm/llvm-project/pull/75658


More information about the Mlir-commits mailing list