[Mlir-commits] [mlir] [mlir] Canonicalize tensor.extract_slice (linalg.fill) (PR #112619)

Nithin Meganathan llvmlistbot at llvm.org
Fri Oct 25 15:57:30 PDT 2024


https://github.com/nithinsubbiah updated https://github.com/llvm/llvm-project/pull/112619

>From 758228d48d3550308172bcdb965e3e75585975a3 Mon Sep 17 00:00:00 2001
From: nithinsubbiah <nithinsubbiah at gmail.com>
Date: Wed, 16 Oct 2024 14:08:25 -0700
Subject: [PATCH 1/2] [mlir] Canonicalize tensor.extract_slice (linalg.fill)

Signed-off-by: nithinsubbiah <nithinsubbiah at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 32 +++++++++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 14 ++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..658445f19e91aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+public:
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // See if tensor input of tensor.extract_slice op is the result of a
+    // linalg.fill op.
+    auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    Value fillInput = fillOp.getInputs()[0];
+
+    Location loc = extractSliceOp.getLoc();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    auto emptyOp = rewriter.create<tensor::EmptyOp>(
+        loc, mixedSizes, extractSliceOp.getType().getElementType());
+
+    // Replace tensor.extract_slice op with new linalg.fillOp (former's result
+    // type and shape).
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(
+        extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
+        ValueRange{emptyOp});
+    return success();
+  }
+};
+
 /// Folds pack(fill) into a single fill op if
 ///   1. The pack op does not have padding value, or
 ///   2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
   results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
-              FoldFillWithPack, FoldFillWithPad,
+              FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
               FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
               FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
               FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..763cf80241fcc0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
+  %c0 = arith.constant 0. : f32
+  %0 = tensor.empty() : tensor<2x1920x66x66xf32>
+  %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
+  %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
+  return %extracted_slice : tensor<2x1920x64x66xf32>
+}
+// CHECK-LABEL: func.func @fold_fill_extract_slice
+// CHECK:         %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 func.func @fill_pack() -> tensor<24x32x16x16xf32> {
   %dest = tensor.empty() : tensor<384x512xf32>
   %cst = arith.constant 0.000000e+00 : f32

>From 52e7b5324c6f9e040f2bcbe5f40f51ab093c45ff Mon Sep 17 00:00:00 2001
From: nithinsubbiah <nithinsubbiah at gmail.com>
Date: Sun, 20 Oct 2024 17:57:43 -0700
Subject: [PATCH 2/2] [mlir] Check if ExtractSliceOp is the only consumer while
 folding into FillOp

Signed-off-by: nithinsubbiah <nithinsubbiah at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 658445f19e91aa..2f2b6fed2add4e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -806,7 +806,7 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
-/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
+/// Fold tensor.extract_slice(linalg.fill(<input>)) into linalg.fill(<input>)
 struct FoldFillWithTensorExtractSlice
     : public OpRewritePattern<tensor::ExtractSliceOp> {
 public:
@@ -817,10 +817,10 @@ struct FoldFillWithTensorExtractSlice
     // See if tensor input of tensor.extract_slice op is the result of a
     // linalg.fill op.
     auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
-    if (!fillOp)
+    if (!fillOp || !fillOp->hasOneUse())
       return failure();
 
-    Value fillInput = fillOp.getInputs()[0];
+    Value fillInput = fillOp.value();
 
     Location loc = extractSliceOp.getLoc();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
@@ -965,11 +965,12 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
-              FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
+  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithPack,
+              FoldFillWithPad, FoldFillWithTensorExtract,
+              FoldFillWithTensorExtractSlice,
               FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
               FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-              FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+              FoldFillWithTranspose, FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list