[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 2 04:51:04 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/125401.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+28) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+1) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+///   %empty = tensor.empty() : tensor<i8>
+///   %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+///   tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+///   outs(%filled : tensor<i8>) dimensions = [0]
+///     (%in: i8, %init: i8) {
+///       %3 = arith.addi %in, %init : i8
+///       linalg.yield %3 : i8
+///   }
+/// The pattern is rewritten into:
+///   %empty = tensor.empty() : tensor<i8>
+///   %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+///   tensor<i8>) {
+///     ^bb0(%in: i8, %init: i8):
+///       %cst = arith.constant 0 : index
+///       %index = linalg.index %c0 : index
+///       %cmp = arith.cmpi eq, %cst, %index : i1
+///       %sum = arith.select %cmp, %c0, %init : i8
+///       %res = arith.addi %in, %sum : i8
+///       linalg.yield %res : i8
+///   }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
+  FuseFillOpWithReduceOp.cpp
   Fusion.cpp
   Generalization.cpp
   Hoisting.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgRankReduceContractionOps.cpp
+  TestLinalgFuseFillOpWithReduceOp.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgRankReduceContractionOps();
+  mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();

``````````

</details>


https://github.com/llvm/llvm-project/pull/125401


More information about the Mlir-commits mailing list