[Mlir-commits] [mlir] [mlir][Linalg] Add a pattern to fold concats of fill. (PR #98995)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 23:30:41 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/98995

If a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the `outs` operands.

>From 9b7944020a5ab5b8b7810e01ce6e8a95b093791e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Mon, 15 Jul 2024 23:27:50 -0700
Subject: [PATCH] [mlir][Linalg] Add a pattern to fold concats of fill.

If a concat has all its operands as just fills, and the values match,
then the fill could happen on the concatenated values of the `outs`
operands.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 61 ++++++++++++++++++++--
 mlir/test/Dialect/Linalg/canonicalize.mlir | 27 ++++++++++
 2 files changed, 83 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..c26eabd811c65 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
+      return failure();
+    }
+
+    auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+    if (!firstFillOp) {
+      return failure();
+    }
+    // Prefetch the fill value.
+    OpFoldResult firstFillVal =
+        getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+    // Collect all the outs values for the fill operations.
+    SmallVector<Value> allOuts;
+    allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+    auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
+      auto fillOp = v.getDefiningOp<linalg::FillOp>();
+      if (!fillOp) {
+        return false;
+      }
+
+      OpFoldResult fillVal =
+          getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
+      if (fillVal != firstFillVal)
+        return false;
+
+      allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
+      return true;
+    };
+    if (!llvm::all_of(concatOperands.drop_front(),
+                      isDefinedByCompatibleFillOp)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "not all operands are defined by a compatible fill op");
+    }
+
+    Value outsConcat = rewriter.create<tensor::ConcatOp>(
+        concatOp.getLoc(), concatOp.getDim(), allOuts);
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(
+        concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
-           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-           FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-           FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
+              FoldFillWithPack, FoldFillWithPad,
+              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+              FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..02b46c405e2bd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1096,3 +1096,30 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+
+// -----
+
+func.func @concats_of_fill(
+    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> tensor<5x?x?xf32>
+{
+  %cst0 = arith.constant 0.0 : f32
+  %cst1 = arith.constant 0.0 : f32
+  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+  %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  return %4 : tensor<5x?x?xf32>
+}
+//       CHECK: func @concats_of_fill(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[CST:.+]] = arith.constant 0.0
+//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+//       CHECK:   %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
+//       CHECK:   return %[[FILL]]



More information about the Mlir-commits mailing list