[Mlir-commits] [mlir] d84450b - [mlir][linalg] Canonicalize tensor.extract(linalg.fill)

Lei Zhang llvmlistbot at llvm.org
Fri Aug 4 09:34:32 PDT 2023


Author: Surya Jasper
Date: 2023-08-04T09:34:21-07:00
New Revision: d84450bc5123e662a09ea2e431765fcf457b6992

URL: https://github.com/llvm/llvm-project/commit/d84450bc5123e662a09ea2e431765fcf457b6992
DIFF: https://github.com/llvm/llvm-project/commit/d84450bc5123e662a09ea2e431765fcf457b6992.diff

LOG: [mlir][linalg] Canonicalize tensor.extract(linalg.fill)

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D156008

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d6778ed72c7d0e..6923f55591fc8a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -730,12 +730,34 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
   }
 };
 
+/// Fold tensor.extract(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
+public:
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+    auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    // Get scalar input operand of linalg.fill op.
+    Value extractedScalar = fillOp.getInputs()[0];
+
+    // Replace tensor.extract op with scalar value used to fill the tensor.
+    rewriter.replaceOp(extractOp, extractedScalar);
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
   results
-      .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+      .add<FoldFillWithTensorExtract,
+           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
            FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
            FoldInsertPadIntoFill>(context);
 }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5f45ab40875af2..d3bb48d5a442d4 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -335,6 +335,22 @@ func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+//       CHECK: func @fold_fill_extract
+//  CHECK-SAME:   %[[ARG0:.+]]: i1
+func.func @fold_fill_extract(%arg0 : i1) -> i1 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1>
+  %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1>
+
+  %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1>
+
+  //  CHECK:   return %[[ARG0]]
+  return %extracted : i1
+}
+
 // -----
 
 // CHECK: func @fold_self_copy


        


More information about the Mlir-commits mailing list