[llvm] [mlir] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)
lorenzo chelini via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 11 04:35:35 PST 2023
================
@@ -81,10 +82,91 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeInputTensor = transposeOp.getOperand(0);
+ auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto packInnerDimsPos = packOp.getInnerDimsPos();
+ auto packInnerTiles = packOp.getStaticInnerTiles();
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ SmallVector<int64_t> newPackOuterDimsPermVec;
+ SmallVector<int64_t> newPackInnerDimsPosVec;
+ SmallVector<int64_t> newPackInnerTilesVec;
+
+ // Variable for storing translated position after considering original
+ // outer_dims_perm and permutation attributes of tensor.pack and
+ // linalg.transpose.
+ int64_t translatedPosition;
+
+ // Process transpose operation for non-tiled outer dimensions of the tensor.
+ for (unsigned int i = 0; i < transposePerm.size() - packInnerTiles.size();
+ ++i) {
+ // If tensor.pack has outer_dims_perm attribute, then consider it during
+ // index translation.
+ if (packOuterDimsPerm.size()) {
+ // Note: static_cast is added around transposePerm[i] to suppress the
+ // compiler warning of comparison between variables of different types.
+ if (static_cast<unsigned long>(transposePerm[i]) <
+ packOuterDimsPerm.size())
+ translatedPosition = packOuterDimsPerm[transposePerm[i]];
+ else
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "Cannot fold in tensor.pack if a tile dimension was transposed "
+ "with a non-tile dimension in linalg.transpose.");
+ } else
+ translatedPosition = transposePerm[i];
+
+ newPackOuterDimsPermVec.push_back(translatedPosition);
+ }
+
+ // Process transpose operation for tiled inner dimensions of the tensor.
----------------
chelini wrote:
nit: we could drop `of the tensor` I think it is clear from context.
https://github.com/llvm/llvm-project/pull/74206
More information about the llvm-commits
mailing list