[Mlir-commits] [mlir] Fold `linalg.fill` -> `linalg.copy` along `outs` use in the consumer. (PR #72920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 14:24:01 PST 2023
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/72920
>From ca09afe76b7ce69e3ec0316d0442daf5c2f61683 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 8 Nov 2023 13:40:37 -0700
Subject: [PATCH] Fold `linalg.fill` -> `linalg.copy` along `outs` use in the
consumer.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 24 ++++++++++++++++++----
mlir/test/Dialect/Linalg/canonicalize.mlir | 13 ++++++++++++
2 files changed, 33 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d12ba8c4c59b33f..24d29d299ab2418 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,14 +803,30 @@ struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
}
};
+/// Fold fill with copy.
+struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
+ using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>();
+ if (!fillOp)
+ return failure();
+ rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
+ fillOp.getOutputs());
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
- FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
- FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill>(context);
+ results
+ .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
+ FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+ FoldInsertPadIntoFill>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7793e435582746c..c054829a915d7ba 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -972,3 +972,16 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %3: tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %copy : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list