[Mlir-commits] [mlir] 6f5da84 - [mlir] Extended BufferPlacement to support nested region control flow.

Stephan Herhut llvmlistbot at llvm.org
Tue Jun 30 03:10:16 PDT 2020


Author: Marcel Koester
Date: 2020-06-30T12:10:01+02:00
New Revision: 6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2

URL: https://github.com/llvm/llvm-project/commit/6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2
DIFF: https://github.com/llvm/llvm-project/commit/6f5da84f7bb31c7c2fcb78e64d5dc3baea1c60f2.diff

LOG: [mlir] Extended BufferPlacement to support nested region control flow.

Summary: The current BufferPlacement implementation does not support
nested region control flow. This CL adds support for nested regions via
the RegionBranchOpInterface and the detection of branch-like
(ReturnLike) terminators inside nested regions.

Differential Revision: https://reviews.llvm.org/D81926

Added: 
    

Modified: 
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/Transforms/buffer-placement.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 74c534701536..577d52188351 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -65,8 +65,18 @@
 
 using namespace mlir;
 
-namespace {
+/// Walks over all immediate return-like terminators in the given region.
+template <typename FuncT>
+static void walkReturnOperations(Region *region, const FuncT &func) {
+  for (Block &block : *region)
+    for (Operation &operation : block) {
+      // Skip non-return-like terminators.
+      if (operation.hasTrait<OpTrait::ReturnLike>())
+        func(&operation);
+    }
+}
 
+namespace {
 //===----------------------------------------------------------------------===//
 // BufferPlacementAliasAnalysis
 //===----------------------------------------------------------------------===//
@@ -82,7 +92,7 @@ class BufferPlacementAliasAnalysis {
 
 public:
   /// Constructs a new alias analysis using the op provided.
-  BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); }
+  BufferPlacementAliasAnalysis(Operation *op) { build(op); }
 
   /// Find all immediate aliases this value could potentially have.
   ValueMapT::const_iterator find(Value value) const {
@@ -102,7 +112,7 @@ class BufferPlacementAliasAnalysis {
   }
 
   /// Removes the given values from all alias sets.
-  void remove(const SmallPtrSetImpl<BlockArgument> &aliasValues) {
+  void remove(const SmallPtrSetImpl<Value> &aliasValues) {
     for (auto &entry : aliases)
       llvm::set_subtract(entry.second, aliasValues);
   }
@@ -123,33 +133,69 @@ class BufferPlacementAliasAnalysis {
   /// This function constructs a mapping from values to its immediate aliases.
   /// It iterates over all blocks, gets their predecessors, determines the
   /// values that will be passed to the corresponding block arguments and
-  /// inserts them into the underlying map.
-  void build(MutableArrayRef<Region> regions) {
-    for (Region &region : regions) {
-      for (Block &block : region) {
-        // Iterate over all predecessor and get the mapped values to their
-        // corresponding block arguments values.
-        for (auto it = block.pred_begin(), e = block.pred_end(); it != e;
-             ++it) {
-          unsigned successorIndex = it.getSuccessorIndex();
-          // Get the terminator and the values that will be passed to our block.
-          auto branchInterface =
-              dyn_cast<BranchOpInterface>((*it)->getTerminator());
-          if (!branchInterface)
-            continue;
-          // Query the branch op interace to get the successor operands.
-          auto successorOperands =
-              branchInterface.getSuccessorOperands(successorIndex);
-          if (successorOperands.hasValue()) {
-            // Build the actual mapping of values to their immediate aliases.
-            for (auto argPair : llvm::zip(block.getArguments(),
-                                          successorOperands.getValue())) {
-              aliases[std::get<1>(argPair)].insert(std::get<0>(argPair));
-            }
-          }
+  /// inserts them into the underlying map. Furthermore, it wires successor
+  /// regions and branch-like return operations from nested regions.
+  void build(Operation *op) {
+    // Registers all aliases of the given values.
+    auto registerAliases = [&](auto values, auto aliases) {
+      for (auto entry : llvm::zip(values, aliases))
+        this->aliases[std::get<0>(entry)].insert(std::get<1>(entry));
+    };
+
+    // Query all branch interfaces to link block argument aliases.
+    op->walk([&](BranchOpInterface branchInterface) {
+      Block *parentBlock = branchInterface.getOperation()->getBlock();
+      for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
+           it != e; ++it) {
+        // Query the branch op interface to get the successor operands.
+        auto successorOperands =
+            branchInterface.getSuccessorOperands(it.getIndex());
+        if (!successorOperands.hasValue())
+          continue;
+        // Build the actual mapping of values to their immediate aliases.
+        registerAliases(successorOperands.getValue(), (*it)->getArguments());
+      }
+    });
+
+    // Query the RegionBranchOpInterface to find potential successor regions.
+    op->walk([&](RegionBranchOpInterface regionInterface) {
+      // Create an empty attribute for each operand to comply with the
+      // `getSuccessorRegions` interface definition that requires a single
+      // attribute per operand.
+      SmallVector<Attribute, 2> operandAttributes(
+          regionInterface.getOperation()->getNumOperands());
+
+      // Extract all entry regions and wire all initial entry successor inputs.
+      SmallVector<RegionSuccessor, 2> entrySuccessors;
+      regionInterface.getSuccessorRegions(/*index=*/llvm::None,
+                                          operandAttributes, entrySuccessors);
+      for (RegionSuccessor &entrySuccessor : entrySuccessors) {
+        // Wire the entry region's successor arguments with the initial
+        // successor inputs.
+        assert(entrySuccessor.getSuccessor() &&
+               "Invalid entry region without an attached successor region");
+        registerAliases(regionInterface.getSuccessorEntryOperands(
+                            entrySuccessor.getSuccessor()->getRegionNumber()),
+                        entrySuccessor.getSuccessorInputs());
+      }
+
+      // Wire flow between regions and from region exits.
+      for (Region &region : regionInterface.getOperation()->getRegions()) {
+        // Iterate over all successor region entries that are reachable from the
+        // current region.
+        SmallVector<RegionSuccessor, 2> successorRegions;
+        regionInterface.getSuccessorRegions(
+            region.getRegionNumber(), operandAttributes, successorRegions);
+        for (RegionSuccessor &successorRegion : successorRegions) {
+          // Iterate over all immediate terminator operations and wire the
+          // successor inputs with the operands of each terminator.
+          walkReturnOperations(&region, [&](Operation *terminator) {
+            registerAliases(terminator->getOperands(),
+                            successorRegion.getSuccessorInputs());
+          });
         }
       }
-    }
+    });
   }
 
   /// Maps values to all immediate aliases this value can have.
@@ -235,14 +281,24 @@ class BufferPlacement {
   Block *getInitialAllocBlock(OpResult result) {
     // Get all allocation operands as these operands are important for the
     // allocation operation.
-    auto operands = result.getOwner()->getOperands();
+    Operation *owner = result.getOwner();
+    auto operands = owner->getOperands();
+    Block *dominator;
     if (operands.size() < 1)
-      return findCommonDominator(result, aliases.resolve(result), dominators);
+      dominator =
+          findCommonDominator(result, aliases.resolve(result), dominators);
+    else {
+      // If this node has dependencies, check all dependent nodes with respect
+      // to a common post dominator in which all values are available.
+      ValueSetT dependencies(++operands.begin(), operands.end());
+      dominator =
+          findCommonDominator(*operands.begin(), dependencies, postDominators);
+    }
 
-    // If this node has dependencies, check all dependent nodes with respect
-    // to a common post dominator in which all values are available.
-    ValueSetT dependencies(++operands.begin(), operands.end());
-    return findCommonDominator(*operands.begin(), dependencies, postDominators);
+    // Do not move allocs out of their parent regions to keep them local.
+    if (dominator->getParent() != owner->getParentRegion())
+      return &owner->getParentRegion()->front();
+    return dominator;
   }
 
   /// Finds correct alloc positions according to the algorithm described at
@@ -273,12 +329,12 @@ class BufferPlacement {
 
   /// Introduces required allocs and copy operations to avoid memory leaks.
   void introduceCopies() {
-    // Initialize the set of block arguments that require a dedicated memory
-    // free operation since their arguments cannot be safely deallocated in a
-    // post dominator.
-    SmallPtrSet<BlockArgument, 8> blockArgsToFree;
-    llvm::SmallDenseSet<std::tuple<BlockArgument, Block *>> visitedBlockArgs;
-    SmallVector<std::tuple<BlockArgument, Block *>, 8> toProcess;
+    // Initialize the set of values that require a dedicated memory free
+    // operation since their operands cannot be safely deallocated in a post
+    // dominator.
+    SmallPtrSet<Value, 8> valuesToFree;
+    llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
+    SmallVector<std::tuple<Value, Block *>, 8> toProcess;
 
     // Check dominance relation for proper dominance properties. If the given
     // value node does not dominate an alias, we will have to create a copy in
@@ -289,17 +345,15 @@ class BufferPlacement {
       if (it == aliases.end())
         return;
       for (Value value : it->second) {
-        auto blockArg = value.cast<BlockArgument>();
-        if (blockArgsToFree.count(blockArg) > 0)
+        if (valuesToFree.count(value) > 0)
           continue;
         // Check whether we have to free this particular block argument.
-        if (!dominators.dominates(definingBlock, blockArg.getOwner())) {
-          toProcess.emplace_back(blockArg, blockArg.getParentBlock());
-          blockArgsToFree.insert(blockArg);
-        } else if (visitedBlockArgs
-                       .insert(std::make_tuple(blockArg, definingBlock))
+        if (!dominators.dominates(definingBlock, value.getParentBlock())) {
+          toProcess.emplace_back(value, value.getParentBlock());
+          valuesToFree.insert(value);
+        } else if (visitedValues.insert(std::make_tuple(value, definingBlock))
                        .second)
-          toProcess.emplace_back(blockArg, definingBlock);
+          toProcess.emplace_back(value, definingBlock);
       }
     };
 
@@ -316,62 +370,170 @@ class BufferPlacement {
 
     // Update buffer aliases to ensure that we free all buffers and block
     // arguments at the correct locations.
-    aliases.remove(blockArgsToFree);
+    aliases.remove(valuesToFree);
 
     // Add new allocs and additional copy operations.
-    for (BlockArgument blockArg : blockArgsToFree) {
-      Block *block = blockArg.getOwner();
-
-      // Allocate a buffer for the current block argument in the block of
-      // the associated value (which will be a predecessor block by
-      // definition).
-      for (auto it = block->pred_begin(), e = block->pred_end(); it != e;
-           ++it) {
-        // Get the terminator and the value that will be passed to our
-        // argument.
-        Operation *terminator = (*it)->getTerminator();
-        auto branchInterface = cast<BranchOpInterface>(terminator);
-        // Convert the mutable operand range to an immutable range and query the
-        // associated source value.
-        Value sourceValue =
-            branchInterface.getSuccessorOperands(it.getSuccessorIndex())
-                .getValue()[blockArg.getArgNumber()];
-        // Create a new alloc at the current location of the terminator.
-        auto memRefType = sourceValue.getType().cast<MemRefType>();
-        OpBuilder builder(terminator);
-
-        // Extract information about dynamically shaped types by
-        // extracting their dynamic dimensions.
-        SmallVector<Value, 4> dynamicOperands;
-        for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
-          if (!ShapedType::isDynamic(shapeElement.value()))
-            continue;
-          dynamicOperands.push_back(builder.create<DimOp>(
-              terminator->getLoc(), sourceValue, shapeElement.index()));
-        }
+    for (Value value : valuesToFree) {
+      if (auto blockArg = value.dyn_cast<BlockArgument>())
+        introduceBlockArgCopy(blockArg);
+      else
+        introduceValueCopyForRegionResult(value);
+
+      // Register the value to require a final dealloc. Note that we do not have
+      // to assign a block here since we do not want to move the allocation node
+      // to another location.
+      allocs.push_back({value, nullptr, nullptr});
+    }
+  }
 
-        // TODO: provide a generic interface to create dialect-specific
-        // Alloc and CopyOp nodes.
-        auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
-                                             dynamicOperands);
-        // Wire new alloc and successor operand.
-        branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex())
-            .getValue()
+  /// Introduces temporary allocs in all predecessors and copies the source
+  /// values into the newly allocated buffers.
+  void introduceBlockArgCopy(BlockArgument blockArg) {
+    // Allocate a buffer for the current block argument in the block of
+    // the associated value (which will be a predecessor block by
+    // definition).
+    Block *block = blockArg.getOwner();
+    for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
+      // Get the terminator and the value that will be passed to our
+      // argument.
+      Operation *terminator = (*it)->getTerminator();
+      auto branchInterface = cast<BranchOpInterface>(terminator);
+      // Query the associated source value.
+      Value sourceValue =
+          branchInterface.getSuccessorOperands(it.getSuccessorIndex())
+              .getValue()[blockArg.getArgNumber()];
+      // Create a new alloc and copy at the current location of the terminator.
+      Value alloc = introduceBufferCopy(sourceValue, terminator);
+      // Wire new alloc and successor operand.
+      auto mutableOperands =
+          branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
+      if (!mutableOperands.hasValue())
+        terminator->emitError() << "terminators with immutable successor "
+                                   "operands are not supported";
+      else
+        mutableOperands.getValue()
             .slice(blockArg.getArgNumber(), 1)
             .assign(alloc);
-        // Create a new copy operation that copies to contents of the old
-        // allocation to the new one.
-        builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue,
-                                       alloc);
-      }
+    }
 
-      // Register the block argument to require a final dealloc. Note that
-      // we do not have to assign a block here since we do not want to
-      // move the allocation node to another location.
-      allocs.push_back({blockArg, nullptr, nullptr});
+    // Check whether the block argument has implicitly defined predecessors via
+    // the RegionBranchOpInterface. This can be the case if the current block
+    // argument belongs to the first block in a region and the parent operation
+    // implements the RegionBranchOpInterface.
+    Region *argRegion = block->getParent();
+    RegionBranchOpInterface regionInterface;
+    if (!argRegion || &argRegion->front() != block ||
+        !(regionInterface =
+              dyn_cast<RegionBranchOpInterface>(argRegion->getParentOp())))
+      return;
+
+    introduceCopiesForRegionSuccessors(
+        regionInterface, argRegion->getParentOp()->getRegions(),
+        [&](RegionSuccessor &successorRegion) {
+          // Find a predecessor of our argRegion.
+          return successorRegion.getSuccessor() == argRegion;
+        },
+        [&](RegionSuccessor &successorRegion) {
+          // The operand index will be the argument number.
+          return blockArg.getArgNumber();
+        });
+  }
+
+  /// Introduces temporary allocs in front of all associated nested-region
+  /// terminators and copies the source values into the newly allocated buffers.
+  void introduceValueCopyForRegionResult(Value value) {
+    // Get the actual result index in the scope of the parent terminator.
+    Operation *operation = value.getDefiningOp();
+    auto regionInterface = cast<RegionBranchOpInterface>(operation);
+    introduceCopiesForRegionSuccessors(
+        regionInterface, operation->getRegions(),
+        [&](RegionSuccessor &successorRegion) {
+          // Determine whether this region has a successor entry that leaves
+          // this region by returning to its parent operation.
+          return !successorRegion.getSuccessor();
+        },
+        [&](RegionSuccessor &successorRegion) {
+          // Find the associated success input index.
+          return llvm::find(successorRegion.getSuccessorInputs(), value)
+              .getIndex();
+        });
+  }
+
+  /// Introduces buffer copies for all terminators in the given regions. The
+  /// regionPredicate is applied to every successor region in order to restrict
+  /// the copies to specific regions. Thereby, the operandProvider is invoked
+  /// for each matching region successor and determines the operand index that
+  /// requires a buffer copy.
+  template <typename TPredicate, typename TOperandProvider>
+  void
+  introduceCopiesForRegionSuccessors(RegionBranchOpInterface regionInterface,
+                                     MutableArrayRef<Region> regions,
+                                     const TPredicate &regionPredicate,
+                                     const TOperandProvider &operandProvider) {
+    // Create an empty attribute for each operand to comply with the
+    // `getSuccessorRegions` interface definition that requires a single
+    // attribute per operand.
+    SmallVector<Attribute, 2> operandAttributes(
+        regionInterface.getOperation()->getNumOperands());
+    for (Region &region : regions) {
+      // Query the regionInterface to get all successor regions of the current
+      // one.
+      SmallVector<RegionSuccessor, 2> successorRegions;
+      regionInterface.getSuccessorRegions(region.getRegionNumber(),
+                                          operandAttributes, successorRegions);
+      // Try to find a matching region successor.
+      RegionSuccessor *regionSuccessor =
+          llvm::find_if(successorRegions, regionPredicate);
+      if (regionSuccessor == successorRegions.end())
+        continue;
+      // Get the operand index in the context of the current successor input
+      // bindings.
+      auto operandIndex = operandProvider(*regionSuccessor);
+
+      // Iterate over all immediate terminator operations to introduce
+      // new buffer allocations. Thereby, the appropriate terminator operand
+      // will be adjusted to point to the newly allocated buffer instead.
+      walkReturnOperations(&region, [&](Operation *terminator) {
+        // Extract the source value from the current terminator.
+        Value sourceValue = terminator->getOperand(operandIndex);
+        // Create a new alloc at the current location of the terminator.
+        Value alloc = introduceBufferCopy(sourceValue, terminator);
+        // Wire alloc and terminator operand.
+        terminator->setOperand(operandIndex, alloc);
+      });
     }
   }
 
+  /// Creates a new memory allocation for the given source value and copies
+  /// its content into the newly allocated buffer. The terminator operation is
+  /// used to insert the alloc and copy operations at the right places.
+  Value introduceBufferCopy(Value sourceValue, Operation *terminator) {
+    // Create a new alloc at the current location of the terminator.
+    auto memRefType = sourceValue.getType().cast<MemRefType>();
+    OpBuilder builder(terminator);
+
+    // Extract information about dynamically shaped types by
+    // extracting their dynamic dimensions.
+    SmallVector<Value, 4> dynamicOperands;
+    for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
+      if (!ShapedType::isDynamic(shapeElement.value()))
+        continue;
+      dynamicOperands.push_back(builder.create<DimOp>(
+          terminator->getLoc(), sourceValue, shapeElement.index()));
+    }
+
+    // TODO: provide a generic interface to create dialect-specific
+    // Alloc and CopyOp nodes.
+    auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
+                                         dynamicOperands);
+
+    // Create a new copy operation that copies to contents of the old
+    // allocation to the new one.
+    builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue, alloc);
+
+    return alloc;
+  }
+
   /// Finds associated deallocs that can be linked to our allocation nodes (if
   /// any).
   void findDeallocs() {
@@ -440,8 +602,8 @@ class BufferPlacement {
       if (entry.deallocOperation) {
         entry.deallocOperation->moveAfter(endOperation);
       } else {
-        // If the Dealloc position is at the terminator operation of the block,
-        // then the value should escape from a deallocation.
+        // If the Dealloc position is at the terminator operation of the
+        // block, then the value should escape from a deallocation.
         Operation *nextOp = endOperation->getNextNode();
         if (!nextOp)
           continue;

diff  --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir
index 93b6b5ade058..225a186caeb0 100644
--- a/mlir/test/Transforms/buffer-placement.mlir
+++ b/mlir/test/Transforms/buffer-placement.mlir
@@ -716,3 +716,201 @@ func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %a
 // CHECK: dealloc %[[Y]]
 // CHECK: return %[[ARG1]], %[[X]]
 
+// -----
+
+// Test Case: nested region control flow
+// The alloc position of %1 does not need to be changed and flows through
+// both if branches until it is finally returned. Hence, it does not
+// require a specific dealloc operation. However, %3 requires a dealloc.
+
+// CHECK-LABEL: func @nested_region_control_flow
+func @nested_region_control_flow(
+  %arg0 : index,
+  %arg1 : index) -> memref<?x?xf32> {
+  %0 = cmpi "eq", %arg0, %arg1 : index
+  %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %2 = scf.if %0 -> (memref<?x?xf32>) {
+    scf.yield %1 : memref<?x?xf32>
+  } else {
+    %3 = alloc(%arg0, %arg1) : memref<?x?xf32>
+    scf.yield %1 : memref<?x?xf32>
+  }
+  return %2 : memref<?x?xf32>
+}
+
+//      CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+//      CHECK: scf.yield %[[ALLOC0]]
+//      CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1)
+// CHECK-NEXT: dealloc %[[ALLOC2]]
+// CHECK-NEXT: scf.yield %[[ALLOC0]]
+//      CHECK: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow with a nested buffer allocation in a
+// divergent branch.
+// The alloc positions of %1, %3 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %3 is allocated and "returned" in a divergent branch, we have
+// to allocate a temporary buffer (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @nested_region_control_flow_div
+func @nested_region_control_flow_div(
+  %arg0 : index,
+  %arg1 : index) -> memref<?x?xf32> {
+  %0 = cmpi "eq", %arg0, %arg1 : index
+  %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %2 = scf.if %0 -> (memref<?x?xf32>) {
+    scf.yield %1 : memref<?x?xf32>
+  } else {
+    %3 = alloc(%arg0, %arg1) : memref<?x?xf32>
+    scf.yield %3 : memref<?x?xf32>
+  }
+  return %2 : memref<?x?xf32>
+}
+
+//      CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+//      CHECK: %[[ALLOC2:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]])
+//      CHECK: scf.yield %[[ALLOC2]]
+//      CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1)
+//      CHECK: %[[ALLOC4:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]])
+//      CHECK: dealloc %[[ALLOC3]]
+//      CHECK: scf.yield %[[ALLOC4]]
+//      CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: deeply nested region control flow with a nested buffer allocation
+// in a divergent branch.
+// The alloc positions of %1, %4 and %5 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %4 is allocated and "returned" in a divergent branch, we have
+// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @nested_region_control_flow_div_nested
+func @nested_region_control_flow_div_nested(
+  %arg0 : index,
+  %arg1 : index) -> memref<?x?xf32> {
+  %0 = cmpi "eq", %arg0, %arg1 : index
+  %1 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %2 = scf.if %0 -> (memref<?x?xf32>) {
+    %3 = scf.if %0 -> (memref<?x?xf32>) {
+      scf.yield %1 : memref<?x?xf32>
+    } else {
+      %4 = alloc(%arg0, %arg1) : memref<?x?xf32>
+      scf.yield %4 : memref<?x?xf32>
+    }
+    scf.yield %3 : memref<?x?xf32>
+  } else {
+    %5 = alloc(%arg1, %arg1) : memref<?x?xf32>
+    scf.yield %5 : memref<?x?xf32>
+  }
+  return %2 : memref<?x?xf32>
+}
+//      CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
+// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if
+//      CHECK: %[[ALLOC3:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC3]])
+//      CHECK: scf.yield %[[ALLOC3]]
+//      CHECK: %[[ALLOC4:.*]] = alloc(%arg0, %arg1)
+//      CHECK: %[[ALLOC5:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]])
+//      CHECK: dealloc %[[ALLOC4]]
+//      CHECK: scf.yield %[[ALLOC5]]
+//      CHECK: %[[ALLOC6:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC6]])
+//      CHECK: dealloc %[[ALLOC2]]
+//      CHECK: scf.yield %[[ALLOC6]]
+//      CHECK: %[[ALLOC7:.*]] = alloc(%arg1, %arg1)
+//      CHECK: %[[ALLOC8:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
+//      CHECK: dealloc %[[ALLOC7]]
+//      CHECK: scf.yield %[[ALLOC8]]
+//      CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow within a region interface.
+// The alloc positions of %0 does not need to be changed and no copies are
+// required in this case since the allocation finally escapes the method.
+
+// CHECK-LABEL: func @inner_region_control_flow
+func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
+  %0 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
+    ^bb0(%arg1 : memref<?x?xf32>):
+      test.region_if_yield %arg1 : memref<?x?xf32>
+  } else {
+    ^bb0(%arg1 : memref<?x?xf32>):
+      test.region_if_yield %arg1 : memref<?x?xf32>
+  } join {
+    ^bb0(%arg1 : memref<?x?xf32>):
+      test.region_if_yield %arg1 : memref<?x?xf32>
+  }
+  return %1 : memref<?x?xf32>
+}
+
+//      CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
+// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC2]]
+//      CHECK: ^bb0(%[[ALLOC3:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
+//      CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
+// CHECK-NEXT: test.region_if_yield %[[ALLOC4]]
+//      CHECK: return %[[ALLOC1]]
+
+// -----
+
+// Test Case: nested region control flow within a region interface including an
+// allocation in a divergent branch.
+// The alloc positions of %1 and %2 does not need to be changed since
+// BufferPlacement does not move allocs out of nested regions at the moment.
+// However, since %2 is allocated and yielded in a divergent branch, we have
+// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
+
+// CHECK-LABEL: func @inner_region_control_flow_div
+func @inner_region_control_flow_div(
+  %arg0 : index,
+  %arg1 : index) -> memref<?x?xf32> {
+  %0 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
+    ^bb0(%arg2 : memref<?x?xf32>):
+      test.region_if_yield %arg2 : memref<?x?xf32>
+  } else {
+    ^bb0(%arg2 : memref<?x?xf32>):
+      %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+      test.region_if_yield %2 : memref<?x?xf32>
+  } join {
+    ^bb0(%arg2 : memref<?x?xf32>):
+      test.region_if_yield %arg2 : memref<?x?xf32>
+  }
+  return %1 : memref<?x?xf32>
+}
+
+//      CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
+// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
+// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
+//      CHECK: %[[ALLOC3:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]])
+// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
+//      CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
+//      CHECK: %[[ALLOC5:.*]] = alloc
+//      CHECK: %[[ALLOC6:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC5]], %[[ALLOC6]])
+// CHECK-NEXT: dealloc %[[ALLOC5]]
+// CHECK-NEXT: test.region_if_yield %[[ALLOC6]]
+//      CHECK: ^bb0(%[[ALLOC7:.*]]:{{.*}}):
+//      CHECK: %[[ALLOC8:.*]] = alloc
+// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
+// CHECK-NEXT: dealloc %[[ALLOC7]]
+// CHECK-NEXT: test.region_if_yield %[[ALLOC8]]
+//      CHECK: dealloc %[[ALLOC0]]
+// CHECK-NEXT: return %[[ALLOC1]]

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e4b1793c9399..dab2ea1bff6a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -518,6 +518,77 @@ void StringAttrPrettyNameOp::getAsmResultNames(
         setNameFn(getResult(i), str.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// RegionIfOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, RegionIfOp op) {
+  p << RegionIfOp::getOperationName() << " ";
+  p.printOperands(op.getOperands());
+  p << ": " << op.getOperandTypes();
+  p.printArrowTypeList(op.getResultTypes());
+  p << " then";
+  p.printRegion(op.thenRegion(),
+                /*printEntryBlockArgs=*/true,
+                /*printBlockTerminators=*/true);
+  p << " else";
+  p.printRegion(op.elseRegion(),
+                /*printEntryBlockArgs=*/true,
+                /*printBlockTerminators=*/true);
+  p << " join";
+  p.printRegion(op.joinRegion(),
+                /*printEntryBlockArgs=*/true,
+                /*printBlockTerminators=*/true);
+}
+
+static ParseResult parseRegionIfOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 2> operandInfos;
+  SmallVector<Type, 2> operandTypes;
+
+  result.regions.reserve(3);
+  Region *thenRegion = result.addRegion();
+  Region *elseRegion = result.addRegion();
+  Region *joinRegion = result.addRegion();
+
+  // Parse operand, type and arrow type lists.
+  if (parser.parseOperandList(operandInfos) ||
+      parser.parseColonTypeList(operandTypes) ||
+      parser.parseArrowTypeList(result.types))
+    return failure();
+
+  // Parse all attached regions.
+  if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
+      parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
+      parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
+    return failure();
+
+  return parser.resolveOperands(operandInfos, operandTypes,
+                                parser.getCurrentLocation(), result.operands);
+}
+
+OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
+  assert(index < 2 && "invalid region index");
+  return getOperands();
+}
+
+void RegionIfOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // We always branch to the join region.
+  if (index.hasValue()) {
+    if (index.getValue() < 2)
+      regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
+    else
+      regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // The then and else regions are the entry regions of this op.
+  regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
+  regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bc8f2ff3f818..ddaa27fd0ca8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1349,4 +1349,47 @@ def SideEffectOp : TEST_Op<"side_effect_op",
   let results = (outs AnyType:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// Test RegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def RegionIfYieldOp : TEST_Op<"region_if_yield",
+      [NoSideEffect, ReturnLike, Terminator]> {
+  let arguments = (ins Variadic<AnyType>:$results);
+  let assemblyFormat = [{
+    $results `:` type($results) attr-dict
+  }];
+}
+
+def RegionIfOp : TEST_Op<"region_if",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       SingleBlockImplicitTerminator<"RegionIfYieldOp">,
+       RecursiveSideEffects]> {
+  let description =[{
+    Represents an abstract if-then-else-join pattern. In this context, the then
+    and else regions jump to the join region, which finally returns to its
+    parent op.
+    }];
+
+  let printer = [{ return ::print(p, *this); }];
+  let parser = [{ return ::parseRegionIfOp(parser, result); }];
+  let arguments = (ins Variadic<AnyType>);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$thenRegion,
+                        AnyRegion:$elseRegion,
+                        AnyRegion:$joinRegion);
+  let extraClassDeclaration = [{
+    Block::BlockArgListType getThenArgs() {
+      return getBody(0)->getArguments();
+    }
+    Block::BlockArgListType getElseArgs() {
+      return getBody(1)->getArguments();
+    }
+    Block::BlockArgListType getJoinArgs() {
+      return getBody(2)->getArguments();
+    }
+    OperandRange getSuccessorEntryOperands(unsigned index);
+  }];
+}
+
 #endif // TEST_OPS


        


More information about the Mlir-commits mailing list