[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)

Renato Golin llvmlistbot at llvm.org
Tue Jul 16 02:30:29 PDT 2024


================
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
----------------
rengolin wrote:

You really want a concat with 2 or more inputs, so I'd change this to:
```
  if (concatOperands.size() < 2) {
```

https://github.com/llvm/llvm-project/pull/98995


More information about the Mlir-commits mailing list