[Mlir-commits] [mlir] [mlir][Interfaces][NFC] Better documentation for `RegionBranchOpInterface` (PR #66920)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 20 08:38:22 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66920

Update outdated documentation and add an example.

>From bae2cd0499b08a4d58b68026a590752f0525e8d6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 20 Sep 2023 17:36:07 +0200
Subject: [PATCH] [mlir][Interfaces][NFC] Better documentation for
 `RegionBranchOpInterface`

Update outdated documentation and add an example.
---
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 137 ++++++++++++------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |   9 --
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  31 +---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp |  18 +--
 4 files changed, 104 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index e52636a5ac8fcca..22cdd63f9b13eed 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,27 +117,58 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
 
 def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let description = [{
-    This interface provides information for region operations that contain
-    branching behavior between held regions, i.e. this interface allows for
+    This interface provides information for region operations that exhibit
+    branching behavior between held regions. I.e., this interface allows for
     expressing control flow information for region holding operations.
 
-    This interface is meant to model well-defined cases of control-flow of
+    This interface is meant to model well-defined cases of control-flow and
     value propagation, where what occurs along control-flow edges is assumed to
-    be side-effect free. For example, corresponding successor operands and
-    successor block arguments may have different types. In such cases,
-    `areTypesCompatible` can be implemented to compare types along control-flow
-    edges. By default, type equality is used.
+    be side-effect free.
+
+    A "region branch point" indicates a point from which a branch originates. It
+    can indicate either a region of this op or `RegionBranchPoint::parent()`. In
+    the latter case, the branch originates from outside of outside of the op,
+    i.e., when first executing this op.
+
+    A "region successor" indicates the target of a branch. It can indicate
+    either a region of this op or this op. In the former case, the region
+    successor is a region pointer and a range of block arguments to which the
+    "successor operands" are forwarded to. In the latter case, the control flow
+    leaves this op and the region successor is a range of results of this op to
+    which the successor operands are forwarded to.
+
+    By default, successor operands and successor block arguments/successor
+    results must have the same type. `areTypesCompatible` can be implemented to
+    allow non-equal types.
+
+    Example:
+
+    ```
+    %r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b)
+        -> tensor<5xf32> {
+      ...
+      scf.yield %c : tensor<5xf32>
+    }
+    ```
+
+    `scf.if` has one region. The region has two region successors: the region
+    itself and the `scf.if` op. %b is an entry successor operand. %c is a
+    successor operand. %a is a successor block argument. %r is a successor
+    result.
   }];
   let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{
-        Returns the operands of this operation used as the entry arguments when
-        branching from `point`, which was specified as a successor of
-        this operation by `getEntrySuccessorRegions`, or the operands forwarded
-        to the operation's results when it branches back to itself. These operands
-        should correspond 1-1 with the successor inputs specified in
-        `getEntrySuccessorRegions`.
+        Returns the operands of this operation that are forwarded to the region
+        successor's block arguments or this operation's results when branching
+        to `point`. `point` is guaranteed to be among the successors that are
+        returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
+
+        Example: In the above example, this method returns the operand %b of the
+        `scf.for` op, regardless of the value of `point`. I.e., this op always
+        forwards the same operands, regardless of whether the loop has 0 or more
+        iterations.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
       (ins "::mlir::RegionBranchPoint":$point), [{}],
@@ -147,32 +178,44 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Returns the viable region successors that are branched to when first
-        executing the op.
+        Returns the potentional region successors when first executing the op.
+
         Unlike `getSuccessorRegions`, this method also passes along the
-        constant operands of this op. Based on these, different region
-        successors can be determined.
-        `operands` contains an entry for every operand of the implementing
-        op with a null attribute if the operand has no constant value or
-        the corresponding attribute if it is a constant.
+        constant operands of this op. Based on these, the implementation may
+        filter out certain successors. By default, simply dispatches to
+        `getSuccessorRegions`. `operands` contains an entry for every
+        operand of this op, with a null attribute if the operand has no constant
+        value.
 
-        By default, simply dispatches to `getSuccessorRegions`.
+        Example: In the above example, this method may return two region
+        region successors: the single region of the `scf.for` op and the this
+        operation. If %lb, %ub, %step are constants and it can be determined
+        the loop does not have any iterations, this method may choose to return
+        only this operation. Similarly, if it can be determined that the loop
+        has at least one iteration, this method may choose to return only the
+        region of the loop.
       }],
       "void", "getEntrySuccessorRegions",
       (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
-           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
-           [{}], [{
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
+      /*defaultImplementation=*/[{
         $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
       }]
     >,
     InterfaceMethod<[{
-        Returns the viable successors of `point`. These are the regions that may
-        be selected during the flow of control. The parent operation, may
-        specify itself as successor, which indicates that the control flow may
-        not enter any region at all. This method allows for describing which
-        regions may be executed when entering an operation, and which regions
-        are executed after having executed another region of the parent op. The
-        successor region must be non-empty.
+        Returns the potentional region successors when branching from `point`.
+        These are the regions that may be selected during the flow of control.
+
+        When `point = RegionBranchPoint::parent()`, this method returns the
+        region successors when entering the operation. Otherwise, this method
+        returns the successor regions when branching from the region indicated
+        by `point`.
+
+        Example: In the above example, this method returns the region of the
+        `scf.for` and this operation for either region branch point (`parent`
+        and the region of the `scf.for`). An implementation may choose to filter
+        out region successors when it is statically known (e.g., by examining
+        the operands of this op) that those successors are not branched to.
       }],
       "void", "getSuccessorRegions",
       (ins "::mlir::RegionBranchPoint":$point,
@@ -183,12 +226,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
         times this operation will invoke the attached regions (assuming the
         regions yield normally, i.e. do not abort or invoke an infinite loop).
         The minimum number of invocations is at least 0. If the maximum number
-        of invocations cannot be statically determined, then it will not have a
-        value (i.e., it is set to `std::nullopt`).
+        of invocations cannot be statically determined, then it will be set to
+        `InvocationBounds::getUnknown()`.
 
-        `operands` is a set of optional attributes that either correspond to
-        constant values for each operand of this operation or null if that
-        operand is not a constant.
+        This method also passes along the constant operands of this op.
+        `operands` contains an entry for every operand of this op, with a null
+        attribute if the operand has no constant value.
 
         This method may be called speculatively on operations where the provided
         operands are not necessarily the same as the operation's current
@@ -199,8 +242,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
            "::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
              :$invocationBounds), [{}],
-       [{ invocationBounds.append($_op->getNumRegions(),
-                                  ::mlir::InvocationBounds::getUnknown()); }]
+      /*defaultImplementation=*/[{
+        invocationBounds.append($_op->getNumRegions(),
+                                ::mlir::InvocationBounds::getUnknown());
+      }]
     >,
     InterfaceMethod<[{
         This method is called to compare types along control-flow edges. By
@@ -208,7 +253,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       }],
       "bool", "areTypesCompatible",
       (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
-       [{ return lhs == rhs; }]
+      /*defaultImplementation=*/[{ return lhs == rhs; }]
     >,
   ];
 
@@ -235,7 +280,7 @@ def RegionBranchTerminatorOpInterface :
   OpInterface<"RegionBranchTerminatorOpInterface"> {
   let description = [{
     This interface provides information for branching terminator operations
-    in the presence of a parent RegionBranchOpInterface implementation. It
+    in the presence of a parent `RegionBranchOpInterface` implementation. It
     specifies which operands are passed to which successor region.
   }];
   let cppNamespace = "::mlir";
@@ -243,26 +288,26 @@ def RegionBranchTerminatorOpInterface :
   let methods = [
     InterfaceMethod<[{
         Returns a mutable range of operands that are semantically "returned" by
-        passing them to the region successor given by `point`.
+        passing them to the region successor indicated by `point`.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
       (ins "::mlir::RegionBranchPoint":$point)
     >,
     InterfaceMethod<[{
-        Returns the viable region successors that are branched to after this
+        Returns the potential region successors that are branched to after this
         terminator based on the given constant operands.
 
-        `operands` contains an entry for every operand of the
-        implementing op with a null attribute if the operand has no constant
-        value or the corresponding attribute if it is a constant.
+        This method also passes along the constant operands of this op.
+        `operands` contains an entry for every operand of this op, with a null
+        attribute if the operand has no constant value.
 
-        Default implementation simply dispatches to the parent
+        The default implementation simply dispatches to the parent
         `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
       }],
       "void", "getSuccessorRegions",
       (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
-      [{
+      /*defaultImplementation=*/[{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
           .getSuccessorRegions(op->getParentRegion(), regions);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5899c198b703b5e..f48b48f943d873c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2375,10 +2375,6 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<AffineForEmptyLoopFolder>(context);
 }
 
-/// Return operands used when entering the region at 'index'. These operands
-/// correspond to the loop iterator operands, i.e., those excluding the
-/// induction variable. AffineForOp only has one region, so zero is the only
-/// valid value for `index`.
 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   assert((point.isParent() || point == getRegion()) && "invalid region point");
 
@@ -2387,11 +2383,6 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   return getIterOperands();
 }
 
-/// Given the region at `index`, or the parent operation if `index` is None,
-/// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
 void AffineForOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   assert((point.isParent() || point == getRegion()) && "expected loop region");
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 45e68e23a71d60e..1230b3d786f6220 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -260,11 +260,6 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
 }
 
-/// Given the region at `index`, or the parent operation if `index` is None,
-/// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
 void ExecuteRegionOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
   // If the predecessor is the ExecuteRegionOp, branch into the body.
@@ -541,18 +536,10 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
   return dyn_cast_or_null<ForOp>(containingOp);
 }
 
-/// Return operands used when entering the region at 'index'. These operands
-/// correspond to the loop iterator operands, i.e., those excluding the
-/// induction variable.
 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   return getInitArgs();
 }
 
-/// Given the region at `index`, or the parent operation if `index` is None,
-/// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
 void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
   // Both the operation itself and the region may be branching into the body or
@@ -1997,11 +1984,6 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
-/// Given the region at `index`, or the parent operation if `index` is None,
-/// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
 void IfOp::getSuccessorRegions(RegionBranchPoint point,
                                SmallVectorImpl<RegionSuccessor> &regions) {
   // The `then` and the `else` region branch back to the parent operation.
@@ -3160,13 +3142,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert(point == getBefore() &&
-         "WhileOp is expected to branch only to the first region");
-
-  return getInits();
-}
-
 ConditionOp WhileOp::getConditionOp() {
   return cast<ConditionOp>(getBeforeBody()->getTerminator());
 }
@@ -3183,6 +3158,12 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
   return getAfterBody()->getArguments();
 }
 
+OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  assert(point == getBefore() &&
+         "WhileOp is expected to branch only to the first region");
+  return getInits();
+}
+
 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
   // The parent op always branches to the condition region.
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index b3166155e84f934..4ed024ddae247b0 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -102,11 +102,8 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
 }
 
 /// Verify that types match along all region control flow edges originating from
-/// `sourceNo` (region # if source is a region, std::nullopt if source is parent
-/// op). `getInputsTypesForRegion` is a function that returns the types of the
-/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
-/// the exact type match verification is not necessary (e.g., if the Op verifies
-/// the match itself).
+/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
+/// types of the inputs that flow to a successor region.
 static LogicalResult
 verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
                          function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
@@ -150,8 +147,8 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
-  auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
-    return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
+  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
+    return regionInterface.getEntrySuccessorOperands(point).getTypes();
   };
 
   // Verify types along control flow edges originating from the parent.
@@ -190,11 +187,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
       continue;
 
     auto inputTypesForRegion =
-        [&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
+        [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
       std::optional<OperandRange> regionReturnOperands;
       for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
-        auto terminatorOperands =
-            regionReturnOp.getSuccessorOperands(succRegionNo);
+        auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
 
         if (!regionReturnOperands) {
           regionReturnOperands = terminatorOperands;
@@ -206,7 +202,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
         if (!areTypesCompatible(regionReturnOperands->getTypes(),
                                 terminatorOperands.getTypes())) {
           InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
-          return printRegionEdgeName(diag, region, succRegionNo)
+          return printRegionEdgeName(diag, region, point)
                  << " operands mismatch between return-like terminators";
         }
       }



More information about the Mlir-commits mailing list