[Mlir-commits] [mlir] [mlir] Canonicalize tensor.extract_slice (linalg.fill) (PR #112619)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 16 14:10:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nithin Meganathan (nithinsubbiah)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/112619.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31-1) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..658445f19e91aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+public:
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // See if tensor input of tensor.extract_slice op is the result of a
+    // linalg.fill op.
+    auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    Value fillInput = fillOp.getInputs()[0];
+
+    Location loc = extractSliceOp.getLoc();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    auto emptyOp = rewriter.create<tensor::EmptyOp>(
+        loc, mixedSizes, extractSliceOp.getType().getElementType());
+
+    // Replace tensor.extract_slice op with new linalg.fillOp (former's result
+    // type and shape).
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(
+        extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
+        ValueRange{emptyOp});
+    return success();
+  }
+};
+
 /// Folds pack(fill) into a single fill op if
 ///   1. The pack op does not have padding value, or
 ///   2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
   results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
-              FoldFillWithPack, FoldFillWithPad,
+              FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
               FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
               FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
               FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..763cf80241fcc0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
+  %c0 = arith.constant 0. : f32
+  %0 = tensor.empty() : tensor<2x1920x66x66xf32>
+  %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
+  %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
+  return %extracted_slice : tensor<2x1920x64x66xf32>
+}
+// CHECK-LABEL: func.func @fold_fill_extract_slice
+// CHECK:         %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 func.func @fill_pack() -> tensor<24x32x16x16xf32> {
   %dest = tensor.empty() : tensor<384x512xf32>
   %cst = arith.constant 0.000000e+00 : f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/112619


More information about the Mlir-commits mailing list