[Mlir-commits] [mlir] bcfd32a - [mlir][linalg] Swap extract_slice(fill(x)) ops
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 6 03:33:03 PST 2023
Author: Matthias Springer
Date: 2023-01-06T12:28:29+01:00
New Revision: bcfd32adc4b658dc45aa8c338d5dd03837e2a0e4
URL: https://github.com/llvm/llvm-project/commit/bcfd32adc4b658dc45aa8c338d5dd03837e2a0e4
DIFF: https://github.com/llvm/llvm-project/commit/bcfd32adc4b658dc45aa8c338d5dd03837e2a0e4.diff
LOG: [mlir][linalg] Swap extract_slice(fill(x)) ops
This pattern is similar to `FoldFillWithTensorReshape`, which performs the same swapping with reshapes.
Fill the smaller extracted tensor slice instead of `x`. This allows for additional simplifications in case `x` is the result of another extract_slice.
Differential Revision: https://reviews.llvm.org/D141117
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b74537e95de2..48e7cfd64319 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -463,6 +463,36 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
}
};
+/// Swap extract_slice(fill) to fill(extract_slice).
+///
+/// Only swap the two ops if the extract_slice is the only user of the fill.
+struct SwapExtractSliceOfFill : OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto oldFill = extractSliceOp.getSource().getDefiningOp<FillOp>();
+ if (!oldFill)
+ return failure();
+ // Only swap the ops if there is no other user of the fill.
+ if (!extractSliceOp.getSource().hasOneUse())
+ return failure();
+ // Extract from the old fill's source.
+ rewriter.updateRootInPlace(extractSliceOp, [&]() {
+ extractSliceOp.getSourceMutable().assign(oldFill.output());
+ });
+ // Create a new fill and remove the old one.
+ rewriter.setInsertionPointAfter(extractSliceOp);
+ auto newFill =
+ rewriter.create<FillOp>(oldFill.getLoc(), ValueRange{oldFill.value()},
+ ValueRange{extractSliceOp.getResult()});
+ rewriter.eraseOp(oldFill);
+ // Use the new fill instead of the extract_slice.
+ rewriter.replaceAllUsesExcept(extractSliceOp.getResult(),
+ newFill.getResult(0), newFill);
+ return success();
+ }
+};
+
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
/// filling value are the same.
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
@@ -607,7 +637,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
results
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill>(context);
+ FoldInsertPadIntoFill, SwapExtractSliceOfFill>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9e4d886c5b0a..ee94d073aa7d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -312,6 +312,20 @@ func.func @fold_fill_reshape() -> tensor<6x4xf32> {
// -----
+// CHECK-LABEL: func @fold_fill_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<1x1xf32>
+func.func @fold_fill_extract_slice(%t: tensor<1x1xf32>) -> (tensor<f32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[e:.*]] = tensor.extract_slice %[[t]]
+ // CHECK: %[[f:.*]] = linalg.fill {{.*}} outs(%[[e]] : tensor<f32>)
+ %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<1x1xf32>) -> tensor<1x1xf32>
+ %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<1x1xf32> to tensor<f32>
+ // CHECK: return %[[f]]
+ return %1 : tensor<f32>
+}
+
+// -----
+
// CHECK: func @fold_fill_reshape_dynamic
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
More information about the Mlir-commits
mailing list