[Mlir-commits] [mlir] [mlir] [linalg] Add canonicalize pattern to swap transpose with broadcast (PR #97063)

donald chen llvmlistbot at llvm.org
Thu Jul 4 05:12:06 PDT 2024


================
@@ -1890,9 +1890,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// This pattern reduces the cost of transpose by swapping the order of
+/// broadcast and transpose:
+///   transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transposeOp.getInput();
+    BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+    if (!input.hasOneUse() || !broadcastOp)
+      return failure();
+
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+    // Get new perms and new dimensions.
+    SmallVector<int64_t> resultPerms = removePermutation(perms, dimensions);
+    SmallVector<int64_t> resultDimensions;
+    SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+    for (unsigned i = 0; i < dimensions.size(); i++) {
+      resultDimensions.push_back(invertPerm[dimensions[i]]);
+    }
+    llvm::sort(resultDimensions);
+
----------------
cxy-1993 wrote:

Nice catch! I must have mixed up the limitations of other ops and broadcasts, and mistakenly thought that the dimension of broadcasts must be an increasing sequence. The code has been modified, and an non-increasing order dimension test added.

https://github.com/llvm/llvm-project/pull/97063


More information about the Mlir-commits mailing list