[Mlir-commits] [mlir] 621d7cc - [mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching terminator operations.

River Riddle llvmlistbot at llvm.org
Thu Mar 5 12:58:47 PST 2020


Author: River Riddle
Date: 2020-03-05T12:50:35-08:00
New Revision: 621d7cca3751f934f991e34fe0e26187c33314f4

URL: https://github.com/llvm/llvm-project/commit/621d7cca3751f934f991e34fe0e26187c33314f4
DIFF: https://github.com/llvm/llvm-project/commit/621d7cca3751f934f991e34fe0e26187c33314f4.diff

LOG: [mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching terminator operations.

This interface contains the necessary components to provide the same builtin behavior that terminators have. This will be used in future revisions to remove many of the hardcoded constraints placed on successors and successor operands. The interface initially contains three methods:

```c++
// Return a set of values corresponding to the operands for successor 'index', or None if the operands do not correspond to materialized values.
Optional<OperandRange> getSuccessorOperands(unsigned index);

// Return true if this terminator can have it's successor operands erased.
bool canEraseSuccessorOperand();

// Erase the operand of a successor. This is only valid to call if 'canEraseSuccessorOperand' returns true.
void eraseSuccessorOperand(unsigned succIdx, unsigned opIdx);
```

Differential Revision: https://reviews.llvm.org/D75314

Added: 
    mlir/include/mlir/Analysis/ControlFlowInterfaces.h
    mlir/include/mlir/Analysis/ControlFlowInterfaces.td
    mlir/lib/Analysis/ControlFlowInterfaces.cpp

Modified: 
    mlir/include/mlir/Analysis/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/lib/TestDialect/CMakeLists.txt
    mlir/test/lib/TestDialect/TestDialect.cpp
    mlir/test/lib/TestDialect/TestDialect.h
    mlir/test/lib/TestDialect/TestOps.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt
index 3d9a7ed36979..cc8c4939b73b 100644
--- a/mlir/include/mlir/Analysis/CMakeLists.txt
+++ b/mlir/include/mlir/Analysis/CMakeLists.txt
@@ -3,6 +3,11 @@ mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS ControlFlowInterfaces.td)
+mlir_tablegen(ControlFlowInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ControlFlowInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRControlFlowInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
 mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
 mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)

diff  --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.h b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h
new file mode 100644
index 000000000000..87f82040ed61
--- /dev/null
+++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h
@@ -0,0 +1,43 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the branch interfaces defined in
+// `ControlFlowInterfaces.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class BranchOpInterface;
+
+namespace detail {
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. `operandIndex` is the operand within `operands` to be erased.
+void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
+                                 Operation *op);
+
+/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if `operandIndex` is within the range of `operands`, or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument>
+getBranchSuccessorArgument(Optional<OperandRange> operands,
+                           unsigned operandIndex, Block *successor);
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+                                            Optional<OperandRange> operands);
+} // namespace detail
+
+#include "mlir/Analysis/ControlFlowInterfaces.h.inc"
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H

diff  --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.td b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td
new file mode 100644
index 000000000000..b34b7a9d7e81
--- /dev/null
+++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td
@@ -0,0 +1,85 @@
+//===-- ControlFlowInterfaces.td - ControlFlow Interfaces --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to define information
+// about control flow operations, e.g. branches.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def BranchOpInterface : OpInterface<"BranchOpInterface"> {
+  let description = [{
+    This interface provides information for branching terminator operations,
+    i.e. terminator operations with successors.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+        Returns a set of values that correspond to the arguments to the
+        successor at the given index. Returns None if the operands to the
+        successor are non-materialized values, i.e. they are internal to the
+        operation.
+      }],
+      "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
+    >,
+    InterfaceMethod<[{
+        Return true if this operation can erase an operand to a successor block.
+      }],
+      "bool", "canEraseSuccessorOperand"
+    >,
+    InterfaceMethod<[{
+        Erase the operand at `operandIndex` from the `index`-th successor. This
+        should only be called if `canEraseSuccessorOperand` returns true.
+      }],
+      "void", "eraseSuccessorOperand",
+      (ins "unsigned":$index, "unsigned":$operandIndex), [{}],
+      /*defaultImplementation=*/[{
+        ConcreteOp *op = static_cast<ConcreteOp *>(this);
+        Optional<OperandRange> operands = op->getSuccessorOperands(index);
+        assert(operands && "unable to query operands for successor");
+        detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
+      }]
+    >,
+    InterfaceMethod<[{
+        Returns the `BlockArgument` corresponding to operand `operandIndex` in
+        some successor, or None if `operandIndex` isn't a successor operand
+        index.
+      }],
+      "Optional<BlockArgument>", "getSuccessorBlockArgument",
+      (ins "unsigned":$operandIndex), [{
+        Operation *opaqueOp = op;
+        for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) {
+          if (Optional<BlockArgument> arg = detail::getBranchSuccessorArgument(
+                op.getSuccessorOperands(i), operandIndex,
+                opaqueOp->getSuccessor(i)))
+            return arg;
+        }
+        return llvm::None;
+      }]
+    >
+  ];
+
+  let verify = [{
+    auto concreteOp = cast<ConcreteOpType>($_op);
+    for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
+      Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
+      if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands)))
+        return failure();
+    }
+    return success();
+  }];
+}
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d1378b827c52..a8b0be0793d6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
 #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
 
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/OpDefinition.h"

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4585407edce..2ab9041bc2a0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,6 +14,7 @@
 #define LLVMIR_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 
 class LLVM_Builder<string builder> {
   string llvmBuilder = builder;
@@ -315,7 +316,9 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
 def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
 
 // Call-related operations.
-def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
+def LLVM_InvokeOp : LLVM_Op<"invoke", [
+                      DeclareOpInterfaceMethods<BranchOpInterface>, Terminator
+                    ]>,
                     Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                                Variadic<LLVM_Type>)>,
                     Results<(outs Variadic<LLVM_Type>)> {
@@ -458,11 +461,13 @@ def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]>,
 }
 
 // Terminators.
-def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+def LLVM_BrOp : LLVM_TerminatorOp<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>]> {
   let successors = (successor AnySuccessor:$dest);
   let assemblyFormat = "$dest attr-dict";
 }
-def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
+def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>]> {
   let arguments = (ins LLVMI1:$condition);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
   let assemblyFormat = "$condition `,` successors attr-dict";

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 03884afe3e95..5ef825af4ea1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -16,10 +16,13 @@
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
 include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 
 // -----
 
-def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
+def SPV_BranchOp : SPV_Op<"Branch",[
+    DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+    Terminator]> {
   let summary = "Unconditional branch to target block.";
 
   let description = [{
@@ -75,8 +78,9 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
 
 // -----
 
-def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
-                                     [InFunctionScope, Terminator]> {
+def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
+    DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+    Terminator]> {
   let summary = [{
     If Condition is true, branch to true block, otherwise branch to false
     block.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index ea541c056a11..e223f17933f3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
 #define MLIR_DIALECT_SPIRV_SPIRVOPS_H_
 
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Function.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 1e19c0270416..10822764ff6d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
 
 #include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2fe0365408e2..e44f8ff18a3a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -14,6 +14,7 @@
 #define STANDARD_OPS
 
 include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 
 def Std_Dialect : Dialect {
@@ -331,7 +332,8 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
 // BranchOp
 //===----------------------------------------------------------------------===//
 
-def BranchOp : Std_Op<"br", [Terminator]> {
+def BranchOp : Std_Op<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let summary = "branch operation";
   let description = [{
     The "br" operation represents a branch operation in a function.
@@ -668,7 +670,8 @@ def CmpIOp : Std_Op<"cmpi",
 // CondBranchOp
 //===----------------------------------------------------------------------===//
 
-def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
+def CondBranchOp : Std_Op<"cond_br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let summary = "conditional branch operation";
   let description = [{
     The "cond_br" operation represents a conditional branch operation in a

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 96545d6be2b0..7735dd176b08 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -639,6 +639,10 @@ class OperandRange final
   type_range getTypes() const { return {begin(), end()}; }
   auto getType() const { return getTypes(); }
 
+  /// Return the operand index of the first element of this range. The range
+  /// must not be empty.
+  unsigned getBeginOperandIndex() const;
+
 private:
   /// See `detail::indexed_accessor_range_base` for details.
   static OpOperand *offset_base(OpOperand *object, ptr
diff _t index) {

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 9eccde56e6ee..44ec43c96233 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   AffineAnalysis.cpp
   AffineStructures.cpp
   CallGraph.cpp
+  ControlFlowInterfaces.cpp
   Dominance.cpp
   InferTypeOpInterface.cpp
   Liveness.cpp
@@ -14,6 +15,7 @@ set(LLVM_OPTIONAL_SOURCES
 
 add_llvm_library(MLIRAnalysis
   CallGraph.cpp
+  ControlFlowInterfaces.cpp
   InferTypeOpInterface.cpp
   Liveness.cpp
   SliceAnalysis.cpp
@@ -26,6 +28,7 @@ add_llvm_library(MLIRAnalysis
 add_dependencies(MLIRAnalysis
   MLIRAffineOps
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIRTypeInferOpInterfaceIncGen
   MLIRLoopOps
   )
@@ -45,6 +48,7 @@ add_llvm_library(MLIRLoopAnalysis
 add_dependencies(MLIRLoopAnalysis
   MLIRAffineOps
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIRTypeInferOpInterfaceIncGen
   MLIRLoopOps
   )

diff  --git a/mlir/lib/Analysis/ControlFlowInterfaces.cpp b/mlir/lib/Analysis/ControlFlowInterfaces.cpp
new file mode 100644
index 000000000000..7d98f29d7cf6
--- /dev/null
+++ b/mlir/lib/Analysis/ControlFlowInterfaces.cpp
@@ -0,0 +1,101 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ControlFlowInterfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. 'operandIndex' is the operand within 'operands' to be erased.
+void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
+                                               unsigned operandIndex,
+                                               Operation *op) {
+  assert(operandIndex < operands.size() &&
+         "invalid index for successor operands");
+
+  // Erase the operand from the operation.
+  size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
+  op->eraseOperand(fullOperandIndex);
+
+  // If this operation has an OperandSegmentSizeAttr, keep it up to date.
+  auto operandSegmentAttr =
+      op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
+  if (!operandSegmentAttr)
+    return;
+
+  // Find the segment containing the full operand index and decrement it.
+  // TODO: This seems like a general utility that could be added somewhere.
+  SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
+  unsigned currentSize = 0;
+  for (unsigned i = 0, e = values.size(); i != e; ++i) {
+    currentSize += values[i];
+    if (fullOperandIndex < currentSize) {
+      --values[i];
+      break;
+    }
+  }
+  op->setAttr("operand_segment_sizes",
+              DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
+}
+
+/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if 'operandIndex' is within the range of 'operands', or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
+    Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
+  // Check that the operands are valid.
+  if (!operands || operands->empty())
+    return llvm::None;
+
+  // Check to ensure that this operand is within the range.
+  unsigned operandsStart = operands->getBeginOperandIndex();
+  if (operandIndex < operandsStart ||
+      operandIndex >= (operandsStart + operands->size()))
+    return llvm::None;
+
+  // Index the successor.
+  unsigned argIndex = operandIndex - operandsStart;
+  return successor->getArgument(argIndex);
+}
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult
+mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+                                            Optional<OperandRange> operands) {
+  if (!operands)
+    return success();
+
+  // Check the count.
+  unsigned operandCount = operands->size();
+  Block *destBB = op->getSuccessor(succNo);
+  if (operandCount != destBB->getNumArguments())
+    return op->emitError() << "branch has " << operandCount
+                           << " operands for successor #" << succNo
+                           << ", but target block has "
+                           << destBB->getNumArguments();
+
+  // Check the types.
+  auto operandIt = operands->begin();
+  for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
+    if ((*operandIt).getType() != destBB->getArgument(i).getType())
+      return op->emitError() << "type mismatch for bb argument #" << i
+                             << " of successor #" << succNo;
+  }
+  return success();
+}

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 8eafbd8a2179..97b70cae9366 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIR
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
   )
-add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
+add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
 target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR)
 
 add_mlir_dialect_library(MLIRNVVMIR

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 75369142991b..567ddee94d7d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -153,6 +153,28 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LLVM::BrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool BrOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
+// LLVM::CondBrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? trueDestOperands() : falseDestOperands();
+}
+
+bool CondBrOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
@@ -229,9 +251,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
 }
 
 ///===----------------------------------------------------------------------===//
-/// Verifying/Printing/Parsing for LLVM::InvokeOp.
+/// LLVM::InvokeOp
 ///===----------------------------------------------------------------------===//
 
+Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? normalDestOperands() : unwindDestOperands();
+}
+
+bool InvokeOp::canEraseSuccessorOperand() { return true; }
+
 static LogicalResult verify(InvokeOp op) {
   if (op.getNumResults() > 1)
     return op.emitOpError("must have 0 or 1 result");

diff  --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index ad1bb4df2b5b..7e77f3e7866e 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRSPIRV
   )
 
 add_dependencies(MLIRSPIRV
+  MLIRControlFlowInterfacesIncGen
   MLIRSPIRVAvailabilityIncGen
   MLIRSPIRVCanonicalizationIncGen
   MLIRSPIRVEnumAvailabilityIncGen

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6a638673d584..907f8f82f62f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -942,10 +942,30 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.BranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
 
+Optional<OperandRange>
+spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
+  assert(index < 2 && "invalid successor index");
+  return index == kTrueIndex ? getTrueBlockArguments()
+                             : getFalseBlockArguments();
+}
+
+bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
+
 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
                                             OperationState &state) {
   auto &builder = parser.getBuilder();

diff  --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index c8af4702fbc1..9b8ffcdb1980 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRStandardOps
 add_dependencies(MLIRStandardOps
 
   MLIRCallOpInterfacesIncGen
+  MLIRControlFlowInterfacesIncGen
   MLIREDSC
   MLIRIR
   MLIRStandardOpsIncGen

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1059e66d1fc5..6cb1f21ccda5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -482,7 +482,7 @@ Block *BranchOp::getDest() { return getSuccessor(); }
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) {
-  getOperation()->eraseSuccessorOperand(0, index);
+  getOperation()->eraseOperand(index);
 }
 
 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
@@ -490,6 +490,13 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<SimplifyBrToBlockWithSinglePred>(context);
 }
 
+Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool BranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // CallOp
 //===----------------------------------------------------------------------===//
@@ -749,6 +756,13 @@ void CondBranchOp::getCanonicalizationPatterns(
   results.insert<SimplifyConstCondBranchPred>(context);
 }
 
+Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == trueIndex ? getTrueOperands() : getFalseOperands();
+}
+
+bool CondBranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // Constant*Op
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index bfd4b40b317b..2af425d65406 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -950,37 +950,13 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
   return success();
 }
 
-static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
-  Operation::operand_range operands = op->getSuccessorOperands(succNo);
-  unsigned operandCount = op->getNumSuccessorOperands(succNo);
-  Block *destBB = op->getSuccessor(succNo);
-  if (operandCount != destBB->getNumArguments())
-    return op->emitError() << "branch has " << operandCount
-                           << " operands for successor #" << succNo
-                           << ", but target block has "
-                           << destBB->getNumArguments();
-
-  auto operandIt = operands.begin();
-  for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
-    if ((*operandIt).getType() != destBB->getArgument(i).getType())
-      return op->emitError() << "type mismatch for bb argument #" << i
-                             << " of successor #" << succNo;
-  }
-
-  return success();
-}
-
 static LogicalResult verifyTerminatorSuccessors(Operation *op) {
   auto *parent = op->getParentRegion();
 
   // Verify that the operands lines up with the BB arguments in the successor.
-  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
-    auto *succ = op->getSuccessor(i);
+  for (Block *succ : op->getSuccessors())
     if (succ->getParent() != parent)
       return op->emitError("reference to block defined in another region");
-    if (failed(verifySuccessor(op, i)))
-      return failure();
-  }
   return success();
 }
 

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 107fc483b96d..25859a562cea 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -183,6 +183,13 @@ Type TypeRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
 OperandRange::OperandRange(Operation *op)
     : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
 
+/// Return the operand index of the first element of this range. The range
+/// must not be empty.
+unsigned OperandRange::getBeginOperandIndex() const {
+  assert(!empty() && "range must not be empty");
+  return base->getOperandNumber();
+}
+
 //===----------------------------------------------------------------------===//
 // ResultRange
 

diff  --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt
index 15459b9abaa1..d81500912f4d 100644
--- a/mlir/test/lib/TestDialect/CMakeLists.txt
+++ b/mlir/test/lib/TestDialect/CMakeLists.txt
@@ -16,6 +16,7 @@ add_llvm_library(MLIRTestDialect
   TestPatterns.cpp
 )
 add_dependencies(MLIRTestDialect
+  MLIRControlFlowInterfacesIncGen
   MLIRTestOpsIncGen
   MLIRTypeInferOpInterfaceIncGen
 )

diff  --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 3ded7b95b9f9..649b547626d9 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -163,6 +163,17 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return getOperands();
+}
+
+bool TestBranchOp::canEraseSuccessorOperand() { return true; }
+
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h
index 300327630073..8228f31434f5 100644
--- a/mlir/test/lib/TestDialect/TestDialect.h
+++ b/mlir/test/lib/TestDialect/TestDialect.h
@@ -15,6 +15,7 @@
 #define MLIR_TESTDIALECT_H
 
 #include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
 #include "mlir/Analysis/InferTypeOpInterface.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 26205e7eb855..5ee4a46eb818 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
 include "mlir/Analysis/CallInterfaces.td"
 include "mlir/Analysis/InferTypeOpInterface.td"
 
@@ -446,6 +447,11 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
    ]> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
+
+  let extraClassDeclaration = [{
+    LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
+                                        SmallVectorImpl<Value> &shapes);
+  }];
 }
 
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
@@ -454,7 +460,8 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
                      (I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">),
                      [(IsNotScalar $attr)]>;
 
-def TestBranchOp : TEST_Op<"br", [Terminator]> {
+def TestBranchOp : TEST_Op<"br",
+    [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
   let successors = (successor AnySuccessor:$target);
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ebd82f9fb8bf..7aa51bdbf1cf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1155,8 +1155,8 @@ void OpEmitter::genOpInterfaceMethods() {
       continue;
     auto interface = opTrait->getOpInterface();
     for (auto method : interface.getMethods()) {
-      // Don't declare if the method has a body.
-      if (method.getBody())
+      // Don't declare if the method has a body or a default implementation.
+      if (method.getBody() || method.getDefaultImplementation())
         continue;
       std::string args;
       llvm::raw_string_ostream os(args);


        


More information about the Mlir-commits mailing list