[Mlir-commits] [mlir] [mlir] Canonicalize tensor.extract_slice (linalg.fill) (PR #112619)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 16 14:10:22 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Nithin Meganathan (nithinsubbiah)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/112619.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31-1)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..658445f19e91aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
}
};
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+public:
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ // See if tensor input of tensor.extract_slice op is the result of a
+ // linalg.fill op.
+ auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ Value fillInput = fillOp.getInputs()[0];
+
+ Location loc = extractSliceOp.getLoc();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, mixedSizes, extractSliceOp.getType().getElementType());
+
+ // Replace tensor.extract_slice op with new linalg.fillOp (former's result
+ // type and shape).
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
+ ValueRange{emptyOp});
+ return success();
+ }
+};
+
/// Folds pack(fill) into a single fill op if
/// 1. The pack op does not have padding value, or
/// 2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
- FoldFillWithPack, FoldFillWithPad,
+ FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..763cf80241fcc0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
// -----
+func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
+ %c0 = arith.constant 0. : f32
+ %0 = tensor.empty() : tensor<2x1920x66x66xf32>
+ %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
+ %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
+ return %extracted_slice : tensor<2x1920x64x66xf32>
+}
+// CHECK-LABEL: func.func @fold_fill_extract_slice
+// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
+// CHECK: return %[[FILL]]
+
+// -----
+
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
%dest = tensor.empty() : tensor<384x512xf32>
%cst = arith.constant 0.000000e+00 : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/112619
More information about the Mlir-commits
mailing list