[Mlir-commits] [mlir] [mlir][Transforms] CSE: Add filter options to control CSE'ing (PR #115639)
Matthias Springer
llvmlistbot at llvm.org
Sat Nov 9 23:38:25 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/115639
This commit adds two new pass options that gives users more fine-grained control over which ops are CSE'd / DCE'd.
* `barrier-op-filter` specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
* `eliminate-op-filter` specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)
>From d17454e081e167ab391ae1b9f21115591418e614 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 10 Nov 2024 08:29:29 +0100
Subject: [PATCH] [mlir][Transforms] CSE: Add filter options to control CSE'ing
This commit adds two new pass options that gives users more fine-grained control over which ops are CSE'd / DCE'd.
* `barrier-op-filter` specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
* `eliminate-op-filter` specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)
---
mlir/include/mlir/Transforms/CSE.h | 27 ++++++++-
mlir/include/mlir/Transforms/Passes.h | 2 +-
mlir/include/mlir/Transforms/Passes.td | 17 +++++-
mlir/lib/Transforms/CSE.cpp | 47 +++++++++++++---
mlir/test/Transforms/cse.mlir | 77 ++++++++++++++++++++++++--
5 files changed, 153 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509..4edca3e3369f24 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,19 +13,44 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_
+#include <functional>
+
namespace mlir {
class DominanceInfo;
class Operation;
class RewriterBase;
+/// Configuration for CSE.
+struct CSEConfig {
+ /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
+ /// matching ops.
+ ///
+ /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
+ /// this filter.
+ ///
+ /// Example:
+ /// %0 = arith.constant 0 : index
+ /// scf.for ... {
+ /// %1 = arith.constant 0 : index
+ /// ...
+ /// }
+ /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
+ std::function<bool(Operation *)> barrierOpFilter = nullptr;
+
+ /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
+ /// non-matching ops are subject to elimination.
+ std::function<bool(Operation *)> eliminateOpFilter = nullptr;
+};
+
/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
///
/// `changed` indicates whether the IR was modified or not.
void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed = nullptr);
+ bool *changed = nullptr,
+ CSEConfig config = CSEConfig());
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 5c977055e95dc8..41f208216374fe 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -33,7 +33,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
-#define GEN_PASS_DECL_CSEPASS
+#define GEN_PASS_DECL_CSE
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 000d9f697618e6..429029f21eb307 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
let summary = "Eliminate common sub-expressions";
let description = [{
This pass implements a generalized algorithm for common sub-expression
- elimination. This pass relies on information provided by the
- `Memory SideEffect` interface to identify when it is safe to eliminate
+ elimination. The pass also eliminates dead operation (DCE). The pass
+ relies on information provided by the `MemoryEffectOpInterface`
+ interface and on `DominanceInfo` to identify when it is safe to eliminate
operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
for more general details on this optimization.
+
+ The types of ops that are subject to elimination can be configured with
+ `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.
+
+ Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
+ barrier ops can be specified with `barrier-op-filter`.
}];
let constructor = "mlir::createCSEPass()";
+ let options = [
+ ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
+ "Names of ops that act as CSE'ing barriers">,
+ ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
+ "If non-empty, list of ops that are subject to elimination">,
+ ];
let statistics = [
Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de5..93ac35db276da0 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -23,8 +23,9 @@
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
-#include <deque>
+#include <deque>
+#include <unordered_set>
namespace mlir {
#define GEN_PASS_DEF_CSE
#include "mlir/Transforms/Passes.h.inc"
@@ -60,8 +61,9 @@ namespace {
/// Simple common sub-expression elimination.
class CSEDriver {
public:
- CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
- : rewriter(rewriter), domInfo(domInfo) {}
+ CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
+ const CSEConfig &config)
+ : rewriter(rewriter), domInfo(domInfo), config(config) {}
/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
@@ -125,6 +127,9 @@ class CSEDriver {
// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;
+
+ /// CSE configuration.
+ CSEConfig config;
};
} // namespace
@@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
bool hasSSADominance) {
+ // Don't simplify operations that are filtered out.
+ if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
+ return failure();
+
// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
@@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
- // implicit captures in explicit capture only regions.
- if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+ // implicit captures in explicit capture only regions. Additional barrier
+ // ops can be specified by the user.
+ bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+ (config.barrierOpFilter && config.barrierOpFilter(&op));
+ if (isBarrier) {
ScopedMapTy nestedKnownValues;
for (auto ®ion : op.getRegions())
simplifyRegion(nestedKnownValues, region);
@@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {
void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
- bool *changed) {
- CSEDriver driver(rewriter, &domInfo);
+ bool *changed, CSEConfig config) {
+ CSEDriver driver(rewriter, &domInfo, config);
driver.simplify(op, changed);
}
@@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
} // namespace
void CSE::runOnOperation() {
+ // Set up CSE configuration from pass options.
+ CSEConfig config;
+ std::unordered_set<std::string> barrierOpNames;
+ for (std::string opName : barrierOpFilter)
+ barrierOpNames.insert(opName);
+ std::unordered_set<std::string> eliminateOpNames;
+ for (std::string opName : eliminateOpFilter)
+ eliminateOpNames.insert(opName);
+ if (!barrierOpNames.empty()) {
+ config.barrierOpFilter = [&](Operation *op) -> bool {
+ return barrierOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+ if (!eliminateOpNames.empty()) {
+ config.eliminateOpFilter = [&](Operation *op) -> bool {
+ return eliminateOpNames.count(op->getName().getStringRef().str());
+ };
+ }
+
// Simplify the IR.
IRRewriter rewriter(&getContext());
- CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
+ CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
bool changed = false;
driver.simplify(getOperation(), &changed);
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 11a33102684733..5d2da75db6ce2f 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -1,32 +1,47 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
-#map0 = affine_map<(d0) -> (d0 mod 2)>
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER
// CHECK-LABEL: @simple_constant
+// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant
func.func @simple_constant() -> (i32, i32) {
// CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
%1 = arith.constant 1 : i32
return %0, %1 : i32, i32
}
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
// CHECK-LABEL: @basic
+// CHECK-ELIMINATE-FILTER-LABEL: @basic
func.func @basic() -> (index, index) {
// CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
+ // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 0 : index
// CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
+ // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
%0 = affine.apply #map0(%c0)
%1 = affine.apply #map0(%c1)
// CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
+ // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index
return %0, %1 : index, index
}
+// -----
+
// CHECK-LABEL: @many
func.func @many(f32, f32) -> (f32) {
^bb0(%a : f32, %b : f32):
@@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) {
return %l : f32
}
+// -----
+
/// Check that operations are not eliminated if they have different operands.
// CHECK-LABEL: @different_ops
func.func @different_ops() -> (i32, i32) {
@@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) {
return %0, %1 : i32, i32
}
+// -----
+
/// Check that operations are not eliminated if they have different result
/// types.
// CHECK-LABEL: @different_results
@@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4
return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
}
+// -----
+
/// Check that operations are not eliminated if they have different attributes.
// CHECK-LABEL: @different_attributes
func.func @different_attributes(index, index) -> (i1, i1, i1) {
@@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) {
return %0, %1, %2 : i1, i1, i1
}
+// -----
+
/// Check that operations with side effects are not eliminated.
// CHECK-LABEL: @side_effect
func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
@@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
}
+// -----
+
/// Check that operation definitions are properly propagated down the dominance
/// tree.
// CHECK-LABEL: @down_propagate_for
+// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for
func.func @down_propagate_for() {
// CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
+ // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%0 = arith.constant 1 : i32
// CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+ // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 {
affine.for %i = 0 to 4 {
- // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
%1 = arith.constant 1 : i32
+
+ // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
+ // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> ()
"foo"(%0, %1) : (i32, i32) -> ()
}
return
}
+// -----
+
// CHECK-LABEL: @down_propagate
func.func @down_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
@@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 {
return %arg : i32
}
+// -----
+
/// Check that operation definitions are NOT propagated up the dominance tree.
// CHECK-LABEL: @up_propagate_for
func.func @up_propagate_for() -> i32 {
@@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 {
return %1 : i32
}
+// -----
+
// CHECK-LABEL: func @up_propagate
func.func @up_propagate() -> i32 {
// CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
@@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 {
return %add : i32
}
+// -----
+
/// The same test as above except that we are testing on a cfg embedded within
/// an operation region.
// CHECK-LABEL: func @up_propagate_region
@@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 {
return %0 : i32
}
+// -----
+
/// This test checks that nested regions that are isolated from above are
/// properly handled.
// CHECK-LABEL: @nested_isolated
@@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 {
return %0 : i32
}
+// -----
+
/// This test is checking that CSE gracefully handles values in graph regions
/// where the use occurs before the def, and one of the defs could be CSE'd with
/// the other.
@@ -269,6 +312,8 @@ func.func @use_before_def() {
return
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_direct_duplicated_read_op
@@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 {
return %2 : i32
}
+// -----
+
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_multiple_duplicated_read_op
@@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 {
return %6 : i64
}
+// -----
+
/// This test is checking that CSE is not removing duplicated read op that
/// have write op in between.
// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
@@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
return %2 : i32
}
+// -----
+
// Check that an operation with a single region can CSE.
func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operations with different number of bbArgs dont CSE.
func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operations with different regions dont CSE
func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
// Operation with identical region with multiple statements CSE.
func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens
// CHECK-NOT: test.cse_of_single_block_op
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
// Operation with non-identical regions dont CSE.
func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
// CHECK: return %[[OP0]], %[[OP1]]
+// -----
+
func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
%false_2 = arith.constant false
%true_5 = arith.constant true
@@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor
// CHECK: test.region_yield %[[TRUE]]
// CHECK: return %[[OP]], %[[OP]]
+// -----
+
func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%r1 = scf.if %c -> (tensor<5xf32>) {
%0 = tensor.empty() : tensor<5xf32>
@@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
// CHECK-NOT: scf.if
// CHECK: return %[[if]], %[[if]]
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_success
func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
@@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
return %0, %2, %1 : i32, i32, i32
}
+// -----
+
// CHECK-LABEL: @cse_recursive_effects_failure
func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
More information about the Mlir-commits
mailing list