[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)
Aviad Cohen
llvmlistbot at llvm.org
Sun Feb 2 04:51:35 PST 2025
https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/125401
>From 56afd0f41761fd4b059585beb341cb71fbd5b908 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 2 Feb 2025 14:48:16 +0200
Subject: [PATCH] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce
operations
---
.../Dialect/Linalg/Transforms/Transforms.h | 28 +++++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Transforms/FuseFillOpWithReduceOp.cpp | 107 ++++++++++++++++++
.../Linalg/fuse_fill_op_with_reduce_op.mlir | 88 ++++++++++++++
mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 1 +
.../TestLinalgFuseFillOpWithReduceOp.cpp | 63 +++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 290 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
create mode 100644 mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+/// %empty = tensor.empty() : tensor<i8>
+/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+/// outs(%filled : tensor<i8>) dimensions = [0]
+/// (%in: i8, %init: i8) {
+/// %3 = arith.addi %in, %init : i8
+/// linalg.yield %3 : i8
+/// }
+/// The pattern is rewritten into:
+/// %empty = tensor.empty() : tensor<i8>
+/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+/// tensor<i8>) {
+/// ^bb0(%in: i8, %init: i8):
+/// %cst = arith.constant 0 : index
+/// %index = linalg.index %c0 : index
+/// %cmp = arith.cmpi eq, %cst, %index : i1
+/// %sum = arith.select %cmp, %c0, %init : i8
+/// %res = arith.addi %in, %sum : i8
+/// linalg.yield %res : i8
+/// }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
+ FuseFillOpWithReduceOp.cpp
Fusion.cpp
Generalization.cpp
Hoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
new file mode 100644
index 00000000000000..6811bbbb63e228
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
@@ -0,0 +1,107 @@
+//===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that fuses a linalg.generic -> tensor.pad op
+// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
+// op chain.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic
+/// operation. The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+/// %empty = tensor.empty() : tensor<i8>
+/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+/// outs(%filled : tensor<i8>) dimensions = [0]
+/// (%in: i8, %init: i8) {
+/// %3 = arith.addi %in, %init : i8
+/// linalg.yield %3 : i8
+/// }
+/// The pattern is rewritten into:
+/// %empty = tensor.empty() : tensor<i8>
+/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+/// tensor<i8>) {
+/// ^bb0(%in: i8, %init: i8):
+/// %cst = arith.constant 0 : index
+/// %index = linalg.index %c0 : index
+/// %cmp = arith.cmpi eq, %cst, %index : i1
+/// %sum = arith.select %cmp, %c0, %init : i8
+/// %res = arith.addi %in, %sum : i8
+/// linalg.yield %res : i8
+/// }
+struct FoldFillWithReduceOp : public OpRewritePattern<linalg::ReduceOp> {
+ using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp,
+ PatternRewriter &rewriter) const override {
+ if (!reduceOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ reduceOp, "skip reduce op with non-pure tensor semantics");
+ if (reduceOp.getDimensions().size() != 1)
+ return rewriter.notifyMatchFailure(
+ reduceOp, "skip reduce op with non-single dimension");
+ if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1)
+ return rewriter.notifyMatchFailure(
+ reduceOp, "skip reduce op with multiple number of inputs/results");
+ auto fillOp = reduceOp.getInits()[0].getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return rewriter.notifyMatchFailure(
+ reduceOp,
+ "skip reduce op with inits not directly based on fill operation");
+
+ long dim = reduceOp.getDimensions()[0];
+ // Note: on success, the `reduceOp` is replaced with a genericOp and no
+ // longer valid.
+ auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp);
+ if (failed(failureOrGenericOp))
+ return rewriter.notifyMatchFailure(reduceOp,
+ "failed to generalize reduce op");
+
+ linalg::GenericOp genericReduceOp = *failureOrGenericOp;
+ auto operandIdx = -1;
+ for (auto &use : genericReduceOp->getOpOperands()) {
+ if (use.get().getDefiningOp() == fillOp)
+ operandIdx = use.getOperandNumber();
+ }
+ assert(operandIdx != -1 && "fill op not found in reduce op uses");
+
+ Location loc = genericReduceOp.getLoc();
+ auto blockArg = genericReduceOp.getMatchingBlockArgument(
+ &genericReduceOp->getOpOperand(operandIdx));
+ rewriter.setInsertionPointToStart(genericReduceOp.getBody());
+ auto constZeroIndexOp = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto linalgIndexOp = rewriter.create<linalg::IndexOp>(loc, dim);
+ auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ constZeroIndexOp.getResult(),
+ linalgIndexOp.getResult());
+ auto selectOp =
+ rewriter.create<arith::SelectOp>(loc, cmpIOp, fillOp.value(), blockArg);
+ rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp);
+ genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get());
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::linalg::populateFuseFillOpWithReduceOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldFillWithReduceOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
new file mode 100644
index 00000000000000..7721cfec72ce74
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func private @test_reduce_sum_kernel(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
+// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
+// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
+// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
+// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8
+// CHECK: linalg.yield %[[VAL_10]] : i8
+// CHECK: } -> tensor<i8>
+// CHECK: return %[[VAL_11:.*]] : tensor<i8>
+// CHECK: }
+
+func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+ %1 = tensor.empty() : tensor<i8>
+ %c0_i8 = arith.constant 0 : i8
+ %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+ %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
+ (%in: i8, %init: i8) {
+ %3 = arith.addi %in, %init : i8
+ linalg.yield %3 : i8
+ }
+ return %reduced : tensor<i8>
+}
+
+// -----
+
+func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+ %1 = tensor.empty() : tensor<i8>
+ // CHECK: linalg.reduce
+ %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor<i8>) dimensions = [0]
+ (%in: i8, %init: i8) {
+ %3 = arith.addi %in, %init : i8
+ linalg.yield %3 : i8
+ }
+ return %reduced : tensor<i8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @test_reduce_multiply_kernel(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i8
+// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
+// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
+// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
+// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
+// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8
+// CHECK: linalg.yield %[[VAL_10]] : i8
+// CHECK: } -> tensor<i8>
+// CHECK: return %[[VAL_11:.*]] : tensor<i8>
+// CHECK: }
+
+func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+ %1 = tensor.empty() : tensor<i8>
+ %c1_i8 = arith.constant 1 : i8
+ %2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+ %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
+ (%in: i8, %init: i8) {
+ %3 = arith.muli %in, %init : i8
+ linalg.yield %3 : i8
+ }
+ return %reduced : tensor<i8>
+}
+
+// -----
+
+func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor<i8>) {
+ %1 = tensor.empty() : tensor<i8>
+ %c0_i8 = arith.constant 0 : i8
+ // CHECK: linalg.fill
+ %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+ // CHECK: linalg.reduce
+ %reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor<i8>) dimensions = [0, 1]
+ (%in: i8, %init: i8) {
+ %3 = arith.addi %in, %init : i8
+ linalg.yield %3 : i8
+ }
+ return %reduced : tensor<i8>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
TestLinalgRankReduceContractionOps.cpp
+ TestLinalgFuseFillOpWithReduceOp.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
new file mode 100644
index 00000000000000..d5506cae741617
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
@@ -0,0 +1,63 @@
+//===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse linalg fill with linalg reduce
+// into a new linalg generic operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgFuseFillOpWithReduceOp
+ : public PassWrapper<TestLinalgFuseFillOpWithReduceOp,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp)
+
+ TestLinalgFuseFillOpWithReduceOp() = default;
+ TestLinalgFuseFillOpWithReduceOp(
+ const TestLinalgFuseFillOpWithReduceOp &pass) = default;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, linalg::LinalgDialect,
+ tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-linalg-fuse-fill-op-with-reduce-op";
+ }
+ StringRef getDescription() const final {
+ return "Test fuse linalg fill with linalg reduce into a new linalg generic "
+ "operation";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &this->getContext();
+ func::FuncOp funcOp = this->getOperation();
+
+ RewritePatternSet patterns(context);
+ linalg::populateFuseFillOpWithReduceOpPatterns(patterns);
+ if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgFuseFillOpWithReduceOp() {
+ PassRegistration<TestLinalgFuseFillOpWithReduceOp>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgRankReduceContractionOps();
+ mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
More information about the Mlir-commits
mailing list