[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 30 12:41:10 PDT 2024
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/98995
>From 05042ed00a274fc47169a8c9885b0c16ea596b9e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Mon, 15 Jul 2024 23:27:50 -0700
Subject: [PATCH] [mlir][Linalg] Add a pattern to fold concats of fill.
If a concat has all its operands as just fills, and the values match,
then the fill could happen on the concatenated values of the `outs`
operands.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 61 ++++++++++++++++++++--
mlir/test/Dialect/Linalg/canonicalize.mlir | 27 ++++++++++
2 files changed, 83 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d1db90bbe2d20..99b625d99fec2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+/// Fold a concat with all elements being fills of the same value
+/// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOperands = concatOp.getInputs();
+ if (concatOperands.empty()) {
+ return failure();
+ }
+
+ auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+ if (!firstFillOp) {
+ return failure();
+ }
+ // Prefetch the fill value.
+ OpFoldResult firstFillVal =
+ getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+ // Collect all the outs values for the fill operations.
+ SmallVector<Value> allOuts;
+ allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+ auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
+ auto fillOp = v.getDefiningOp<linalg::FillOp>();
+ if (!fillOp) {
+ return false;
+ }
+
+ OpFoldResult fillVal =
+ getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
+ if (fillVal != firstFillVal)
+ return false;
+
+ allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
+ return true;
+ };
+ if (!llvm::all_of(concatOperands.drop_front(),
+ isDefinedByCompatibleFillOp)) {
+ return rewriter.notifyMatchFailure(
+ concatOp, "not all operands are defined by a compatible fill op");
+ }
+
+ Value outsConcat = rewriter.create<tensor::ConcatOp>(
+ concatOp.getLoc(), concatOp.getDim(), allOuts);
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
- FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
- FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+ results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
+ FoldFillWithPack, FoldFillWithPad,
+ FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+ FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index d34bc8c1c54f6..b1212b8863a5d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1169,3 +1169,30 @@ func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
permutation = [1, 0]
func.return %transpose : tensor<4x2xf32>
}
+
+// -----
+
+func.func @concats_of_fill(
+ %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+ -> tensor<5x?x?xf32>
+{
+ %cst0 = arith.constant 0.0 : f32
+ %cst1 = arith.constant 0.0 : f32
+ %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+ %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+ %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ return %4 : tensor<5x?x?xf32>
+}
+// CHECK: func @concats_of_fill(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index)
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+// CHECK-DAG: %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+// CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
+// CHECK: return %[[FILL]]
More information about the Mlir-commits
mailing list