[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 17 10:18:33 PDT 2024
================
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOperands = concatOp.getInputs();
+ if (concatOperands.empty()) {
----------------
MaheshRavishankar wrote:
I could. So this is
```
%0 = linalg.fill ins (%cst : f32) outs(%init : tensor<...xf32>) -> tensor<...xf32>
%1 = tensor.concat dims(...) %0 : ...
```
I would expect to canonicalize to just
```
%1 = linalg.fill ...
```
but irrespective of that, this pattern would
```
%0 = tensor.concat dim(...) %init :
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<...xf32>) -> tensor<...xf32>
```
which would still a good thing to do and would give what would be best if that canonicalization pattern existed. So I am saying it might be worth even handling the corner case of 1 operand concat even though that shouldnt happen.
https://github.com/llvm/llvm-project/pull/98995
More information about the Mlir-commits
mailing list