[Mlir-commits] [mlir] 6f5da84 - [mlir] Extended BufferPlacement to support nested region control flow.
Stephan Herhut
llvmlistbot at llvm.org
Tue Jun 30 03:10:16 PDT 2020
Author: Marcel Koester
Date: 2020-06-30T12:10:01+02:00
New Revision: 6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2
URL: https://github.com/llvm/llvm-project/commit/6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2
DIFF: https://github.com/llvm/llvm-project/commit/6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2.diff
LOG: [mlir] Extended BufferPlacement to support nested region control flow.
Summary: The current BufferPlacement implementation does not support
nested region control flow. This CL adds support for nested regions via
the RegionBranchOpInterface and the detection of branch-like
(ReturnLike) terminators inside nested regions.
Differential Revision: https://reviews.llvm.org/D81926
Added:
Modified:
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Transforms/buffer-placement.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 74c534701536..577d52188351 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -65,8 +65,18 @@
using namespace mlir;
-namespace {
+/// Walks over all immediate return-like terminators in the given region.
+template <typename FuncT>
+static void walkReturnOperations(Region *region, const FuncT &func) {
+ for (Block &block : *region)
+ for (Operation &operation : block) {
+ // Skip non-return-like terminators.
+ if (operation.hasTrait<OpTrait::ReturnLike>())
+ func(&operation);
+ }
+}
+namespace {
//===----------------------------------------------------------------------===//
// BufferPlacementAliasAnalysis
//===----------------------------------------------------------------------===//
@@ -82,7 +92,7 @@ class BufferPlacementAliasAnalysis {
public:
/// Constructs a new alias analysis using the op provided.
- BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); }
+ BufferPlacementAliasAnalysis(Operation *op) { build(op); }
/// Find all immediate aliases this value could potentially have.
ValueMapT::const_iterator find(Value value) const {
@@ -102,7 +112,7 @@ class BufferPlacementAliasAnalysis {
}
/// Removes the given values from all alias sets.
- void remove(const SmallPtrSetImpl<BlockArgument> &aliasValues) {
+ void remove(const SmallPtrSetImpl<Value> &aliasValues) {
for (auto &entry : aliases)
llvm::set_subtract(entry.second, aliasValues);
}
@@ -123,33 +133,69 @@ class BufferPlacementAliasAnalysis {
/// This function constructs a mapping from values to its immediate aliases.
/// It iterates over all blocks, gets their predecessors, determines the
/// values that will be passed to the corresponding block arguments and
- /// inserts them into the underlying map.
- void build(MutableArrayRef<Region> regions) {
- for (Region ®ion : regions) {
- for (Block &block : region) {
- // Iterate over all predecessor and get the mapped values to their
- // corresponding block arguments values.
- for (auto it = block.pred_begin(), e = block.pred_end(); it != e;
- ++it) {
- unsigned successorIndex = it.getSuccessorIndex();
- // Get the terminator and the values that will be passed to our block.
- auto branchInterface =
- dyn_cast<BranchOpInterface>((*it)->getTerminator());
- if (!branchInterface)
- continue;
- // Query the branch op interace to get the successor operands.
- auto successorOperands =
- branchInterface.getSuccessorOperands(successorIndex);
- if (successorOperands.hasValue()) {
- // Build the actual mapping of values to their immediate aliases.
- for (auto argPair : llvm::zip(block.getArguments(),
- successorOperands.getValue())) {
- aliases[std::get<1>(argPair)].insert(std::get<0>(argPair));
- }
- }
+ /// inserts them into the underlying map. Furthermore, it wires successor
+ /// regions and branch-like return operations from nested regions.
+ void build(Operation *op) {
+ // Registers all aliases of the given values.
+ auto registerAliases = [&](auto values, auto aliases) {
+ for (auto entry : llvm::zip(values, aliases))
+ this->aliases[std::get<0>(entry)].insert(std::get<1>(entry));
+ };
+
+ // Query all branch interfaces to link block argument aliases.
+ op->walk([&](BranchOpInterface branchInterface) {
+ Block *parentBlock = branchInterface.getOperation()->getBlock();
+ for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
+ it != e; ++it) {
+ // Query the branch op interface to get the successor operands.
+ auto successorOperands =
+ branchInterface.getSuccessorOperands(it.getIndex());
+ if (!successorOperands.hasValue())
+ continue;
+ // Build the actual mapping of values to their immediate aliases.
+ registerAliases(successorOperands.getValue(), (*it)->getArguments());
+ }
+ });
+
+ // Query the RegionBranchOpInterface to find potential successor regions.
+ op->walk([&](RegionBranchOpInterface regionInterface) {
+ // Create an empty attribute for each operand to comply with the
+ // `getSuccessorRegions` interface definition that requires a single
+ // attribute per operand.
+ SmallVector<Attribute, 2> operandAttributes(
+ regionInterface.getOperation()->getNumOperands());
+
+ // Extract all entry regions and wire all initial entry successor inputs.
+ SmallVector<RegionSuccessor, 2> entrySuccessors;
+ regionInterface.getSuccessorRegions(/*index=*/llvm::None,
+ operandAttributes, entrySuccessors);
+ for (RegionSuccessor &entrySuccessor : entrySuccessors) {
+ // Wire the entry region's successor arguments with the initial
+ // successor inputs.
+ assert(entrySuccessor.getSuccessor() &&
+ "Invalid entry region without an attached successor region");
+ registerAliases(regionInterface.getSuccessorEntryOperands(
+ entrySuccessor.getSuccessor()->getRegionNumber()),
+ entrySuccessor.getSuccessorInputs());
+ }
+
+ // Wire flow between regions and from region exits.
+ for (Region ®ion : regionInterface.getOperation()->getRegions()) {
+ // Iterate over all successor region entries that are reachable from the
+ // current region.
+ SmallVector<RegionSuccessor, 2> successorRegions;
+ regionInterface.getSuccessorRegions(
+ region.getRegionNumber(), operandAttributes, successorRegions);
+ for (RegionSuccessor &successorRegion : successorRegions) {
+ // Iterate over all immediate terminator operations and wire the
+ // successor inputs with the operands of each terminator.
+ walkReturnOperations(®ion, [&](Operation *terminator) {
+ registerAliases(terminator->getOperands(),
+ successorRegion.getSuccessorInputs());
+ });
}
}
- }
+ });
}
/// Maps values to all immediate aliases this value can have.
@@ -235,14 +281,24 @@ class BufferPlacement {
Block *getInitialAllocBlock(OpResult result) {
// Get all allocation operands as these operands are important for the
// allocation operation.
- auto operands = result.getOwner()->getOperands();
+ Operation *owner = result.getOwner();
+ auto operands = owner->getOperands();
+ Block *dominator;
if (operands.size() < 1)
- return findCommonDominator(result, aliases.resolve(result), dominators);
+ dominator =
+ findCommonDominator(result, aliases.resolve(result), dominators);
+ else {
+ // If this node has dependencies, check all dependent nodes with respect
+ // to a common post dominator in which all values are available.
+ ValueSetT dependencies(++operands.begin(), operands.end());
+ dominator =
+ findCommonDominator(*operands.begin(), dependencies, postDominators);
+ }
- // If this node has dependencies, check all dependent nodes with respect
- // to a common post dominator in which all values are available.
- ValueSetT dependencies(++operands.begin(), operands.end());
- return findCommonDominator(*operands.begin(), dependencies, postDominators);
+ // Do not move allocs out of their parent regions to keep them local.
+ if (dominator->getParent() != owner->getParentRegion())
+ return &owner->getParentRegion()->front();
+ return dominator;
}
/// Finds correct alloc positions according to the algorithm described at
@@ -273,12 +329,12 @@ class BufferPlacement {
/// Introduces required allocs and copy operations to avoid memory leaks.
void introduceCopies() {
- // Initialize the set of block arguments that require a dedicated memory
- // free operation since their arguments cannot be safely deallocated in a
- // post dominator.
- SmallPtrSet<BlockArgument, 8> blockArgsToFree;
- llvm::SmallDenseSet<std::tuple<BlockArgument, Block *>> visitedBlockArgs;
- SmallVector<std::tuple<BlockArgument, Block *>, 8> toProcess;
+ // Initialize the set of values that require a dedicated memory free
+ // operation since their operands cannot be safely deallocated in a post
+ // dominator.
+ SmallPtrSet<Value, 8> valuesToFree;
+ llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
+ SmallVector<std::tuple<Value, Block *>, 8> toProcess;
// Check dominance relation for proper dominance properties. If the given
// value node does not dominate an alias, we will have to create a copy in
@@ -289,17 +345,15 @@ class BufferPlacement {
if (it == aliases.end())
return;
for (Value value : it->second) {
- auto blockArg = value.cast<BlockArgument>();
- if (blockArgsToFree.count(blockArg) > 0)
+ if (valuesToFree.count(value) > 0)
continue;
// Check whether we have to free this particular block argument.
- if (!dominators.dominates(definingBlock, blockArg.getOwner())) {
- toProcess.emplace_back(blockArg, blockArg.getParentBlock());
- blockArgsToFree.insert(blockArg);
- } else if (visitedBlockArgs
- .insert(std::make_tuple(blockArg, definingBlock))
+ if (!dominators.dominates(definingBlock, value.getParentBlock())) {
+ toProcess.emplace_back(value, value.getParentBlock());
+ valuesToFree.insert(value);
+ } else if (visitedValues.insert(std::make_tuple(value, definingBlock))
.second)
- toProcess.emplace_back(blockArg, definingBlock);
+ toProcess.emplace_back(value, definingBlock);
}
};
@@ -316,62 +370,170 @@ class BufferPlacement {
// Update buffer aliases to ensure that we free all buffers and block
// arguments at the correct locations.
- aliases.remove(blockArgsToFree);
+ aliases.remove(valuesToFree);
// Add new allocs and additional copy operations.
- for (BlockArgument blockArg : blockArgsToFree) {
- Block *block = blockArg.getOwner();
-
- // Allocate a buffer for the current block argument in the block of
- // the associated value (which will be a predecessor block by
- // definition).
- for (auto it = block->pred_begin(), e = block->pred_end(); it != e;
- ++it) {
- // Get the terminator and the value that will be passed to our
- // argument.
- Operation *terminator = (*it)->getTerminator();
- auto branchInterface = cast<BranchOpInterface>(terminator);
- // Convert the mutable operand range to an immutable range and query the
- // associated source value.
- Value sourceValue =
- branchInterface.getSuccessorOperands(it.getSuccessorIndex())
- .getValue()[blockArg.getArgNumber()];
- // Create a new alloc at the current location of the terminator.
- auto memRefType = sourceValue.getType().cast<MemRefType>();
- OpBuilder builder(terminator);
-
- // Extract information about dynamically shaped types by
- // extracting their dynamic dimensions.
- SmallVector<Value, 4> dynamicOperands;
- for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
- if (!ShapedType::isDynamic(shapeElement.value()))
- continue;
- dynamicOperands.push_back(builder.create<DimOp>(
- terminator->getLoc(), sourceValue, shapeElement.index()));
- }
+ for (Value value : valuesToFree) {
+ if (auto blockArg = value.dyn_cast<BlockArgument>())
+ introduceBlockArgCopy(blockArg);
+ else
+ introduceValueCopyForRegionResult(value);
+
+ // Register the value to require a final dealloc. Note that we do not have
+ // to assign a block here since we do not want to move the allocation node
+ // to another location.
+ allocs.push_back({value, nullptr, nullptr});
+ }
+ }
- // TODO: provide a generic interface to create dialect-specific
- // Alloc and CopyOp nodes.
- auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
- dynamicOperands);
- // Wire new alloc and successor operand.
- branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex())
- .getValue()
+ /// Introduces temporary allocs in all predecessors and copies the source
+ /// values into the newly allocated buffers.
+ void introduceBlockArgCopy(BlockArgument blockArg) {
+ // Allocate a buffer for the current block argument in the block of
+ // the associated value (which will be a predecessor block by
+ // definition).
+ Block *block = blockArg.getOwner();
+ for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
+ // Get the terminator and the value that will be passed to our
+ // argument.
+ Operation *terminator = (*it)->getTerminator();
+ auto branchInterface = cast<BranchOpInterface>(terminator);
+ // Query the associated source value.
+ Value sourceValue =
+ branchInterface.getSuccessorOperands(it.getSuccessorIndex())
+ .getValue()[blockArg.getArgNumber()];
+ // Create a new alloc and copy at the current location of the terminator.
+ Value alloc = introduceBufferCopy(sourceValue, terminator);
+ // Wire new alloc and successor operand.
+ auto mutableOperands =
+ branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
+ if (!mutableOperands.hasValue())
+ terminator->emitError() << "terminators with immutable successor "
+ "operands are not supported";
+ else
+ mutableOperands.getValue()
.slice(blockArg.getArgNumber(), 1)
.assign(alloc);
- // Create a new copy operation that copies to contents of the old
- // allocation to the new one.
- builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue,
- alloc);
- }
+ }
- // Register the block argument to require a final dealloc. Note that
- // we do not have to assign a block here since we do not want to
- // move the allocation node to another location.
- allocs.push_back({blockArg, nullptr, nullptr});
+ // Check whether the block argument has implicitly defined predecessors via
+ // the RegionBranchOpInterface. This can be the case if the current block
+ // argument belongs to the first block in a region and the parent operation
+ // implements the RegionBranchOpInterface.
+ Region *argRegion = block->getParent();
+ RegionBranchOpInterface regionInterface;
+ if (!argRegion || &argRegion->front() != block ||
+ !(regionInterface =
+ dyn_cast<RegionBranchOpInterface>(argRegion->getParentOp())))
+ return;
+
+ introduceCopiesForRegionSuccessors(
+ regionInterface, argRegion->getParentOp()->getRegions(),
+ [&](RegionSuccessor &successorRegion) {
+ // Find a predecessor of our argRegion.
+ return successorRegion.getSuccessor() == argRegion;
+ },
+ [&](RegionSuccessor &successorRegion) {
+ // The operand index will be the argument number.
+ return blockArg.getArgNumber();
+ });
+ }
+
+ /// Introduces temporary allocs in front of all associated nested-region
+ /// terminators and copies the source values into the newly allocated buffers.
+ void introduceValueCopyForRegionResult(Value value) {
+ // Get the actual result index in the scope of the parent terminator.
+ Operation *operation = value.getDefiningOp();
+ auto regionInterface = cast<RegionBranchOpInterface>(operation);
+ introduceCopiesForRegionSuccessors(
+ regionInterface, operation->getRegions(),
+ [&](RegionSuccessor &successorRegion) {
+ // Determine whether this region has a successor entry that leaves
+ // this region by returning to its parent operation.
+ return !successorRegion.getSuccessor();
+ },
+ [&](RegionSuccessor &successorRegion) {
+ // Find the associated success input index.
+ return llvm::find(successorRegion.getSuccessorInputs(), value)
+ .getIndex();
+ });
+ }
+
+ /// Introduces buffer copies for all terminators in the given regions. The
+ /// regionPredicate is applied to every successor region in order to restrict
+ /// the copies to specific regions. Thereby, the operandProvider is invoked
+ /// for each matching region successor and determines the operand index that
+ /// requires a buffer copy.
+ template <typename TPredicate, typename TOperandProvider>
+ void
+ introduceCopiesForRegionSuccessors(RegionBranchOpInterface regionInterface,
+ MutableArrayRef<Region> regions,
+ const TPredicate ®ionPredicate,
+ const TOperandProvider &operandProvider) {
+ // Create an empty attribute for each operand to comply with the
+ // `getSuccessorRegions` interface definition that requires a single
+ // attribute per operand.
+ SmallVector<Attribute, 2> operandAttributes(
+ regionInterface.getOperation()->getNumOperands());
+ for (Region ®ion : regions) {
+ // Query the regionInterface to get all successor regions of the current
+ // one.
+ SmallVector<RegionSuccessor, 2> successorRegions;
+ regionInterface.getSuccessorRegions(region.getRegionNumber(),
+ operandAttributes, successorRegions);
+ // Try to find a matching region successor.
+ RegionSuccessor *regionSuccessor =
+ llvm::find_if(successorRegions, regionPredicate);
+ if (regionSuccessor == successorRegions.end())
+ continue;
+ // Get the operand index in the context of the current successor input
+ // bindings.
+ auto operandIndex = operandProvider(*regionSuccessor);
+
+ // Iterate over all immediate terminator operations to introduce
+ // new buffer allocations. Thereby, the appropriate terminator operand
+ // will be adjusted to point to the newly allocated buffer instead.
+ walkReturnOperations(®ion, [&](Operation *terminator) {
+ // Extract the source value from the current terminator.
+ Value sourceValue = terminator->getOperand(operandIndex);
+ // Create a new alloc at the current location of the terminator.
+ Value alloc = introduceBufferCopy(sourceValue, terminator);
+ // Wire alloc and terminator operand.
+ terminator->setOperand(operandIndex, alloc);
+ });
}
}
+ /// Creates a new memory allocation for the given source value and copies
+ /// its content into the newly allocated buffer. The terminator operation is
+ /// used to insert the alloc and copy operations at the right places.
+ Value introduceBufferCopy(Value sourceValue, Operation *terminator) {
+ // Create a new alloc at the current location of the terminator.
+ auto memRefType = sourceValue.getType().cast<MemRefType>();
+ OpBuilder builder(terminator);
+
+ // Extract information about dynamically shaped types by
+ // extracting their dynamic dimensions.
+ SmallVector<Value, 4> dynamicOperands;
+ for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
+ if (!ShapedType::isDynamic(shapeElement.value()))
+ continue;
+ dynamicOperands.push_back(builder.create<DimOp>(
+ terminator->getLoc(), sourceValue, shapeElement.index()));
+ }
+
+ // TODO: provide a generic interface to create dialect-specific
+ // Alloc and CopyOp nodes.
+ auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
+ dynamicOperands);
+
+ // Create a new copy operation that copies to contents of the old
+ // allocation to the new one.
+ builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue, alloc);
+
+ return alloc;
+ }
+
/// Finds associated deallocs that can be linked to our allocation nodes (if
/// any).
void findDeallocs() {
@@ -440,8 +602,8 @@ class BufferPlacement {
if (entry.deallocOperation) {
entry.deallocOperation->moveAfter(endOperation);
} else {
- // If the Dealloc position is at the terminator operation of the block,
- // then the value should escape from a deallocation.
+ // If the Dealloc position is at the terminator operation of the
+ // block, then the value should escape from a deallocation.
Operation *nextOp = endOperation->getNextNode();
if (!nextOp)
continue;
diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir
index 93b6b5ade058..225a186caeb0 100644
--- a/mlir/test/Transforms/buffer-placement.mlir
+++ b/mlir/test/Transforms/buffer-placement.mlir
@@ -716,3 +716,201 @@ func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %a
// CHECK: dealloc %[[Y]]
// CHECK: return %[[ARG1]], %[[X]]
+// -----
+
+// Test Case: nested region control flow
+// The alloc position of %1 does not need to be changed and flows through
+// both if branches until it is finally returned. Hence, it does not
+// require a specific dealloc operation. However, %3 requires a dealloc.
+
+// CHECK-LABEL: func @nested_region_control_flow
+func @nested_region_control_flow(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = cmpi "eq", %arg0, %arg1 : index
+ %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %3 = alloc(%arg0, %arg1) : memref<?x?xf32>
+ scf.yield %1 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+
+// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+// CHECK: scf.yield %[[ALLOC0]]
+// CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1)
+// CHECK-NEXT: dealloc %[[ALLOC2]]
+// CHECK-NEXT: scf.yield %[[ALLOC0]]
+// CHECK: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow with a nested buffer allocation in a
+// divergent branch.
+// The alloc positions of %1, %3 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %3 is allocated and "returned" in a divergent branch, we have
+// to allocate a temporary buffer (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @nested_region_control_flow_div
+func @nested_region_control_flow_div(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = cmpi "eq", %arg0, %arg1 : index
+ %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %3 = alloc(%arg0, %arg1) : memref<?x?xf32>
+ scf.yield %3 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+
+// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+// CHECK: %[[ALLOC2:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]])
+// CHECK: scf.yield %[[ALLOC2]]
+// CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1)
+// CHECK: %[[ALLOC4:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]])
+// CHECK: dealloc %[[ALLOC3]]
+// CHECK: scf.yield %[[ALLOC4]]
+// CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: deeply nested region control flow with a nested buffer allocation
+// in a divergent branch.
+// The alloc positions of %1, %4 and %5 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %4 is allocated and "returned" in a divergent branch, we have
+// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @nested_region_control_flow_div_nested
+func @nested_region_control_flow_div_nested(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = cmpi "eq", %arg0, %arg1 : index
+ %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+ %2 = scf.if %0 -> (memref<?x?xf32>) {
+ %3 = scf.if %0 -> (memref<?x?xf32>) {
+ scf.yield %1 : memref<?x?xf32>
+ } else {
+ %4 = alloc(%arg0, %arg1) : memref<?x?xf32>
+ scf.yield %4 : memref<?x?xf32>
+ }
+ scf.yield %3 : memref<?x?xf32>
+ } else {
+ %5 = alloc(%arg1, %arg1) : memref<?x?xf32>
+ scf.yield %5 : memref<?x?xf32>
+ }
+ return %2 : memref<?x?xf32>
+}
+// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if
+// CHECK: %[[ALLOC3:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC3]])
+// CHECK: scf.yield %[[ALLOC3]]
+// CHECK: %[[ALLOC4:.*]] = alloc(%arg0, %arg1)
+// CHECK: %[[ALLOC5:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]])
+// CHECK: dealloc %[[ALLOC4]]
+// CHECK: scf.yield %[[ALLOC5]]
+// CHECK: %[[ALLOC6:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC6]])
+// CHECK: dealloc %[[ALLOC2]]
+// CHECK: scf.yield %[[ALLOC6]]
+// CHECK: %[[ALLOC7:.*]] = alloc(%arg1, %arg1)
+// CHECK: %[[ALLOC8:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
+// CHECK: dealloc %[[ALLOC7]]
+// CHECK: scf.yield %[[ALLOC8]]
+// CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow within a region interface.
+// The alloc positions of %0 does not need to be changed and no copies are
+// required in this case since the allocation finally escapes the method.
+
+// CHECK-LABEL: func @inner_region_control_flow
+func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
+ %0 = alloc(%arg0, %arg0) : memref<?x?xf32>
+ %1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ } else {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ } join {
+ ^bb0(%arg1 : memref<?x?xf32>):
+ test.region_if_yield %arg1 : memref<?x?xf32>
+ }
+ return %1 : memref<?x?xf32>
+}
+
+// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
+// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC2]]
+// CHECK: ^bb0(%[[ALLOC3:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
+// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC4]]
+// CHECK: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow within a region interface including an
+// allocation in a divergent branch.
+// The alloc positions of %1 and %2 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %2 is allocated and yielded in a divergent branch, we have
+// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @inner_region_control_flow_div
+func @inner_region_control_flow_div(
+ %arg0 : index,
+ %arg1 : index) -> memref<?x?xf32> {
+ %0 = alloc(%arg0, %arg0) : memref<?x?xf32>
+ %1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
+ ^bb0(%arg2 : memref<?x?xf32>):
+ test.region_if_yield %arg2 : memref<?x?xf32>
+ } else {
+ ^bb0(%arg2 : memref<?x?xf32>):
+ %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+ test.region_if_yield %2 : memref<?x?xf32>
+ } join {
+ ^bb0(%arg2 : memref<?x?xf32>):
+ test.region_if_yield %arg2 : memref<?x?xf32>
+ }
+ return %1 : memref<?x?xf32>
+}
+
+// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
+// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
+// CHECK: %[[ALLOC3:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]])
+// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
+// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
+// CHECK: %[[ALLOC5:.*]] = alloc
+// CHECK: %[[ALLOC6:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC5]], %[[ALLOC6]])
+// CHECK-NEXT: dealloc %[[ALLOC5]]
+// CHECK-NEXT: test.region_if_yield %[[ALLOC6]]
+// CHECK: ^bb0(%[[ALLOC7:.*]]:{{.*}}):
+// CHECK: %[[ALLOC8:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
+// CHECK-NEXT: dealloc %[[ALLOC7]]
+// CHECK-NEXT: test.region_if_yield %[[ALLOC8]]
+// CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e4b1793c9399..dab2ea1bff6a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -518,6 +518,77 @@ void StringAttrPrettyNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
+//===----------------------------------------------------------------------===//
+// RegionIfOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, RegionIfOp op) {
+ p << RegionIfOp::getOperationName() << " ";
+ p.printOperands(op.getOperands());
+ p << ": " << op.getOperandTypes();
+ p.printArrowTypeList(op.getResultTypes());
+ p << " then";
+ p.printRegion(op.thenRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " else";
+ p.printRegion(op.elseRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " join";
+ p.printRegion(op.joinRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+}
+
+static ParseResult parseRegionIfOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> operandInfos;
+ SmallVector<Type, 2> operandTypes;
+
+ result.regions.reserve(3);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+ Region *joinRegion = result.addRegion();
+
+ // Parse operand, type and arrow type lists.
+ if (parser.parseOperandList(operandInfos) ||
+ parser.parseColonTypeList(operandTypes) ||
+ parser.parseArrowTypeList(result.types))
+ return failure();
+
+ // Parse all attached regions.
+ if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
+ parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
+ parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
+ return failure();
+
+ return parser.resolveOperands(operandInfos, operandTypes,
+ parser.getCurrentLocation(), result.operands);
+}
+
+OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
+ assert(index < 2 && "invalid region index");
+ return getOperands();
+}
+
+void RegionIfOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // We always branch to the join region.
+ if (index.hasValue()) {
+ if (index.getValue() < 2)
+ regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
+ else
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // The then and else regions are the entry regions of this op.
+ regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
+ regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
+}
+
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bc8f2ff3f818..ddaa27fd0ca8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1349,4 +1349,47 @@ def SideEffectOp : TEST_Op<"side_effect_op",
let results = (outs AnyType:$result);
}
+//===----------------------------------------------------------------------===//
+// Test RegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def RegionIfYieldOp : TEST_Op<"region_if_yield",
+ [NoSideEffect, ReturnLike, Terminator]> {
+ let arguments = (ins Variadic<AnyType>:$results);
+ let assemblyFormat = [{
+ $results `:` type($results) attr-dict
+ }];
+}
+
+def RegionIfOp : TEST_Op<"region_if",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ SingleBlockImplicitTerminator<"RegionIfYieldOp">,
+ RecursiveSideEffects]> {
+ let description =[{
+ Represents an abstract if-then-else-join pattern. In this context, the then
+ and else regions jump to the join region, which finally returns to its
+ parent op.
+ }];
+
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parseRegionIfOp(parser, result); }];
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$thenRegion,
+ AnyRegion:$elseRegion,
+ AnyRegion:$joinRegion);
+ let extraClassDeclaration = [{
+ Block::BlockArgListType getThenArgs() {
+ return getBody(0)->getArguments();
+ }
+ Block::BlockArgListType getElseArgs() {
+ return getBody(1)->getArguments();
+ }
+ Block::BlockArgListType getJoinArgs() {
+ return getBody(2)->getArguments();
+ }
+ OperandRange getSuccessorEntryOperands(unsigned index);
+ }];
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list