[Mlir-commits] [mlir] 359ba0b - [mlir][CFGToSCF] Add interface changes for downstream projects

Markus Böck llvmlistbot at llvm.org
Tue Aug 15 07:40:08 PDT 2023


Author: Markus Böck
Date: 2023-08-15T16:38:16+02:00
New Revision: 359ba0b00806c6fba325733e817637522b8c6e19

URL: https://github.com/llvm/llvm-project/commit/359ba0b00806c6fba325733e817637522b8c6e19
DIFF: https://github.com/llvm/llvm-project/commit/359ba0b00806c6fba325733e817637522b8c6e19.diff

LOG: [mlir][CFGToSCF] Add interface changes for downstream projects

This is a follow-up to https://reviews.llvm.org/D156889

Downstream projects may have more complicated ops than the control flow ops upstream and therefore need a more powerful interface to support the lifting process. Use cases include the propagation of (inherent) metadata that was previously on the control flow ops and now needs to be lifted to structured control flow ops.
Since the lifting process is inherently non-local in respect to the function-body, we require stronger guarantees from the interface.

This patch therefore makes two changes to the interface:
* Passes the terminator that is being replaced to `createStructuredBranchRegionTerminatorOp`
* Adds as precondition to `createCFGSwitchOp` that its predecessors are already correctly established

Asserts have been added to verify these were it makes sense and to correctly state intent. I have not added tests purely because testing preconditions like these is not really feasible (and incredibly specific).

Differential Revision: https://reviews.llvm.org/D157981

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h
    mlir/include/mlir/Transforms/CFGToSCF.h
    mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
    mlir/lib/Transforms/Utils/CFGToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h
index e8f0137d75b29b..dab8e7d51f24a3 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h
@@ -32,10 +32,9 @@ class ControlFlowToSCFTransformation : public CFGToSCFInterface {
       MutableArrayRef<Region> regions) override;
 
   /// Creates an `scf.yield` op returning the given results.
-  LogicalResult
-  createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder,
-                                           Operation *branchRegionOp,
-                                           ValueRange results) override;
+  LogicalResult createStructuredBranchRegionTerminatorOp(
+      Location loc, OpBuilder &builder, Operation *branchRegionOp,
+      Operation *replacedControlFlowOp, ValueRange results) override;
 
   /// Creates an `scf.while` op. The loop body is made the before-region of the
   /// while op and terminated with an `scf.condition` op. The after-region does

diff  --git a/mlir/include/mlir/Transforms/CFGToSCF.h b/mlir/include/mlir/Transforms/CFGToSCF.h
index a611943bbcd031..e3d4b6f3d7e7ea 100644
--- a/mlir/include/mlir/Transforms/CFGToSCF.h
+++ b/mlir/include/mlir/Transforms/CFGToSCF.h
@@ -42,12 +42,14 @@ class CFGToSCFInterface {
 
   /// Creates a return-like terminator for a branch region of the op returned
   /// by `createStructuredBranchRegionOp`. `branchRegionOp` is the operation
-  /// returned by `createStructuredBranchRegionOp` while `results` are the
-  /// values that should be returned by the branch region.
-  virtual LogicalResult
-  createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder,
-                                           Operation *branchRegionOp,
-                                           ValueRange results) = 0;
+  /// returned by `createStructuredBranchRegionOp`.
+  /// `replacedControlFlowOp` is the control flow op being replaced by the
+  /// terminator or nullptr if the terminator is not replacing any existing
+  /// control flow op. `results` are the values that should be returned by the
+  /// branch region.
+  virtual LogicalResult createStructuredBranchRegionTerminatorOp(
+      Location loc, OpBuilder &builder, Operation *branchRegionOp,
+      Operation *replacedControlFlowOp, ValueRange results) = 0;
 
   /// Creates a structured control flow operation representing a do-while loop.
   /// The do-while loop is expected to have the exact same result types as the
@@ -77,8 +79,10 @@ class CFGToSCFInterface {
   /// `caseDestinations` or `defaultDest`. This is used by the transformation
   /// for intermediate transformations before lifting to structured control
   /// flow. The switch op branches based on `flag` which is guaranteed to be of
-  /// the same type as values returned by `getCFGSwitchValue`. Note:
-  /// `caseValues` and other related ranges may be empty to represent an
+  /// the same type as values returned by `getCFGSwitchValue`. The insertion
+  /// block of the builder is guaranteed to have its predecessors already set
+  /// to create an equivalent CFG after this operation.
+  /// Note: `caseValues` and other related ranges may be empty to represent an
   /// unconditional branch.
   virtual void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag,
                                  ArrayRef<unsigned> caseValues,

diff  --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index af04ca801f2ace..c0169a024bd5e8 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -76,7 +76,7 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
 LogicalResult
 ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
     Location loc, OpBuilder &builder, Operation *branchRegionOp,
-    ValueRange results) {
+    Operation *replacedControlFlowOp, ValueRange results) {
   builder.create<scf::YieldOp>(loc, results);
   return success();
 }

diff  --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 0fa6744964c4e2..84f23584e9f30e 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -306,7 +306,8 @@ class EdgeMultiplexer {
   /// Creates a switch op using `builder` which dispatches to the original
   /// successors of the edges passed to `create` minus the ones in `excluded`.
   /// The builder's insertion point has to be in a block dominated by the
-  /// multiplexer block.
+  /// multiplexer block. All edges to the multiplexer block must have already
+  /// been redirected using `redirectEdge`.
   void createSwitch(
       Location loc, OpBuilder &builder, CFGToSCFInterface &interface,
       const SmallPtrSetImpl<Block *> &excluded = SmallPtrSet<Block *, 1>{}) {
@@ -337,6 +338,8 @@ class EdgeMultiplexer {
     Block *defaultDest = caseDestinations.pop_back_val();
     ValueRange defaultArgs = caseArguments.pop_back_val();
 
+    assert(!builder.getInsertionBlock()->hasNoPredecessors() &&
+           "Edges need to be redirected prior to creating switch.");
     interface.createCFGSwitchOp(loc, builder, realDiscriminator, caseValues,
                                 caseDestinations, caseArguments, defaultDest,
                                 defaultArgs);
@@ -507,12 +510,14 @@ createSingleEntryBlock(Location loc, ArrayRef<Edge> entryEdges,
       loc, llvm::map_to_vector(entryEdges, std::mem_fn(&Edge::getSuccessor)),
       getSwitchValue, getUndefValue);
 
-  auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock());
-  result.createSwitch(loc, builder, interface);
-
+  // Redirect the edges prior to creating the switch op.
+  // We guarantee that predecessors are up to date.
   for (Edge edge : entryEdges)
     result.redirectEdge(edge);
 
+  auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock());
+  result.createSwitch(loc, builder, interface);
+
   return result;
 }
 
@@ -565,6 +570,17 @@ static FailureOr<StructuredLoopProperties> createSingleExitingLatch(
   // Since this is a loop, all back edges point to the same loop header.
   Block *loopHeader = backEdges.front().getSuccessor();
 
+  // Redirect the edges prior to creating the switch op.
+  // We guarantee that predecessors are up to date.
+
+  // Redirecting back edges with `shouldRepeat` as 1.
+  for (Edge backEdge : backEdges)
+    multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1));
+
+  // Redirecting exits edges with `shouldRepeat` as 0.
+  for (Edge exitEdge : exitEdges)
+    multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0));
+
   // Create the new only back edge to the loop header. Branch to the
   // exit block otherwise.
   Value shouldRepeat = latchBlock->getArguments().back();
@@ -603,14 +619,6 @@ static FailureOr<StructuredLoopProperties> createSingleExitingLatch(
     }
   }
 
-  // Redirecting back edges with `shouldRepeat` as 1.
-  for (Edge backEdge : backEdges)
-    multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1));
-
-  // Redirecting exits edges with `shouldRepeat` as 0.
-  for (Edge exitEdge : exitEdges)
-    multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0));
-
   return StructuredLoopProperties{latchBlock, /*condition=*/shouldRepeat,
                                   exitBlock};
 }
@@ -794,13 +802,14 @@ static FailureOr<SmallVector<Block *>> transformCyclesToSCFLoops(
     // First turn the cycle into a loop by creating a single entry block if
     // needed.
     if (edges.entryEdges.size() > 1) {
+      SmallVector<Edge> edgesToEntryBlocks;
+      llvm::append_range(edgesToEntryBlocks, edges.entryEdges);
+      llvm::append_range(edgesToEntryBlocks, edges.backEdges);
+
       EdgeMultiplexer multiplexer = createSingleEntryBlock(
-          loopHeader->getTerminator()->getLoc(), edges.entryEdges,
+          loopHeader->getTerminator()->getLoc(), edgesToEntryBlocks,
           getSwitchValue, getUndefValue, interface);
 
-      for (Edge edge : edges.backEdges)
-        multiplexer.redirectEdge(edge);
-
       loopHeader = multiplexer.getMultiplexerBlock();
     }
     cycleBlockSet.insert(loopHeader);
@@ -1140,7 +1149,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
   for (auto &&[block, valueRange] : createdEmptyBlocks) {
     auto builder = OpBuilder::atBlockEnd(block);
     LogicalResult result = interface.createStructuredBranchRegionTerminatorOp(
-        structuredCondOp->getLoc(), builder, structuredCondOp, valueRange);
+        structuredCondOp->getLoc(), builder, structuredCondOp, nullptr,
+        valueRange);
     if (failed(result))
       return failure();
   }
@@ -1153,7 +1163,7 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
     assert(user->getNumSuccessors() == 1);
     auto builder = OpBuilder::atBlockTerminator(user->getBlock());
     LogicalResult result = interface.createStructuredBranchRegionTerminatorOp(
-        user->getLoc(), builder, structuredCondOp,
+        user->getLoc(), builder, structuredCondOp, user,
         static_cast<OperandRange>(
             getMutableSuccessorOperands(user->getBlock(), 0)));
     if (failed(result))


        


More information about the Mlir-commits mailing list