[mlir] [llvm] [mlir][tensor] Fold consumer linalg transpose with producer tensor pack (PR #74206)

Han-Chung Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 11:24:50 PST 2023


================
@@ -81,10 +82,91 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packInnerDimsPos = packOp.getInnerDimsPos();
+    auto packInnerTiles = packOp.getStaticInnerTiles();
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackInnerDimsPosVec;
+    SmallVector<int64_t> newPackInnerTilesVec;
+
+    // Variable for storing remapped position after considering original
+    // outer_dims_perm and permutation attributes of tensor.pack and
+    // linalg.transpose.
+    int64_t remappedPosition;
+    int64_t finalOuterDimsSize = transposePerm.size() - packInnerTiles.size();
+
+    // Process transpose operation for non-tiled outer dimensions
+    for (unsigned int i = 0; i < finalOuterDimsSize; ++i) {
+      // If tensor.pack has outer_dims_perm attribute, then consider it during
+      // index translation.
+      if (!packOuterDimsPerm.empty()) {
+        // Note: static_cast is added around transposePerm[i] to suppress the
+        // compiler warning of comparison between variables of different types.
+        if (static_cast<unsigned long>(transposePerm[i]) <
+            packOuterDimsPerm.size()) {
+          remappedPosition = packOuterDimsPerm[transposePerm[i]];
+        } else {
+          return rewriter.notifyMatchFailure(
+              transposeOp,
+              "Cannot fold in tensor.pack if a tile dimension was transposed "
+              "with a non-tile dimension in linalg.transpose.");
+        }
+      } else {
+        remappedPosition = transposePerm[i];
+      }
+
+      newPackOuterDimsPermVec.push_back(remappedPosition);
+    }
+
+    // Process transpose operation for tiled inner dimensions
+    for (unsigned int i = finalOuterDimsSize; i < transposePerm.size(); ++i) {
+      remappedPosition = transposePerm[i] - finalOuterDimsSize;
+
+      newPackInnerTilesVec.push_back(packInnerTiles[remappedPosition]);
+      newPackInnerDimsPosVec.push_back(packInnerDimsPos[remappedPosition]);
+    }
+
+    SmallVector<OpFoldResult> opFoldResultsTiles;
+    opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
+
+    transform(newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
+              [&rewriter](int64_t value) {
+                return IntegerAttr::get(IndexType::get(rewriter.getContext()),
+                                        value);
+              });
----------------
hanhanW wrote:

If you get the innerTiles through `getMixedTiles`, they will already in `OpFoldResult` type. And you won't need these conversion. 

The other point is that using `transform` is not common in MLIR codebase, based on my review experience. Also this can be simplified by using `getAsIndexOpFoldResult` method.

https://github.com/llvm/llvm-project/pull/74206


More information about the llvm-commits mailing list