[Mlir-commits] [mlir] dea33c8 - [mlir][Transforms] teach CSE about recursive memory effects
Tom Eccles
llvmlistbot at llvm.org
Thu Aug 10 02:40:50 PDT 2023
Author: Tom Eccles
Date: 2023-08-10T09:40:01Z
New Revision: dea33c80d3666f1e8368ef1a3a09adf999e31723
URL: https://github.com/llvm/llvm-project/commit/dea33c80d3666f1e8368ef1a3a09adf999e31723
DIFF: https://github.com/llvm/llvm-project/commit/dea33c80d3666f1e8368ef1a3a09adf999e31723.diff
LOG: [mlir][Transforms] teach CSE about recursive memory effects
Add support for reasoning about operations with recursive memory effects
to CSE. The recursive effects are gathered by a helper function. I
decided to allow returning duplicates from the helper function because
there's no benefit to spending the computation time to remove them in
the existing use case.
Differential Revision: https://reviews.llvm.org/D156805
Added:
Modified:
mlir/include/mlir/Interfaces/SideEffectInterfaces.h
mlir/lib/Interfaces/SideEffectInterfaces.cpp
mlir/lib/Transforms/CSE.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
mlir/test/Transforms/cse.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index dfbb65ee6d4881..74fb96662934a9 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -332,6 +332,17 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);
+/// Returns the side effects of an operation. If the operation has
+/// RecursiveMemoryEffects, include all side effects of child operations.
+///
+/// std::nullopt indicates that an option did not have a memory effect interface
+/// and so no result could be obtained. An empty vector indicates that there
+/// were no memory effects found (but every operation implemented the memory
+/// effect interface or has RecursiveMemoryEffects). If the vector contains
+/// multiple effects, these effects may be duplicates.
+std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
+getEffectsRecursively(Operation *rootOp);
+
/// Returns true if the given operation is speculatable, i.e. has no undefined
/// behavior or other side effects.
///
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index 76ae27fab5a0c0..967225333ae048 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -182,6 +182,39 @@ bool mlir::isMemoryEffectFree(Operation *op) {
return true;
}
+// the returned vector may contain duplicate effects
+std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
+mlir::getEffectsRecursively(Operation *rootOp) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ SmallVector<Operation *> effectingOps(1, rootOp);
+ while (!effectingOps.empty()) {
+ Operation *op = effectingOps.pop_back_val();
+
+ // If the operation has recursive effects, push all of the nested
+ // operations on to the stack to consider.
+ bool hasRecursiveEffects =
+ op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+ if (hasRecursiveEffects) {
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &nestedOp : block) {
+ effectingOps.push_back(&nestedOp);
+ }
+ }
+ }
+ }
+
+ if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+ effectInterface.getEffects(effects);
+ } else if (!hasRecursiveEffects) {
+ // the operation does not have recursive memory effects or implement
+ // the memory effect op interface. Its effects are unknown.
+ return std::nullopt;
+ }
+ }
+ return effects;
+}
+
bool mlir::isSpeculatable(Operation *op) {
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
if (!conditionallySpeculatable)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index fc0d50949282d9..3affd88d158de5 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -199,17 +199,23 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
}
}
while (nextOp && nextOp != toOp) {
- auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
- // TODO: Do we need to handle other effects generically?
- // If the operation does not implement the MemoryEffectOpInterface we
- // conservatively assumes it writes.
- if ((nextOpMemEffects &&
- nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
- !nextOpMemEffects) {
+ std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+ getEffectsRecursively(nextOp);
+ if (!effects) {
+ // TODO: Do we need to handle other effects generically?
+ // If the operation does not implement the MemoryEffectOpInterface we
+ // conservatively assume it writes.
result.first->second =
std::make_pair(nextOp, MemoryEffects::Write::get());
return true;
}
+
+ for (const MemoryEffects::EffectInstance &effect : *effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ result.first->second = {nextOp, MemoryEffects::Write::get()};
+ return true;
+ }
+ }
nextOp = nextOp->getNextNode();
}
result.first->second = std::make_pair(toOp, nullptr);
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index b0c1d5469cfd29..170f851138f82a 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -332,8 +332,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK: scf.yield %[[VAL_145]]
// CHECK: }
// CHECK: %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
-// CHECK: %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
-// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]]
+// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
// CHECK: %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
// CHECK: %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
// CHECK: %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
@@ -529,4 +528,4 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index d2b8b6a9316ae1..866b228917ad37 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -142,9 +142,7 @@
// CHECK: scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
// CHECK: }
// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
-// CHECK: %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
-// CHECK: %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
-// CHECK: memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
// CHECK: scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
// CHECK: }
// CHECK: %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index f3a820f8a765be..c764d2b9bd57d8 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -459,3 +459,64 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
// CHECK: }
// CHECK-NOT: scf.if
// CHECK: return %[[if]], %[[if]]
+
+// CHECK-LABEL: @cse_recursive_effects_success
+func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
+ // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+
+ // do something with recursive effects, containing no side effects
+ %true = arith.constant true
+ // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
+ // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
+ %1 = scf.if %true -> (i32) {
+ %c42 = arith.constant 42 : i32
+ scf.yield %c42 : i32
+ // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
+ // CHECK-NEXT: scf.yield %[[C42]]
+ // CHECK-NEXT: } else {
+ } else {
+ %c24 = arith.constant 24 : i32
+ scf.yield %c24 : i32
+ // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
+ // CHECK-NEXT: scf.yield %[[C24]]
+ // CHECK-NEXT: }
+ }
+
+ // %2 can be removed
+ // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32
+ %2 = "test.op_with_memread"() : () -> (i32)
+ return %0, %2, %1 : i32, i32, i32
+}
+
+// CHECK-LABEL: @cse_recursive_effects_failure
+func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
+ // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+
+ // do something with recursive effects, containing a write effect
+ %true = arith.constant true
+ // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
+ // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
+ %1 = scf.if %true -> (i32) {
+ "test.op_with_memwrite"() : () -> ()
+ // CHECK-NEXT: "test.op_with_memwrite"() : () -> ()
+ %c42 = arith.constant 42 : i32
+ scf.yield %c42 : i32
+ // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
+ // CHECK-NEXT: scf.yield %[[C42]]
+ // CHECK-NEXT: } else {
+ } else {
+ %c24 = arith.constant 24 : i32
+ scf.yield %c24 : i32
+ // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
+ // CHECK-NEXT: scf.yield %[[C24]]
+ // CHECK-NEXT: }
+ }
+
+ // %2 can not be be removed because of the write
+ // CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32
+ // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32
+ %2 = "test.op_with_memread"() : () -> (i32)
+ return %0, %2, %1 : i32, i32, i32
+}
More information about the Mlir-commits
mailing list