[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 17 10:18:33 PDT 2024
================
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOperands = concatOp.getInputs();
+ if (concatOperands.empty()) {
+ return failure();
+ }
+
+ auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+ if (!firstFillOp) {
+ return failure();
+ }
+ // Prefetch the fill value.
+ OpFoldResult firstFillVal =
+ getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+ // Collect all the outs values for the fill operations.
+ SmallVector<Value> allOuts;
+ allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+ auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
----------------
MaheshRavishankar wrote:
Well, the first operand is checked to see if this is even a candidate, i.e. if the first operand is not a fill, then we early exist. For the rest of the operands we compare with the first (I find it strange comparing an op with itself...)
https://github.com/llvm/llvm-project/pull/98995
More information about the Mlir-commits
mailing list