[Mlir-commits] [mlir] [mlir] [linalg] Add pattern to swap transpose with broadcast (PR #97063)
Diego Caballero
llvmlistbot at llvm.org
Fri Jul 19 09:53:40 PDT 2024
================
@@ -1890,9 +1890,67 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+/// This pattern canonicalize transpose by swapping the order of
+/// broadcast and transpose:
+/// transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ Value input = transposeOp.getInput();
+ BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+ if (!input.hasOneUse() || !broadcastOp)
+ return failure();
+
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+ // Get new perms and new dimensions.
+ SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
+ SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+ SmallVector<int64_t> resultDimensions;
+ for (unsigned i = 0; i < dimensions.size(); i++) {
----------------
dcaballe wrote:
nit: move ub to var + use pre-increment + remove curly braces for single statement `for` per coding guidelines
https://github.com/llvm/llvm-project/pull/97063
More information about the Mlir-commits
mailing list