[Mlir-commits] [mlir] [mlir][linalg] Add a folder for transpose(fill) -> fill (PR #83623)

Quinn Dawkins llvmlistbot at llvm.org
Fri Mar 1 13:25:52 PST 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/83623

This is similar to the existing folder for a linalg.copy. Transposing a filled tensor is the same as filling the destination of the transpose.

>From 36d7acb35a0b3d875775a966d25b70df4832e50c Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 1 Mar 2024 01:51:27 -0500
Subject: [PATCH] [mlir][linalg] Add a folder for transpose(fill) -> fill

This is similar to the existing folder for a linalg.copy. Transposing
a filled tensor is the same as filling the destination of the transpose.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 18 +++++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 14 ++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 919f5130e1760f..6954eee93efd14 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
   }
 };
 
+/// Fold fill with transpose.
+struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<FillOp>(
+          transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
+          transposeOp.getDpsInitOperand(0)->get());
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
       .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
            FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
            FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-           FoldInsertPadIntoFill>(context);
+           FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 206d7e9f1ce8df..19cea6c2066c92 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:   %[[ZERO:.+]] = arith.constant 0.0
+//       CHECK:   linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+  return %transpose : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @broadcast_same_shape(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)



More information about the Mlir-commits mailing list