[Mlir-commits] [mlir] [mlir] Canonicalize tensor.extract_slice (linalg.fill) (PR #112619)
Nithin Meganathan
llvmlistbot at llvm.org
Wed Oct 16 14:09:28 PDT 2024
https://github.com/nithinsubbiah created https://github.com/llvm/llvm-project/pull/112619
None
>From 9d24d35c3cfd3716de970c682b585a73548a0c99 Mon Sep 17 00:00:00 2001
From: nithinsubbiah <nithinsubbiah at gmail.com>
Date: Wed, 16 Oct 2024 14:08:25 -0700
Subject: [PATCH] [mlir] Canonicalize tensor.extract_slice (linalg.fill)
Signed-off-by: nithinsubbiah <nithinsubbiah at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 32 +++++++++++++++++++++-
mlir/test/Dialect/Linalg/canonicalize.mlir | 14 ++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..658445f19e91aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
}
};
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+public:
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ // See if tensor input of tensor.extract_slice op is the result of a
+ // linalg.fill op.
+ auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ Value fillInput = fillOp.getInputs()[0];
+
+ Location loc = extractSliceOp.getLoc();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, mixedSizes, extractSliceOp.getType().getElementType());
+
+ // Replace tensor.extract_slice op with new linalg.fillOp (former's result
+ // type and shape).
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
+ ValueRange{emptyOp});
+ return success();
+ }
+};
+
/// Folds pack(fill) into a single fill op if
/// 1. The pack op does not have padding value, or
/// 2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
- FoldFillWithPack, FoldFillWithPad,
+ FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..763cf80241fcc0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
// -----
+func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
+ %c0 = arith.constant 0. : f32
+ %0 = tensor.empty() : tensor<2x1920x66x66xf32>
+ %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
+ %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
+ return %extracted_slice : tensor<2x1920x64x66xf32>
+}
+// CHECK-LABEL: func.func @fold_fill_extract_slice
+// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
+// CHECK: return %[[FILL]]
+
+// -----
+
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
%dest = tensor.empty() : tensor<384x512xf32>
%cst = arith.constant 0.000000e+00 : f32
More information about the Mlir-commits
mailing list