[Mlir-commits] [mlir] 38bef47 - [mlir][bufferization] Fix unknown ops in BufferViewFlowAnalysis
Matthias Springer
llvmlistbot at llvm.org
Mon May 15 05:38:36 PDT 2023
Author: Matthias Springer
Date: 2023-05-15T14:33:06+02:00
New Revision: 38bef476552021b7ad45d1aa989d250bcd0a38ff
URL: https://github.com/llvm/llvm-project/commit/38bef476552021b7ad45d1aa989d250bcd0a38ff
DIFF: https://github.com/llvm/llvm-project/commit/38bef476552021b7ad45d1aa989d250bcd0a38ff.diff
LOG: [mlir][bufferization] Fix unknown ops in BufferViewFlowAnalysis
If an op is unknown to the analysis, it must be treated conservatively: assume that every operand aliases with every result.
Differential Revision: https://reviews.llvm.org/D150546
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index b4cfe89d7ced6..d964f801668f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,7 +8,6 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
@@ -58,74 +57,89 @@ void BufferViewFlowAnalysis::build(Operation *op) {
this->dependencies[value].insert(dep);
};
- // Add additional dependencies created by view changes to the alias list.
- op->walk([&](ViewLikeOpInterface viewInterface) {
- dependencies[viewInterface.getViewSource()].insert(
- viewInterface->getResult(0));
- });
+ op->walk([&](Operation *op) {
+ // TODO: We should have an op interface instead of a hard-coded list of
+ // interfaces/ops.
- // Query all branch interfaces to link block argument dependencies.
- op->walk([&](BranchOpInterface branchInterface) {
- Block *parentBlock = branchInterface->getBlock();
- for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
- it != e; ++it) {
- // Query the branch op interface to get the successor operands.
- auto successorOperands =
- branchInterface.getSuccessorOperands(it.getIndex());
- // Build the actual mapping of values to their immediate dependencies.
- registerDependencies(successorOperands.getForwardedOperands(),
- (*it)->getArguments().drop_front(
- successorOperands.getProducedOperandCount()));
+ // Add additional dependencies created by view changes to the alias list.
+ if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
+ dependencies[viewInterface.getViewSource()].insert(
+ viewInterface->getResult(0));
+ return WalkResult::advance();
}
- });
- // Query the RegionBranchOpInterface to find potential successor regions.
- op->walk([&](RegionBranchOpInterface regionInterface) {
- // Extract all entry regions and wire all initial entry successor inputs.
- SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
- entrySuccessors);
- for (RegionSuccessor &entrySuccessor : entrySuccessors) {
- // Wire the entry region's successor arguments with the initial
- // successor inputs.
- assert(entrySuccessor.getSuccessor() &&
- "Invalid entry region without an attached successor region");
- registerDependencies(
- regionInterface.getSuccessorEntryOperands(
- entrySuccessor.getSuccessor()->getRegionNumber()),
- entrySuccessor.getSuccessorInputs());
+ if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
+ // Query all branch interfaces to link block argument dependencies.
+ Block *parentBlock = branchInterface->getBlock();
+ for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
+ it != e; ++it) {
+ // Query the branch op interface to get the successor operands.
+ auto successorOperands =
+ branchInterface.getSuccessorOperands(it.getIndex());
+ // Build the actual mapping of values to their immediate dependencies.
+ registerDependencies(successorOperands.getForwardedOperands(),
+ (*it)->getArguments().drop_front(
+ successorOperands.getProducedOperandCount()));
+ }
+ return WalkResult::advance();
}
- // Wire flow between regions and from region exits.
- for (Region ®ion : regionInterface->getRegions()) {
- // Iterate over all successor region entries that are reachable from the
- // current region.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
- for (RegionSuccessor &successorRegion : successorRegions) {
- // Determine the current region index (if any).
- std::optional<unsigned> regionIndex;
- Region *regionSuccessor = successorRegion.getSuccessor();
- if (regionSuccessor)
- regionIndex = regionSuccessor->getRegionNumber();
- // Iterate over all immediate terminator operations and wire the
- // successor inputs with the successor operands of each terminator.
- for (Block &block : region) {
- auto successorOperands = getRegionBranchSuccessorOperands(
- block.getTerminator(), regionIndex);
- if (successorOperands) {
- registerDependencies(*successorOperands,
- successorRegion.getSuccessorInputs());
+ if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
+ // Query the RegionBranchOpInterface to find potential successor regions.
+ // Extract all entry regions and wire all initial entry successor inputs.
+ SmallVector<RegionSuccessor, 2> entrySuccessors;
+ regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ entrySuccessors);
+ for (RegionSuccessor &entrySuccessor : entrySuccessors) {
+ // Wire the entry region's successor arguments with the initial
+ // successor inputs.
+ assert(entrySuccessor.getSuccessor() &&
+ "Invalid entry region without an attached successor region");
+ registerDependencies(
+ regionInterface.getSuccessorEntryOperands(
+ entrySuccessor.getSuccessor()->getRegionNumber()),
+ entrySuccessor.getSuccessorInputs());
+ }
+
+ // Wire flow between regions and from region exits.
+ for (Region ®ion : regionInterface->getRegions()) {
+ // Iterate over all successor region entries that are reachable from the
+ // current region.
+ SmallVector<RegionSuccessor, 2> successorRegions;
+ regionInterface.getSuccessorRegions(region.getRegionNumber(),
+ successorRegions);
+ for (RegionSuccessor &successorRegion : successorRegions) {
+ // Determine the current region index (if any).
+ std::optional<unsigned> regionIndex;
+ Region *regionSuccessor = successorRegion.getSuccessor();
+ if (regionSuccessor)
+ regionIndex = regionSuccessor->getRegionNumber();
+ // Iterate over all immediate terminator operations and wire the
+ // successor inputs with the successor operands of each terminator.
+ for (Block &block : region) {
+ auto successorOperands = getRegionBranchSuccessorOperands(
+ block.getTerminator(), regionIndex);
+ if (successorOperands) {
+ registerDependencies(*successorOperands,
+ successorRegion.getSuccessorInputs());
+ }
}
}
}
+
+ return WalkResult::advance();
}
- });
- // TODO: This should be an interface.
- op->walk([&](arith::SelectOp selectOp) {
- registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
- registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
+ // Unknown op: Assume that all operands alias with all results.
+ for (Value operand : op->getOperands()) {
+ if (!isa<BaseMemRefType>(operand.getType()))
+ continue;
+ for (Value result : op->getResults()) {
+ if (!isa<BaseMemRefType>(result.getType()))
+ continue;
+ registerDependencies({operand}, {result});
+ }
+ }
+ return WalkResult::advance();
});
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
index 384657222725a..3fbe3913c6549 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
@@ -1317,6 +1317,27 @@ func.func @select_aliases(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
// -----
+func.func @f(%arg0: memref<f64>) -> memref<f64> {
+ return %arg0 : memref<f64>
+}
+
+// CHECK-LABEL: func @function_call
+// CHECK: memref.alloc
+// CHECK: memref.alloc
+// CHECK: call
+// CHECK: test.copy
+// CHECK: memref.dealloc
+// CHECK: memref.dealloc
+func.func @function_call() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = call @f(%alloc) : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// -----
+
// Memref allocated in `then` region and passed back to the parent if op.
#set = affine_set<() : (0 >= 0)>
// CHECK-LABEL: func @test_affine_if_1
More information about the Mlir-commits
mailing list