[Mlir-commits] [mlir] dea33c8 - [mlir][Transforms] teach CSE about recursive memory effects

Tom Eccles llvmlistbot at llvm.org
Thu Aug 10 02:40:50 PDT 2023


Author: Tom Eccles
Date: 2023-08-10T09:40:01Z
New Revision: dea33c80d3666f1e8368ef1a3a09adf999e31723

URL: https://github.com/llvm/llvm-project/commit/dea33c80d3666f1e8368ef1a3a09adf999e31723
DIFF: https://github.com/llvm/llvm-project/commit/dea33c80d3666f1e8368ef1a3a09adf999e31723.diff

LOG: [mlir][Transforms] teach CSE about recursive memory effects

Add support for reasoning about operations with recursive memory effects
to CSE. The recursive effects are gathered by a helper function. I
decided to allow returning duplicates from the helper function because
there's no benefit to spending the computation time to remove them in
the existing use case.

Differential Revision: https://reviews.llvm.org/D156805

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/SideEffectInterfaces.h
    mlir/lib/Interfaces/SideEffectInterfaces.cpp
    mlir/lib/Transforms/CSE.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
    mlir/test/Transforms/cse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index dfbb65ee6d4881..74fb96662934a9 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -332,6 +332,17 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns the side effects of an operation. If the operation has
+/// RecursiveMemoryEffects, include all side effects of child operations.
+///
+/// std::nullopt indicates that an option did not have a memory effect interface
+/// and so no result could be obtained. An empty vector indicates that there
+/// were no memory effects found (but every operation implemented the memory
+/// effect interface or has RecursiveMemoryEffects). If the vector contains
+/// multiple effects, these effects may be duplicates.
+std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
+getEffectsRecursively(Operation *rootOp);
+
 /// Returns true if the given operation is speculatable, i.e. has no undefined
 /// behavior or other side effects.
 ///

diff  --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index 76ae27fab5a0c0..967225333ae048 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -182,6 +182,39 @@ bool mlir::isMemoryEffectFree(Operation *op) {
   return true;
 }
 
+// the returned vector may contain duplicate effects
+std::optional<llvm::SmallVector<MemoryEffects::EffectInstance>>
+mlir::getEffectsRecursively(Operation *rootOp) {
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  SmallVector<Operation *> effectingOps(1, rootOp);
+  while (!effectingOps.empty()) {
+    Operation *op = effectingOps.pop_back_val();
+
+    // If the operation has recursive effects, push all of the nested
+    // operations on to the stack to consider.
+    bool hasRecursiveEffects =
+        op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+    if (hasRecursiveEffects) {
+      for (Region &region : op->getRegions()) {
+        for (Block &block : region) {
+          for (Operation &nestedOp : block) {
+            effectingOps.push_back(&nestedOp);
+          }
+        }
+      }
+    }
+
+    if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+      effectInterface.getEffects(effects);
+    } else if (!hasRecursiveEffects) {
+      // the operation does not have recursive memory effects or implement
+      // the memory effect op interface. Its effects are unknown.
+      return std::nullopt;
+    }
+  }
+  return effects;
+}
+
 bool mlir::isSpeculatable(Operation *op) {
   auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
   if (!conditionallySpeculatable)

diff  --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index fc0d50949282d9..3affd88d158de5 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -199,17 +199,23 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
     }
   }
   while (nextOp && nextOp != toOp) {
-    auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
-    // TODO: Do we need to handle other effects generically?
-    // If the operation does not implement the MemoryEffectOpInterface we
-    // conservatively assumes it writes.
-    if ((nextOpMemEffects &&
-         nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
-        !nextOpMemEffects) {
+    std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
+        getEffectsRecursively(nextOp);
+    if (!effects) {
+      // TODO: Do we need to handle other effects generically?
+      // If the operation does not implement the MemoryEffectOpInterface we
+      // conservatively assume it writes.
       result.first->second =
           std::make_pair(nextOp, MemoryEffects::Write::get());
       return true;
     }
+
+    for (const MemoryEffects::EffectInstance &effect : *effects) {
+      if (isa<MemoryEffects::Write>(effect.getEffect())) {
+        result.first->second = {nextOp, MemoryEffects::Write::get()};
+        return true;
+      }
+    }
     nextOp = nextOp->getNextNode();
   }
   result.first->second = std::make_pair(toOp, nullptr);

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index b0c1d5469cfd29..170f851138f82a 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -332,8 +332,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK:               scf.yield %[[VAL_145]]
 // CHECK:             }
 // CHECK:             %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
-// CHECK:             %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
-// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]]
+// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
 // CHECK:             %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
 // CHECK:             %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
 // CHECK:               %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
@@ -529,4 +528,4 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
 func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
   sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index d2b8b6a9316ae1..866b228917ad37 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -142,9 +142,7 @@
 // CHECK:                   scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
 // CHECK:                 }
 // CHECK:                 %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
-// CHECK:                 %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
-// CHECK:                 %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
-// CHECK:                 memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:                 memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
 // CHECK:                 scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
 // CHECK:               }
 // CHECK:               %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {

diff  --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index f3a820f8a765be..c764d2b9bd57d8 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -459,3 +459,64 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te
 //       CHECK:   }
 //   CHECK-NOT:   scf.if
 //       CHECK:   return %[[if]], %[[if]]
+
+// CHECK-LABEL: @cse_recursive_effects_success
+func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
+  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+  %0 = "test.op_with_memread"() : () -> (i32)
+
+  // do something with recursive effects, containing no side effects
+  %true = arith.constant true
+  // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
+  // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
+  %1 = scf.if %true -> (i32) {
+    %c42 = arith.constant 42 : i32
+    scf.yield %c42 : i32
+    // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
+    // CHECK-NEXT: scf.yield %[[C42]]
+    // CHECK-NEXT: } else {
+  } else {
+    %c24 = arith.constant 24 : i32
+    scf.yield %c24 : i32
+    // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
+    // CHECK-NEXT: scf.yield %[[C24]]
+    // CHECK-NEXT: }
+  }
+
+  // %2 can be removed
+  // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32
+  %2 = "test.op_with_memread"() : () -> (i32)
+  return %0, %2, %1 : i32, i32, i32
+}
+
+// CHECK-LABEL: @cse_recursive_effects_failure
+func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
+  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+  %0 = "test.op_with_memread"() : () -> (i32)
+
+  // do something with recursive effects, containing a write effect
+  %true = arith.constant true
+  // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
+  // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
+  %1 = scf.if %true -> (i32) {
+    "test.op_with_memwrite"() : () -> ()
+    // CHECK-NEXT: "test.op_with_memwrite"() : () -> ()
+    %c42 = arith.constant 42 : i32
+    scf.yield %c42 : i32
+    // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
+    // CHECK-NEXT: scf.yield %[[C42]]
+    // CHECK-NEXT: } else {
+  } else {
+    %c24 = arith.constant 24 : i32
+    scf.yield %c24 : i32
+    // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
+    // CHECK-NEXT: scf.yield %[[C24]]
+    // CHECK-NEXT: }
+  }
+
+  // %2 can not be be removed because of the write
+  // CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32
+  // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32
+  %2 = "test.op_with_memread"() : () -> (i32)
+  return %0, %2, %1 : i32, i32, i32
+}


        


More information about the Mlir-commits mailing list