[Mlir-commits] [mlir] [llvm] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)
lorenzo chelini
llvmlistbot at llvm.org
Mon Dec 4 03:55:02 PST 2023
================
@@ -81,10 +82,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeInputTensor = transposeOp.getOperand(0);
+ auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto transposeOpResultType = transposeOp.getResult().getType()[0];
+ auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+ return success();
+ }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+ : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto packInputTensor = packOp.getOperand(0);
+ auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+ if (!transposeOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto packOpResultType = packOp.getResult().getType();
+ auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
----------------
chelini wrote:
See above.
https://github.com/llvm/llvm-project/pull/74206
More information about the Mlir-commits
mailing list