[Mlir-commits] [mlir] [mlir] [linalg] Add pattern to swap transpose with broadcast (PR #97063)

Diego Caballero llvmlistbot at llvm.org
Fri Jul 19 09:53:40 PDT 2024


================
@@ -1890,9 +1890,67 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// This pattern canonicalize transpose by swapping the order of
+/// broadcast and transpose:
+///   transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transposeOp.getInput();
+    BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+    if (!input.hasOneUse() || !broadcastOp)
+      return failure();
+
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+    // Get new perms and new dimensions.
+    SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
+    SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+    SmallVector<int64_t> resultDimensions;
+    for (unsigned i = 0; i < dimensions.size(); i++) {
----------------
dcaballe wrote:

nit: move ub to var + use pre-increment + remove curly braces for single statement `for` per coding guidelines

https://github.com/llvm/llvm-project/pull/97063


More information about the Mlir-commits mailing list