[Mlir-commits] [mlir] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations (PR #125401)

Aviad Cohen llvmlistbot at llvm.org
Sun Feb 2 04:51:35 PST 2025


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/125401

>From 56afd0f41761fd4b059585beb341cb71fbd5b908 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 2 Feb 2025 14:48:16 +0200
Subject: [PATCH] [mlir][Linalg]: Add rewrite pattern to fuse fill with reduce
 operations

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  28 +++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/FuseFillOpWithReduceOp.cpp     | 107 ++++++++++++++++++
 .../Linalg/fuse_fill_op_with_reduce_op.mlir   |  88 ++++++++++++++
 mlir/test/lib/Dialect/Linalg/CMakeLists.txt   |   1 +
 .../TestLinalgFuseFillOpWithReduceOp.cpp      |  63 +++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 7 files changed, 290 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
 create mode 100644 mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
 create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+///   %empty = tensor.empty() : tensor<i8>
+///   %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+///   tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+///   outs(%filled : tensor<i8>) dimensions = [0]
+///     (%in: i8, %init: i8) {
+///       %3 = arith.addi %in, %init : i8
+///       linalg.yield %3 : i8
+///   }
+/// The pattern is rewritten into:
+///   %empty = tensor.empty() : tensor<i8>
+///   %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+///   tensor<i8>) {
+///     ^bb0(%in: i8, %init: i8):
+///       %cst = arith.constant 0 : index
+///       %index = linalg.index %c0 : index
+///       %cmp = arith.cmpi eq, %cst, %index : i1
+///       %sum = arith.select %cmp, %c0, %init : i8
+///       %res = arith.addi %in, %sum : i8
+///       linalg.yield %res : i8
+///   }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
+  FuseFillOpWithReduceOp.cpp
   Fusion.cpp
   Generalization.cpp
   Hoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
new file mode 100644
index 00000000000000..6811bbbb63e228
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp
@@ -0,0 +1,107 @@
+//===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that fuses a linalg.generic -> tensor.pad op
+// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
+// op chain.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic
+/// operation. The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+///   %empty = tensor.empty() : tensor<i8>
+///   %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+///   tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+///   outs(%filled : tensor<i8>) dimensions = [0]
+///     (%in: i8, %init: i8) {
+///       %3 = arith.addi %in, %init : i8
+///       linalg.yield %3 : i8
+///   }
+/// The pattern is rewritten into:
+///   %empty = tensor.empty() : tensor<i8>
+///   %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+///   tensor<i8>) {
+///     ^bb0(%in: i8, %init: i8):
+///       %cst = arith.constant 0 : index
+///       %index = linalg.index %c0 : index
+///       %cmp = arith.cmpi eq, %cst, %index : i1
+///       %sum = arith.select %cmp, %c0, %init : i8
+///       %res = arith.addi %in, %sum : i8
+///       linalg.yield %res : i8
+///   }
+struct FoldFillWithReduceOp : public OpRewritePattern<linalg::ReduceOp> {
+  using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp,
+                                PatternRewriter &rewriter) const override {
+    if (!reduceOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(
+          reduceOp, "skip reduce op with non-pure tensor semantics");
+    if (reduceOp.getDimensions().size() != 1)
+      return rewriter.notifyMatchFailure(
+          reduceOp, "skip reduce op with non-single dimension");
+    if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1)
+      return rewriter.notifyMatchFailure(
+          reduceOp, "skip reduce op with multiple number of inputs/results");
+    auto fillOp = reduceOp.getInits()[0].getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return rewriter.notifyMatchFailure(
+          reduceOp,
+          "skip reduce op with inits not directly based on fill operation");
+
+    long dim = reduceOp.getDimensions()[0];
+    // Note: on success, the `reduceOp` is replaced with a genericOp and no
+    // longer valid.
+    auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp);
+    if (failed(failureOrGenericOp))
+      return rewriter.notifyMatchFailure(reduceOp,
+                                         "failed to generalize reduce op");
+
+    linalg::GenericOp genericReduceOp = *failureOrGenericOp;
+    auto operandIdx = -1;
+    for (auto &use : genericReduceOp->getOpOperands()) {
+      if (use.get().getDefiningOp() == fillOp)
+        operandIdx = use.getOperandNumber();
+    }
+    assert(operandIdx != -1 && "fill op not found in reduce op uses");
+
+    Location loc = genericReduceOp.getLoc();
+    auto blockArg = genericReduceOp.getMatchingBlockArgument(
+        &genericReduceOp->getOpOperand(operandIdx));
+    rewriter.setInsertionPointToStart(genericReduceOp.getBody());
+    auto constZeroIndexOp = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto linalgIndexOp = rewriter.create<linalg::IndexOp>(loc, dim);
+    auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                                 constZeroIndexOp.getResult(),
+                                                 linalgIndexOp.getResult());
+    auto selectOp =
+        rewriter.create<arith::SelectOp>(loc, cmpIOp, fillOp.value(), blockArg);
+    rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp);
+    genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get());
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::linalg::populateFuseFillOpWithReduceOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FoldFillWithReduceOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
new file mode 100644
index 00000000000000..7721cfec72ce74
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:   func.func private @test_reduce_sum_kernel(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
+// CHECK:           %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
+// CHECK:             %[[VAL_7:.*]] = linalg.index 0 : index
+// CHECK:             %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
+// CHECK:             %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8
+// CHECK:             linalg.yield %[[VAL_10]] : i8
+// CHECK:           } -> tensor<i8>
+// CHECK:           return %[[VAL_11:.*]] : tensor<i8>
+// CHECK:         }
+
+func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+  %1 = tensor.empty() : tensor<i8>
+  %c0_i8 = arith.constant 0 : i8
+  %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+  %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
+    (%in: i8, %init: i8) {
+      %3 = arith.addi %in, %init : i8
+      linalg.yield %3 : i8
+    }
+  return %reduced : tensor<i8>
+}
+
+// -----
+
+func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+  %1 = tensor.empty() : tensor<i8>
+  // CHECK: linalg.reduce
+  %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor<i8>) dimensions = [0]
+    (%in: i8, %init: i8) {
+      %3 = arith.addi %in, %init : i8
+      linalg.yield %3 : i8
+    }
+  return %reduced : tensor<i8>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @test_reduce_multiply_kernel(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : i8
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
+// CHECK:           %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
+// CHECK:             %[[VAL_7:.*]] = linalg.index 0 : index
+// CHECK:             %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
+// CHECK:             %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
+// CHECK:             %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8
+// CHECK:             linalg.yield %[[VAL_10]] : i8
+// CHECK:           } -> tensor<i8>
+// CHECK:           return %[[VAL_11:.*]] : tensor<i8>
+// CHECK:         }
+
+func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
+  %1 = tensor.empty() : tensor<i8>
+  %c1_i8 = arith.constant 1 : i8
+  %2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+  %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
+    (%in: i8, %init: i8) {
+      %3 = arith.muli %in, %init : i8
+      linalg.yield %3 : i8
+    }
+  return %reduced : tensor<i8>
+}
+
+// -----
+
+func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor<i8>) {
+  %1 = tensor.empty() : tensor<i8>
+  %c0_i8 = arith.constant 0 : i8
+  // CHECK: linalg.fill
+  %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
+  // CHECK: linalg.reduce
+  %reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor<i8>) dimensions = [0, 1]
+    (%in: i8, %init: i8) {
+      %3 = arith.addi %in, %init : i8
+      linalg.yield %3 : i8
+    }
+  return %reduced : tensor<i8>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgRankReduceContractionOps.cpp
+  TestLinalgFuseFillOpWithReduceOp.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
new file mode 100644
index 00000000000000..d5506cae741617
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp
@@ -0,0 +1,63 @@
+//===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse linalg fill with linalg reduce
+// into a new linalg generic operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLinalgFuseFillOpWithReduceOp
+    : public PassWrapper<TestLinalgFuseFillOpWithReduceOp,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp)
+
+  TestLinalgFuseFillOpWithReduceOp() = default;
+  TestLinalgFuseFillOpWithReduceOp(
+      const TestLinalgFuseFillOpWithReduceOp &pass) = default;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, linalg::LinalgDialect,
+                    tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-linalg-fuse-fill-op-with-reduce-op";
+  }
+  StringRef getDescription() const final {
+    return "Test fuse linalg fill with linalg reduce into a new linalg generic "
+           "operation";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &this->getContext();
+    func::FuncOp funcOp = this->getOperation();
+
+    RewritePatternSet patterns(context);
+    linalg::populateFuseFillOpWithReduceOpPatterns(patterns);
+    if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgFuseFillOpWithReduceOp() {
+  PassRegistration<TestLinalgFuseFillOpWithReduceOp>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgRankReduceContractionOps();
+  mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();



More information about the Mlir-commits mailing list