[mlir] [llvm] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)

lorenzo chelini via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 4 03:55:02 PST 2023


================
@@ -81,10 +82,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto transposeOpResultType = transposeOp.getResult().getType()[0];
+    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto packInputTensor = packOp.getOperand(0);
+    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
----------------
chelini wrote:

ditto.

https://github.com/llvm/llvm-project/pull/74206


More information about the llvm-commits mailing list