[mlir] [llvm] [mlir][tensor] Fold consumer linalg transpose with producer tensor pack (PR #74206)

Han-Chung Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 13 13:34:17 PST 2023


================
@@ -81,10 +82,77 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto innerDimsPos = packOp.getInnerDimsPos();
+    auto mixedInnerTiles = packOp.getMixedTiles();
+    auto outerDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    SmallVector<int64_t> newOuterDimsPermVec;
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<OpFoldResult> newMixedInnerTilesVec;
+    int64_t srcRank = packOp.getSourceRank();
+
+    // Process transpose operation for non-tiled outer dimensions
+    for (unsigned int i = 0; i < srcRank; ++i) {
+      // Variable for storing remapped position after considering original
+      // outer_dims_perm and permutation attributes of tensor.pack and
+      // linalg.transpose.
+      int64_t remappedPosition;
+
+      // If tensor.pack has outer_dims_perm attribute, then consider it during
+      // index remapping.
+      if (!outerDimsPerm.empty()) {
+        if (transposePerm[i] < srcRank) {
+          remappedPosition = outerDimsPerm[transposePerm[i]];
+        } else {
+          return rewriter.notifyMatchFailure(
+              transposeOp,
+              "Cannot fold in tensor.pack if a tile dimension was transposed "
+              "with a non-tile dimension in linalg.transpose.");
+        }
+      } else {
+        remappedPosition = transposePerm[i];
+      }
----------------
hanhanW wrote:

I think we can simplify the logic a bit, and saves some level of nested indents. E.g.,

```suggestion
      int64_t remappedPosition = transposePerm[i];

      // If tensor.pack has outer_dims_perm attribute, then consider it during
      // index remapping.
      if (!outerDimsPerm.empty()) {
        if (transposePerm[i] >= srcRank) {
          return rewriter.notifyMatchFailure(
              transposeOp,
              "Cannot fold in tensor.pack if a tile dimension was transposed "
              "with a non-tile dimension in linalg.transpose.");
        }
        remappedPosition = outerDimsPerm[remappedPosition];
      }
```

The comment for `remappedPosition` is not very helpful to me because the variable name already spells it to me.

https://github.com/llvm/llvm-project/pull/74206


More information about the llvm-commits mailing list