[Mlir-commits] [mlir] [mlir] [linalg] Add canonicalize pattern to swap transpose with broadcast (PR #97063)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 21:27:35 PDT 2024


================
@@ -1890,9 +1890,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// This pattern reduces the cost of transpose by swapping the order of
+/// broadcast and transpose:
+///   transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transposeOp.getInput();
+    BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+    if (!input.hasOneUse() || !broadcastOp)
+      return failure();
+
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+    // Get new perms and new dimensions.
+    SmallVector<int64_t> resultPerms = removePermutation(perms, dimensions);
+    SmallVector<int64_t> resultDimensions;
+    SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+    for (unsigned i = 0; i < dimensions.size(); i++) {
+      resultDimensions.push_back(invertPerm[dimensions[i]]);
+    }
+    llvm::sort(resultDimensions);
+
----------------
MaheshRavishankar wrote:

Unclear to me why there is a sort

https://github.com/llvm/llvm-project/pull/97063


More information about the Mlir-commits mailing list