[Mlir-commits] [mlir] [mlir] Canonicalize tensor.extract_slice (linalg.fill) (PR #112619)
Nithin Meganathan
llvmlistbot at llvm.org
Sun Oct 20 18:04:24 PDT 2024
https://github.com/nithinsubbiah updated https://github.com/llvm/llvm-project/pull/112619
>From d78209e6cc26983e94ec2b7529e4ac59717ba7c3 Mon Sep 17 00:00:00 2001
From: nithinsubbiah <nithinsubbiah at gmail.com>
Date: Wed, 16 Oct 2024 14:08:25 -0700
Subject: [PATCH 1/2] [mlir] Canonicalize tensor.extract_slice (linalg.fill)
Signed-off-by: nithinsubbiah <nithinsubbiah at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 32 +++++++++++++++++++++-
mlir/test/Dialect/Linalg/canonicalize.mlir | 14 ++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..658445f19e91aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
}
};
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+public:
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ // See if tensor input of tensor.extract_slice op is the result of a
+ // linalg.fill op.
+ auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ Value fillInput = fillOp.getInputs()[0];
+
+ Location loc = extractSliceOp.getLoc();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, mixedSizes, extractSliceOp.getType().getElementType());
+
+ // Replace tensor.extract_slice op with new linalg.fillOp (former's result
+ // type and shape).
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
+ ValueRange{emptyOp});
+ return success();
+ }
+};
+
/// Folds pack(fill) into a single fill op if
/// 1. The pack op does not have padding value, or
/// 2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
- FoldFillWithPack, FoldFillWithPad,
+ FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..763cf80241fcc0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
// -----
+func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
+ %c0 = arith.constant 0. : f32
+ %0 = tensor.empty() : tensor<2x1920x66x66xf32>
+ %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
+ %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
+ return %extracted_slice : tensor<2x1920x64x66xf32>
+}
+// CHECK-LABEL: func.func @fold_fill_extract_slice
+// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
+// CHECK: return %[[FILL]]
+
+// -----
+
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
%dest = tensor.empty() : tensor<384x512xf32>
%cst = arith.constant 0.000000e+00 : f32
>From 1c93e3dfcdfbf96578de3715ef94a20691997ade Mon Sep 17 00:00:00 2001
From: nithinsubbiah <nithinsubbiah at gmail.com>
Date: Sun, 20 Oct 2024 17:57:43 -0700
Subject: [PATCH 2/2] [mlir] Check if ExtractSliceOp is the only consumer while
folding into FillOp
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 658445f19e91aa..6c26cdeff30e1b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,7 +806,7 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
}
};
-/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into linalg.fill(<input>)
struct FoldFillWithTensorExtractSlice
: public OpRewritePattern<tensor::ExtractSliceOp> {
public:
@@ -820,6 +820,16 @@ struct FoldFillWithTensorExtractSlice
if (!fillOp)
return failure();
+ // Perform folding if tensor.extract_slice is the only consumer of fillOp.
+ int totalConsumers = 0;
+ for (auto result : fillOp.getResults()) {
+ for (mlir::Operation *user : result.getUsers()) {
+ totalConsumers++;
+ }
+ }
+ if (totalConsumers > 1)
+ return failure();
+
Value fillInput = fillOp.getInputs()[0];
Location loc = extractSliceOp.getLoc();
@@ -965,11 +975,12 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
- FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
+ results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithPack,
+ FoldFillWithPad, FoldFillWithTensorExtract,
+ FoldFillWithTensorExtractSlice,
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+ FoldFillWithTranspose, FoldInsertPadIntoFill>(context);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list