[Mlir-commits] [mlir] [mlir] [linalg] Add canonicalize pattern to swap transpose with broadcast (PR #97063)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 3 21:27:35 PDT 2024
================
@@ -1890,9 +1890,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+/// This pattern reduces the cost of transpose by swapping the order of
+/// broadcast and transpose:
+/// transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ Value input = transposeOp.getInput();
+ BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+ if (!input.hasOneUse() || !broadcastOp)
+ return failure();
+
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+ // Get new perms and new dimensions.
+ SmallVector<int64_t> resultPerms = removePermutation(perms, dimensions);
+ SmallVector<int64_t> resultDimensions;
+ SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+ for (unsigned i = 0; i < dimensions.size(); i++) {
+ resultDimensions.push_back(invertPerm[dimensions[i]]);
+ }
+ llvm::sort(resultDimensions);
+
----------------
MaheshRavishankar wrote:
Unclear to me why there is a sort
https://github.com/llvm/llvm-project/pull/97063
More information about the Mlir-commits
mailing list