[Mlir-commits] [mlir] [llvm] [mlir][tensor] Fold consumer linalg transpose with producer tensor pack (PR #74206)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Dec 13 13:34:17 PST 2023
================
@@ -81,10 +82,77 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ auto mixedInnerTiles = packOp.getMixedTiles();
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ SmallVector<int64_t> newOuterDimsPermVec;
+ SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<OpFoldResult> newMixedInnerTilesVec;
+ int64_t srcRank = packOp.getSourceRank();
+
+ // Process transpose operation for non-tiled outer dimensions
+ for (unsigned int i = 0; i < srcRank; ++i) {
+ // Variable for storing remapped position after considering original
+ // outer_dims_perm and permutation attributes of tensor.pack and
+ // linalg.transpose.
+ int64_t remappedPosition;
+
+ // If tensor.pack has outer_dims_perm attribute, then consider it during
+ // index remapping.
+ if (!outerDimsPerm.empty()) {
+ if (transposePerm[i] < srcRank) {
+ remappedPosition = outerDimsPerm[transposePerm[i]];
+ } else {
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "Cannot fold in tensor.pack if a tile dimension was transposed "
+ "with a non-tile dimension in linalg.transpose.");
+ }
+ } else {
+ remappedPosition = transposePerm[i];
+ }
----------------
hanhanW wrote:
I think we can simplify the logic a bit, and saves some level of nested indents. E.g.,
```suggestion
int64_t remappedPosition = transposePerm[i];
// If tensor.pack has outer_dims_perm attribute, then consider it during
// index remapping.
if (!outerDimsPerm.empty()) {
if (transposePerm[i] >= srcRank) {
return rewriter.notifyMatchFailure(
transposeOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
}
remappedPosition = outerDimsPerm[remappedPosition];
}
```
The comment for `remappedPosition` is not very helpful to me because the variable name already spells it to me.
https://github.com/llvm/llvm-project/pull/74206
More information about the Mlir-commits
mailing list