[Mlir-commits] [mlir] 42cd9ae - Fold `linalg.fill` -> `linalg.copy` (#72920)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 06:20:39 PST 2023


Author: MaheshRavishankar
Date: 2023-11-21T06:20:34-08:00
New Revision: 42cd9aeec286a6928da59dce1134fdced0f0462a

URL: https://github.com/llvm/llvm-project/commit/42cd9aeec286a6928da59dce1134fdced0f0462a
DIFF: https://github.com/llvm/llvm-project/commit/42cd9aeec286a6928da59dce1134fdced0f0462a.diff

LOG: Fold `linalg.fill` -> `linalg.copy` (#72920)

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d12ba8c4c59b33f..58af9995548e939 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,14 +803,36 @@ struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
   }
 };
 
+/// Fold fill with copy.
+struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
+  using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
+                                          fillOp.getInputs(),
+                                          copyOp.getOutputs());
+      return success();
+    }
+    if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
+                                                  fillOp.getOutputs());
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
-              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-              FoldInsertPadIntoFill>(context);
+  results
+      .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
+           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+           FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+           FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7793e435582746c..e875bae4730946b 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -972,3 +972,29 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<
   %3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %3: tensor<?x?xf32>
 }
+// -----
+
+// CHECK-LABEL: func @canonicalize_fill_to_copy_input(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:   %[[ZERO:.+]] = arith.constant 0.0
+//       CHECK:   linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_copy_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %copy = linalg.copy ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %copy : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:   linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %copy : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list