[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer tensor pack (PR #75658)

lorenzo chelini llvmlistbot at llvm.org
Mon Dec 18 01:10:04 PST 2023


================
@@ -96,39 +147,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
-    auto innerDimsPos = packOp.getInnerDimsPos();
-    auto mixedInnerTiles = packOp.getMixedTiles();
-    auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-    int64_t srcRank = packOp.getSourceRank();
-
-    // Process transpose operation for non-tiled outer dimensions
-    for (unsigned int i = 0; i < srcRank; ++i) {
-      int64_t remappedPosition = transposePerm[i];
-
-      // If tensor.pack has outer_dims_perm attribute, then consider it during
-      // index remapping.
-      if (!outerDimsPerm.empty()) {
-        if (transposePerm[i] >= srcRank) {
-          return rewriter.notifyMatchFailure(
-              transposeOp,
-              "Cannot fold in tensor.pack if a tile dimension was transposed "
-              "with a non-tile dimension in linalg.transpose.");
-        }
-        remappedPosition = outerDimsPerm[remappedPosition];
-      }
-
-      newOuterDimsPermVec.push_back(remappedPosition);
-    }
 
-    // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
-      int64_t remappedPosition = transposePerm[i] - srcRank;
-      newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
-      newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+    bool foldingPossible = getRemappedPermutationForTransposeAndPack(
+        packOp, transposeOp, newOuterDimsPermVec, newInnerDimsPosVec,
+        newMixedInnerTilesVec, /*isTransposeProducer*/ false);
----------------
chelini wrote:

nit: /*isTransposeProducer=*/

https://github.com/llvm/llvm-project/pull/75658


More information about the Mlir-commits mailing list