[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 12:27:47 PDT 2024


================
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
+      return failure();
+    }
+
+    auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+    if (!firstFillOp) {
+      return failure();
+    }
+    // Prefetch the fill value.
+    OpFoldResult firstFillVal =
+        getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+    // Collect all the outs values for the fill operations.
+    SmallVector<Value> allOuts;
+    allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+    auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
+      auto fillOp = v.getDefiningOp<linalg::FillOp>();
+      if (!fillOp) {
+        return false;
+      }
+
+      OpFoldResult fillVal =
+          getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
+      if (fillVal != firstFillVal)
+        return false;
----------------
Max191 wrote:

Oh wait, I just misinterpreted what was happening here. Nevermind

https://github.com/llvm/llvm-project/pull/98995


More information about the Mlir-commits mailing list