[Mlir-commits] [mlir] d84450b - [mlir][linalg] Canonicalize tensor.extract(linalg.fill)
Lei Zhang
llvmlistbot at llvm.org
Fri Aug 4 09:34:32 PDT 2023
Author: Surya Jasper
Date: 2023-08-04T09:34:21-07:00
New Revision: d84450bc5123e662a09ea2e431765fcf457b6992
URL: https://github.com/llvm/llvm-project/commit/d84450bc5123e662a09ea2e431765fcf457b6992
DIFF: https://github.com/llvm/llvm-project/commit/d84450bc5123e662a09ea2e431765fcf457b6992.diff
LOG: [mlir][linalg] Canonicalize tensor.extract(linalg.fill)
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D156008
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d6778ed72c7d0e..6923f55591fc8a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -730,12 +730,34 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
}
};
+/// Fold tensor.extract(linalg.fill(<input>)) into <input>
+struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
+public:
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+ auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ // Get scalar input operand of linalg.fill op.
+ Value extractedScalar = fillOp.getInputs()[0];
+
+ // Replace tensor.extract op with scalar value used to fill the tensor.
+ rewriter.replaceOp(extractOp, extractedScalar);
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
- .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ .add<FoldFillWithTensorExtract,
+ FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill>(context);
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5f45ab40875af2..d3bb48d5a442d4 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -335,6 +335,22 @@ func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x
return %1 : tensor<?x?xf32>
}
+// -----
+// CHECK: func @fold_fill_extract
+// CHECK-SAME: %[[ARG0:.+]]: i1
+func.func @fold_fill_extract(%arg0 : i1) -> i1 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1>
+ %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1>
+
+ %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1>
+
+ // CHECK: return %[[ARG0]]
+ return %extracted : i1
+}
+
// -----
// CHECK: func @fold_self_copy
More information about the Mlir-commits
mailing list