[Mlir-commits] [mlir] 205dce6 - [mlir][linalg] Add a folder for transpose(fill) -> fill (#83623)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 2 14:47:19 PST 2024


Author: Quinn Dawkins
Date: 2024-03-02T17:47:16-05:00
New Revision: 205dce6029bed302f354c0bde5d8c5804f214051

URL: https://github.com/llvm/llvm-project/commit/205dce6029bed302f354c0bde5d8c5804f214051
DIFF: https://github.com/llvm/llvm-project/commit/205dce6029bed302f354c0bde5d8c5804f214051.diff

LOG: [mlir][linalg] Add a folder for transpose(fill) -> fill (#83623)

This is similar to the existing folder for a linalg.copy. Transposing a
filled tensor is the same as filling the destination of the transpose.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 919f5130e1760f..6954eee93efd14 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
   }
 };
 
+/// Fold fill with transpose.
+struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<FillOp>(
+          transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
+          transposeOp.getDpsInitOperand(0)->get());
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
       .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
            FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
            FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-           FoldInsertPadIntoFill>(context);
+           FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 206d7e9f1ce8df..19cea6c2066c92 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:   %[[ZERO:.+]] = arith.constant 0.0
+//       CHECK:   linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+  return %transpose : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @broadcast_same_shape(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)


        


More information about the Mlir-commits mailing list