[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 2 04:51:04 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aviad Cohen (AviadCo)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/125401.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+28)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+1)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+/// %empty = tensor.empty() : tensor<i8>
+/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+/// outs(%filled : tensor<i8>) dimensions = [0]
+/// (%in: i8, %init: i8) {
+/// %3 = arith.addi %in, %init : i8
+/// linalg.yield %3 : i8
+/// }
+/// The pattern is rewritten into:
+/// %empty = tensor.empty() : tensor<i8>
+/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+/// tensor<i8>) {
+/// ^bb0(%in: i8, %init: i8):
+/// %cst = arith.constant 0 : index
+/// %index = linalg.index %c0 : index
+/// %cmp = arith.cmpi eq, %cst, %index : i1
+/// %sum = arith.select %cmp, %c0, %init : i8
+/// %res = arith.addi %in, %sum : i8
+/// linalg.yield %res : i8
+/// }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
+ FuseFillOpWithReduceOp.cpp
Fusion.cpp
Generalization.cpp
Hoisting.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
TestLinalgRankReduceContractionOps.cpp
+ TestLinalgFuseFillOpWithReduceOp.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgRankReduceContractionOps();
+ mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
``````````
</details>
https://github.com/llvm/llvm-project/pull/125401
More information about the Mlir-commits
mailing list