[Mlir-commits] [mlir] 38bef47 - [mlir][bufferization] Fix unknown ops in BufferViewFlowAnalysis

Matthias Springer llvmlistbot at llvm.org
Mon May 15 05:38:36 PDT 2023


Author: Matthias Springer
Date: 2023-05-15T14:33:06+02:00
New Revision: 38bef476552021b7ad45d1aa989d250bcd0a38ff

URL: https://github.com/llvm/llvm-project/commit/38bef476552021b7ad45d1aa989d250bcd0a38ff
DIFF: https://github.com/llvm/llvm-project/commit/38bef476552021b7ad45d1aa989d250bcd0a38ff.diff

LOG: [mlir][bufferization] Fix unknown ops in BufferViewFlowAnalysis

If an op is unknown to the analysis, it must be treated conservatively: assume that every operand aliases with every result.

Differential Revision: https://reviews.llvm.org/D150546

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index b4cfe89d7ced6..d964f801668f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetOperations.h"
@@ -58,74 +57,89 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       this->dependencies[value].insert(dep);
   };
 
-  // Add additional dependencies created by view changes to the alias list.
-  op->walk([&](ViewLikeOpInterface viewInterface) {
-    dependencies[viewInterface.getViewSource()].insert(
-        viewInterface->getResult(0));
-  });
+  op->walk([&](Operation *op) {
+    // TODO: We should have an op interface instead of a hard-coded list of
+    // interfaces/ops.
 
-  // Query all branch interfaces to link block argument dependencies.
-  op->walk([&](BranchOpInterface branchInterface) {
-    Block *parentBlock = branchInterface->getBlock();
-    for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
-         it != e; ++it) {
-      // Query the branch op interface to get the successor operands.
-      auto successorOperands =
-          branchInterface.getSuccessorOperands(it.getIndex());
-      // Build the actual mapping of values to their immediate dependencies.
-      registerDependencies(successorOperands.getForwardedOperands(),
-                           (*it)->getArguments().drop_front(
-                               successorOperands.getProducedOperandCount()));
+    // Add additional dependencies created by view changes to the alias list.
+    if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
+      dependencies[viewInterface.getViewSource()].insert(
+          viewInterface->getResult(0));
+      return WalkResult::advance();
     }
-  });
 
-  // Query the RegionBranchOpInterface to find potential successor regions.
-  op->walk([&](RegionBranchOpInterface regionInterface) {
-    // Extract all entry regions and wire all initial entry successor inputs.
-    SmallVector<RegionSuccessor, 2> entrySuccessors;
-    regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
-                                        entrySuccessors);
-    for (RegionSuccessor &entrySuccessor : entrySuccessors) {
-      // Wire the entry region's successor arguments with the initial
-      // successor inputs.
-      assert(entrySuccessor.getSuccessor() &&
-             "Invalid entry region without an attached successor region");
-      registerDependencies(
-          regionInterface.getSuccessorEntryOperands(
-              entrySuccessor.getSuccessor()->getRegionNumber()),
-          entrySuccessor.getSuccessorInputs());
+    if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
+      // Query all branch interfaces to link block argument dependencies.
+      Block *parentBlock = branchInterface->getBlock();
+      for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
+           it != e; ++it) {
+        // Query the branch op interface to get the successor operands.
+        auto successorOperands =
+            branchInterface.getSuccessorOperands(it.getIndex());
+        // Build the actual mapping of values to their immediate dependencies.
+        registerDependencies(successorOperands.getForwardedOperands(),
+                             (*it)->getArguments().drop_front(
+                                 successorOperands.getProducedOperandCount()));
+      }
+      return WalkResult::advance();
     }
 
-    // Wire flow between regions and from region exits.
-    for (Region &region : regionInterface->getRegions()) {
-      // Iterate over all successor region entries that are reachable from the
-      // current region.
-      SmallVector<RegionSuccessor, 2> successorRegions;
-      regionInterface.getSuccessorRegions(region.getRegionNumber(),
-                                          successorRegions);
-      for (RegionSuccessor &successorRegion : successorRegions) {
-        // Determine the current region index (if any).
-        std::optional<unsigned> regionIndex;
-        Region *regionSuccessor = successorRegion.getSuccessor();
-        if (regionSuccessor)
-          regionIndex = regionSuccessor->getRegionNumber();
-        // Iterate over all immediate terminator operations and wire the
-        // successor inputs with the successor operands of each terminator.
-        for (Block &block : region) {
-          auto successorOperands = getRegionBranchSuccessorOperands(
-              block.getTerminator(), regionIndex);
-          if (successorOperands) {
-            registerDependencies(*successorOperands,
-                                 successorRegion.getSuccessorInputs());
+    if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
+      // Query the RegionBranchOpInterface to find potential successor regions.
+      // Extract all entry regions and wire all initial entry successor inputs.
+      SmallVector<RegionSuccessor, 2> entrySuccessors;
+      regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+                                          entrySuccessors);
+      for (RegionSuccessor &entrySuccessor : entrySuccessors) {
+        // Wire the entry region's successor arguments with the initial
+        // successor inputs.
+        assert(entrySuccessor.getSuccessor() &&
+               "Invalid entry region without an attached successor region");
+        registerDependencies(
+            regionInterface.getSuccessorEntryOperands(
+                entrySuccessor.getSuccessor()->getRegionNumber()),
+            entrySuccessor.getSuccessorInputs());
+      }
+
+      // Wire flow between regions and from region exits.
+      for (Region &region : regionInterface->getRegions()) {
+        // Iterate over all successor region entries that are reachable from the
+        // current region.
+        SmallVector<RegionSuccessor, 2> successorRegions;
+        regionInterface.getSuccessorRegions(region.getRegionNumber(),
+                                            successorRegions);
+        for (RegionSuccessor &successorRegion : successorRegions) {
+          // Determine the current region index (if any).
+          std::optional<unsigned> regionIndex;
+          Region *regionSuccessor = successorRegion.getSuccessor();
+          if (regionSuccessor)
+            regionIndex = regionSuccessor->getRegionNumber();
+          // Iterate over all immediate terminator operations and wire the
+          // successor inputs with the successor operands of each terminator.
+          for (Block &block : region) {
+            auto successorOperands = getRegionBranchSuccessorOperands(
+                block.getTerminator(), regionIndex);
+            if (successorOperands) {
+              registerDependencies(*successorOperands,
+                                   successorRegion.getSuccessorInputs());
+            }
           }
         }
       }
+
+      return WalkResult::advance();
     }
-  });
 
-  // TODO: This should be an interface.
-  op->walk([&](arith::SelectOp selectOp) {
-    registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
-    registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
+    // Unknown op: Assume that all operands alias with all results.
+    for (Value operand : op->getOperands()) {
+      if (!isa<BaseMemRefType>(operand.getType()))
+        continue;
+      for (Value result : op->getResults()) {
+        if (!isa<BaseMemRefType>(result.getType()))
+          continue;
+        registerDependencies({operand}, {result});
+      }
+    }
+    return WalkResult::advance();
   });
 }

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
index 384657222725a..3fbe3913c6549 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
@@ -1317,6 +1317,27 @@ func.func @select_aliases(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
 
 // -----
 
+func.func @f(%arg0: memref<f64>) -> memref<f64> {
+  return %arg0 : memref<f64>
+}
+
+// CHECK-LABEL: func @function_call
+//       CHECK:   memref.alloc
+//       CHECK:   memref.alloc
+//       CHECK:   call
+//       CHECK:   test.copy
+//       CHECK:   memref.dealloc
+//       CHECK:   memref.dealloc
+func.func @function_call() {
+  %alloc = memref.alloc() : memref<f64>
+  %alloc2 = memref.alloc() : memref<f64>
+  %ret = call @f(%alloc) : (memref<f64>) -> memref<f64>
+  test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+  return
+}
+
+// -----
+
 // Memref allocated in `then` region and passed back to the parent if op.
 #set = affine_set<() : (0 >= 0)>
 // CHECK-LABEL:  func @test_affine_if_1


        


More information about the Mlir-commits mailing list