[Mlir-commits] [mlir] b1de971 - [mlir][ODS] Add support for specifying the successors of an operation.

River Riddle llvmlistbot at llvm.org
Fri Feb 21 15:17:43 PST 2020


Author: River Riddle
Date: 2020-02-21T15:15:32-08:00
New Revision: b1de971ba8c83c82ef63077b666aaff3ba8e56b9

URL: https://github.com/llvm/llvm-project/commit/b1de971ba8c83c82ef63077b666aaff3ba8e56b9
DIFF: https://github.com/llvm/llvm-project/commit/b1de971ba8c83c82ef63077b666aaff3ba8e56b9.diff

LOG: [mlir][ODS] Add support for specifying the successors of an operation.

This revision add support in ODS for specifying the successors of an operation. Successors are specified via the `successors` list:
```
let successors = (successor AnySuccessor:$target, AnySuccessor:$otherTarget);
```

Differential Revision: https://reviews.llvm.org/D74783

Added: 
    mlir/include/mlir/TableGen/Successor.h
    mlir/lib/TableGen/Successor.cpp

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Constraint.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/TableGen/CMakeLists.txt
    mlir/lib/TableGen/Constraint.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/Dialect/SPIRV/control-flow-ops.mlir
    mlir/test/lib/TestDialect/TestOps.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index fc648362c9fb..70d718e22578 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -279,6 +279,24 @@ Similar to variadic operands, `Variadic<...>` can also be used for results.
 And similarly, `SameVariadicResultSize` for multiple variadic results in the
 same operation.
 
+### Operation successors
+
+For terminator operations, the successors are specified inside of the
+`dag`-typed `successors`, led by `successor`:
+
+```tablegen
+let successors = (successor
+  <successor-constraint>:$<successor-name>,
+  ...
+);
+```
+
+#### Variadic successors
+
+Similar to the `Variadic` class used for variadic operands and results,
+`VariadicSuccessor<...>` can be used for successors. Variadic successors can
+currently only be specified as the last successor in the successor list.
+
 ### Operation traits and constraints
 
 Traits are operation properties that affect syntax or semantics. MLIR C++

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a58d9af0666a..840b4396134e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -31,6 +31,25 @@ def LLVMInt : TypeConstraint<
        CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
   "LLVM dialect integer">;
 
+def LLVMIntBase : TypeConstraint<
+  And<[LLVM_Type.predicate,
+       CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
+  "LLVM dialect integer">;
+
+// Integer type of a specific width.
+class LLVMI<int width>
+    : Type<And<[
+        LLVM_Type.predicate,
+        CPred<
+         "$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
+         "LLVM dialect " # width # "-bit integer">,
+      BuildableType<
+        "::mlir::LLVM::LLVMType::getIntNTy("
+        "$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
+         # width # ")">;
+
+def LLVMI1 : LLVMI<1>;
+
 // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
 // used to translate to LLVM IR proper.
 class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d73fd1187431..e0848c2d6d77 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -72,8 +72,7 @@ class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
 // Base class for LLVM terminator operations.  All terminator operations have
 // zero results and an optional list of successors.
 class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
-    LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
-    Arguments<(ins Variadic<LLVM_Type>:$args)>, Results<(outs)> {
+    LLVM_Op<mnemonic, !listconcat(traits, [Terminator])> {
   let builders = [
     OpBuilder<
       "Builder *, OperationState &result, "
@@ -320,15 +319,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
                     Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                                Variadic<LLVM_Type>)>,
                     Results<(outs Variadic<LLVM_Type>)> {
+  let successors = (successor AnySuccessor:$normalDest,
+                              AnySuccessor:$unwindDest);
+
   let builders = [OpBuilder<
-    "Builder *b, OperationState &result, ArrayRef<Type> tys, "
-    "FlatSymbolRefAttr callee, ValueRange ops, Block* normal, "
-    "ValueRange normalOps, Block* unwind, ValueRange unwindOps",
-    [{
-      result.addAttribute("callee", callee);
-      build(b, result, tys, ops, normal, normalOps, unwind, unwindOps);
-    }]>,
-    OpBuilder<
     "Builder *b, OperationState &result, ArrayRef<Type> tys, "
     "ValueRange ops, Block* normal, "
     "ValueRange normalOps, Block* unwind, ValueRange unwindOps",
@@ -460,19 +454,19 @@ def LLVM_SelectOp
 
 // Terminators.
 def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+  let successors = (successor AnySuccessor:$dest);
   let parser = [{ return parseBrOp(parser, result); }];
   let printer = [{ printBrOp(p, *this); }];
 }
 def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
-  let verifier = [{
-    if (getNumSuccessors() != 2)
-      return emitOpError("expected exactly two successors");
-    return success();
-  }];
+  let arguments = (ins LLVMI1:$condition);
+  let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
+
   let parser = [{ return parseCondBrOp(parser, result); }];
   let printer = [{ printCondBrOp(p, *this); }];
 }
-def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> {
+def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
+                    Arguments<(ins Variadic<LLVM_Type>:$args)> {
   string llvmBuilder = [{
     if ($_numOperands != 0)
       builder.CreateRet($args[0]);

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index ac95214b6184..433c1323cdee 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -41,12 +41,14 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
     ```
   }];
 
-  let arguments = (ins
-    Variadic<AnyType>:$block_arguments
-  );
+  let arguments = (ins);
 
   let results = (outs);
 
+  let successors = (successor AnySuccessor:$target);
+
+  let verifier = [{ return success(); }];
+
   let builders = [
     OpBuilder<
       "Builder *, OperationState &state, "
@@ -60,12 +62,10 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
 
   let extraClassDeclaration = [{
     /// Returns the branch target block.
-    Block *getTarget() { return getOperation()->getSuccessor(0); }
+    Block *getTarget() { return target(); }
 
     /// Returns the block arguments.
-    operand_range getBlockArguments() {
-      return getOperation()->getSuccessorOperands(0);
-    }
+    operand_range getBlockArguments() { return targetOperands(); }
   }];
 
   let autogenSerialization = 0;
@@ -115,12 +115,14 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
 
   let arguments = (ins
     SPV_Bool:$condition,
-    Variadic<AnyType>:$branch_arguments,
     OptionalAttr<I32ArrayAttr>:$branch_weights
   );
 
   let results = (outs);
 
+  let successors = (successor AnySuccessor:$trueTarget,
+                               AnySuccessor:$falseTarget);
+
   let builders = [
     OpBuilder<
       "Builder *builder, OperationState &state, Value condition, "

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 1fc4330cefdd..abe92e2afb28 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -232,12 +232,10 @@ def BranchOp : Std_Op<"br", [Terminator]> {
       ^bb3(%3: tensor<*xf32>):
   }];
 
-  let arguments = (ins Variadic<AnyType>:$operands);
+  let successors = (successor AnySuccessor:$dest);
 
-  let builders = [OpBuilder<
-    "Builder *, OperationState &result, Block *dest,"
-    "ValueRange operands = {}", [{
-      result.addSuccessor(dest, operands);
+  let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{
+    result.addSuccessor(dest, llvm::None);
   }]>];
 
   // BranchOp is fully verified by traits.
@@ -513,16 +511,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
          ...
   }];
 
-  let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
-
-  let builders = [OpBuilder<
-    "Builder *, OperationState &result, Value condition,"
-    "Block *trueDest, ValueRange trueOperands,"
-    "Block *falseDest, ValueRange falseOperands", [{
-      result.addOperands(condition);
-      result.addSuccessor(trueDest, trueOperands);
-      result.addSuccessor(falseDest, falseOperands);
-  }]>];
+  let arguments = (ins I1:$condition);
+  let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
 
   // CondBranchOp is fully verified by traits.
   let verifier = ?;

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 3dba6a09c5a7..25c0238946a9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -185,6 +185,10 @@ class AttrConstraint<Pred predicate, string description = ""> :
 class RegionConstraint<Pred predicate, string description = ""> :
     Constraint<predicate, description>;
 
+// Subclass for constraints on a successor.
+class SuccessorConstraint<Pred predicate, string description = ""> :
+    Constraint<predicate, description>;
+
 // How to use these constraint categories:
 //
 // * Use TypeConstraint to specify
@@ -1341,6 +1345,21 @@ class SizedRegion<int numBlocks> : Region<
   CPred<"$_self.getBlocks().size() == " # numBlocks>,
   "region with " # numBlocks # " blocks">;
 
+//===----------------------------------------------------------------------===//
+// Successor definitions
+//===----------------------------------------------------------------------===//
+
+class Successor<Pred condition, string descr = ""> :
+    SuccessorConstraint<condition, descr>;
+
+// Any successor.
+def AnySuccessor : Successor<?, "any successor">;
+
+// A variadic successor constraint. It expands to zero or more of the base
+// successor.
+class VariadicSuccessor<Successor successor>
+  : Successor<successor.predicate, successor.description>;
+
 //===----------------------------------------------------------------------===//
 // OpTrait definitions
 //===----------------------------------------------------------------------===//
@@ -1537,6 +1556,9 @@ def outs;
 // Marker used to identify the region list for an op.
 def region;
 
+// Marker used to identify the successor list for an op.
+def successor;
+
 // Class for defining a custom builder.
 //
 // TableGen generates several generic builders for each op by default (see
@@ -1587,6 +1609,9 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // The list of regions of the op. Default to 0 regions.
   dag regions = (region);
 
+  // The list of successors of the op. Default to 0 successors.
+  dag successors = (successor);
+
   // Attribute getters can be added to the op by adding an Attr member
   // with the name and type of the attribute. E.g., adding int attribute
   // with name "value" and type "i32":

diff  --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 105fc4075647..775b3545a034 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -48,7 +48,7 @@ class Constraint {
   StringRef getDescription() const;
 
   // Constraint kind
-  enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
+  enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
 
   Kind getKind() const { return kind; }
 

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index d9a458ebdf33..e83b25231a87 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -19,6 +19,7 @@
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Region.h"
+#include "mlir/TableGen/Successor.h"
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
@@ -138,6 +139,20 @@ class Operator {
   // Returns the `index`-th region.
   const NamedRegion &getRegion(unsigned index) const;
 
+  // Successors.
+  using const_successor_iterator = const NamedSuccessor *;
+  const_successor_iterator successor_begin() const;
+  const_successor_iterator successor_end() const;
+  llvm::iterator_range<const_successor_iterator> getSuccessors() const;
+
+  // Returns the number of successors.
+  unsigned getNumSuccessors() const;
+  // Returns the `index`-th successor.
+  const NamedSuccessor &getSuccessor(unsigned index) const;
+
+  // Returns the number of variadic successors in this operation.
+  unsigned getNumVariadicSuccessors() const;
+
   // Trait.
   using const_trait_iterator = const OpTrait *;
   const_trait_iterator trait_begin() const;
@@ -193,6 +208,9 @@ class Operator {
   // The results of the op.
   SmallVector<NamedTypeConstraint, 4> results;
 
+  // The successors of this op.
+  SmallVector<NamedSuccessor, 0> successors;
+
   // The traits of the op.
   SmallVector<OpTrait, 4> traits;
 

diff  --git a/mlir/include/mlir/TableGen/Successor.h b/mlir/include/mlir/TableGen/Successor.h
new file mode 100644
index 000000000000..0659a983f0d5
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Successor.h
@@ -0,0 +1,44 @@
+//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_SUCCESSOR_H_
+#define MLIR_TABLEGEN_SUCCESSOR_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing helper methods for accessing Successor defined in
+// TableGen.
+class Successor : public Constraint {
+public:
+  using Constraint::Constraint;
+
+  static bool classof(const Constraint *c) {
+    return c->getKind() == CK_Successor;
+  }
+
+  // Returns true if this successor is variadic.
+  bool isVariadic() const;
+};
+
+// A struct bundling a successor's constraint and its name.
+struct NamedSuccessor {
+  // Returns true if this successor is variadic.
+  bool isVariadic() const { return constraint.isVariadic(); }
+
+  StringRef name;
+  Successor constraint;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_SUCCESSOR_H_

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a63a593dfc1e..880c95c441a4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -234,15 +234,14 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
 static LogicalResult verify(InvokeOp op) {
   if (op.getNumResults() > 1)
     return op.emitOpError("must have 0 or 1 result");
-  if (op.getNumSuccessors() != 2)
-    return op.emitOpError("must have normal and unwind destinations");
 
-  if (op.getSuccessor(1)->empty())
+  Block *unwindDest = op.unwindDest();
+  if (unwindDest->empty())
     return op.emitError(
         "must have at least one operation in unwind destination");
 
   // In unwind destination, first operation must be LandingpadOp
-  if (!isa<LandingpadOp>(op.getSuccessor(1)->front()))
+  if (!isa<LandingpadOp>(unwindDest->front()))
     return op.emitError("first operation in unwind destination should be a "
                         "llvm.landingpad operation");
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 89c243bc09ca..0a7f93f58367 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1036,14 +1036,6 @@ static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
   printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
 }
 
-static LogicalResult verify(spirv::BranchOp branchOp) {
-  auto *op = branchOp.getOperation();
-  if (op->getNumSuccessors() != 1)
-    branchOp.emitOpError("must have exactly one successor");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
@@ -1114,10 +1106,6 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
 }
 
 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
-  auto *op = branchOp.getOperation();
-  if (op->getNumSuccessors() != 2)
-    return branchOp.emitOpError("must have exactly two successors");
-
   if (auto weights = branchOp.branch_weights()) {
     if (weights->getValue().size() != 2) {
       return branchOp.emitOpError("must have exactly two branch weights");

diff  --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 08bf24029c44..6317c669192e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
   OpTrait.cpp
   Pattern.cpp
   Predicate.cpp
+  Successor.cpp
   Type.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 251d15a0c806..98bb7d63d06d 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -23,6 +23,8 @@ Constraint::Constraint(const llvm::Record *record)
     kind = CK_Attr;
   } else if (record->isSubClassOf("RegionConstraint")) {
     kind = CK_Region;
+  } else if (record->isSubClassOf("SuccessorConstraint")) {
+    kind = CK_Successor;
   } else {
     assert(record->isSubClassOf("Constraint"));
   }

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 5e338b37a00f..2fd2997970b6 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -159,6 +159,31 @@ const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
   return regions[index];
 }
 
+auto tblgen::Operator::successor_begin() const -> const_successor_iterator {
+  return successors.begin();
+}
+auto tblgen::Operator::successor_end() const -> const_successor_iterator {
+  return successors.end();
+}
+auto tblgen::Operator::getSuccessors() const
+    -> llvm::iterator_range<const_successor_iterator> {
+  return {successor_begin(), successor_end()};
+}
+
+unsigned tblgen::Operator::getNumSuccessors() const {
+  return successors.size();
+}
+
+const tblgen::NamedSuccessor &
+tblgen::Operator::getSuccessor(unsigned index) const {
+  return successors[index];
+}
+
+unsigned tblgen::Operator::getNumVariadicSuccessors() const {
+  return llvm::count_if(successors,
+                        [](const NamedSuccessor &c) { return c.isVariadic(); });
+}
+
 auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
   return traits.begin();
 }
@@ -285,6 +310,29 @@ void tblgen::Operator::populateOpStructure() {
     results.push_back({name, TypeConstraint(resultDef)});
   }
 
+  // Handle successors
+  auto *successorsDag = def.getValueAsDag("successors");
+  auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
+  if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
+    PrintFatalError(def.getLoc(),
+                    "'successors' must have 'successor' directive");
+  }
+
+  for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
+    auto name = successorsDag->getArgNameStr(i);
+    auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
+    if (!successorInit) {
+      PrintFatalError(def.getLoc(),
+                      Twine("undefined kind for successor #") + Twine(i));
+    }
+    Successor successor(successorInit->getDef());
+
+    // Only support variadic successors if it is the last one for now.
+    if (i != e - 1 && successor.isVariadic())
+      PrintFatalError(def.getLoc(), "only the last successor can be variadic");
+    successors.push_back({name, successor});
+  }
+
   // Create list of traits, skipping over duplicates: appending to lists in
   // tablegen is easy, making them unique less so, so dedupe here.
   if (auto traitList = def.getValueAsListInit("traits")) {

diff  --git a/mlir/lib/TableGen/Successor.cpp b/mlir/lib/TableGen/Successor.cpp
new file mode 100644
index 000000000000..80c6e9bb74d3
--- /dev/null
+++ b/mlir/lib/TableGen/Successor.cpp
@@ -0,0 +1,24 @@
+//===- Successor.cpp - Successor class ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Successor wrapper to simplify using TableGen Record defining a MLIR
+// Successor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Successor.h"
+#include "mlir/ADT/TypeSwitch.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+// Returns true if this successor is variadic.
+bool Successor::isVariadic() const {
+  return def->isSubClassOf("VariadicSuccessor");
+}

diff  --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 55ca9a45c6c5..4201411783c5 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -32,7 +32,7 @@ func @missing_accessor() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{must have exactly one successor}}
+  // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}}
   "spv.Branch"()[^one, ^two] : () -> ()
 ^one:
   spv.Return
@@ -116,7 +116,7 @@ func @wrong_condition_type() -> () {
 
 func @wrong_accessor_count() -> () {
   %true = spv.constant true
-  // expected-error @+1 {{must have exactly two successors}}
+  // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}}
   "spv.BranchConditional"(%true)[^one] : (i1) -> ()
 ^one:
   spv.Return

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 756e9cd98428..a7acfaaeeada 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -431,7 +431,7 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
                      [(IsNotScalar $attr)]>;
 
 def TestBranchOp : TEST_Op<"br", [Terminator]> {
-  let arguments = (ins Variadic<AnyType>:$operands);
+  let successors = (successor AnySuccessor:$target);
 }
 
 def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2d6999275e9a..d8a4b91c026e 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -183,6 +183,9 @@ class OpEmitter {
   // Generates getters for named regions.
   void genNamedRegionGetters();
 
+  // Generates getters for named successors.
+  void genNamedSuccessorGetters();
+
   // Generates builder methods for the operation.
   void genBuilder();
 
@@ -266,6 +269,10 @@ class OpEmitter {
   // The generated code will be attached to `body`.
   void genRegionVerifier(OpMethodBody &body);
 
+  // Generates verify statements for successors in the operation.
+  // The generated code will be attached to `body`.
+  void genSuccessorVerifier(OpMethodBody &body);
+
   // Generates the traits used by the object.
   void genTraits();
 
@@ -302,6 +309,7 @@ OpEmitter::OpEmitter(const Operator &op)
   genNamedOperandGetters();
   genNamedResultGetters();
   genNamedRegionGetters();
+  genNamedSuccessorGetters();
   genAttrGetters();
   genAttrSetters();
   genBuilder();
@@ -579,6 +587,42 @@ void OpEmitter::genNamedRegionGetters() {
   }
 }
 
+void OpEmitter::genNamedSuccessorGetters() {
+  unsigned numSuccessors = op.getNumSuccessors();
+  for (unsigned i = 0; i < numSuccessors; ++i) {
+    const NamedSuccessor &successor = op.getSuccessor(i);
+    if (successor.name.empty())
+      continue;
+
+    // Generate the accessors for a variadic successor.
+    if (successor.isVariadic()) {
+      // Generate the getter.
+      auto &m = opClass.newMethod("SuccessorRange", successor.name);
+      m.body() << formatv(
+          "  return {std::next(this->getOperation()->successor_begin(), {0}), "
+          "this->getOperation()->successor_end()};",
+          i);
+      continue;
+    }
+
+    // Generate the block getter.
+    auto &m = opClass.newMethod("Block *", successor.name);
+    m.body() << formatv("  return this->getOperation()->getSuccessor({0});", i);
+
+    // Generate the all-operands getter.
+    auto &operandsMethod = opClass.newMethod(
+        "Operation::operand_range", (successor.name + "Operands").str());
+    operandsMethod.body() << formatv(
+        " return this->getOperation()->getSuccessorOperands({0});", i);
+
+    // Generate the individual-operand getter.
+    auto &operandMethod = opClass.newMethod(
+        "Value", (successor.name + "Operand").str(), "unsigned index");
+    operandMethod.body() << formatv(
+        " return this->getOperation()->getSuccessorOperand({0}, index);", i);
+  }
+}
+
 static bool canGenerateUnwrappedBuilder(Operator &op) {
   // If this op does not have native attributes at all, return directly to avoid
   // redefining builders.
@@ -869,8 +913,9 @@ void OpEmitter::genCollectiveParamBuilder() {
   // Generate builder that infers type too.
   // TODO(jpienaar): Subsume this with general checking if type can be infered
   // automatically.
-  // TODO(jpienaar): Expand to handle regions.
-  if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
+  // TODO(jpienaar): Expand to handle regions and successors.
+  if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 &&
+      op.getNumSuccessors() == 0)
     genInferedTypeCollectiveParamBuilder();
 }
 
@@ -982,17 +1027,28 @@ void OpEmitter::buildParamList(std::string &paramList,
       ++numAttrs;
     }
   }
+
+  /// Insert parameters for the block and operands for each successor.
+  const char *variadicSuccCode =
+      ", ArrayRef<Block *> {0}, ArrayRef<ValueRange> {0}Operands";
+  const char *succCode = ", Block *{0}, ValueRange {0}Operands";
+  for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
+    if (namedSuccessor.isVariadic())
+      paramList += llvm::formatv(variadicSuccCode, namedSuccessor.name).str();
+    else
+      paramList += llvm::formatv(succCode, namedSuccessor.name).str();
+  }
 }
 
 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
                                                        bool isRawValueAttr) {
-  // Push all operands to the result
+  // Push all operands to the result.
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
     body << "  " << builderOpState << ".addOperands(" << getArgumentName(op, i)
          << ");\n";
   }
 
-  // Push all attributes to the result
+  // Push all attributes to the result.
   for (const auto &namedAttr : op.getAttributes()) {
     auto &attr = namedAttr.attr;
     if (!attr.isDerivedAttr()) {
@@ -1030,11 +1086,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
     }
   }
 
-  // Create the correct number of regions
+  // Create the correct number of regions.
   if (int numRegions = op.getNumRegions()) {
     for (int i = 0; i < numRegions; ++i)
       body << "  (void)" << builderOpState << ".addRegion();\n";
   }
+
+  // Push all successors to the result.
+  for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
+    if (namedSuccessor.isVariadic()) {
+      body << formatv("  for (int i = 0, e = {1}.size(); i != e; ++i)\n"
+                      "    {0}.addSuccessor({1}[i], {1}Operands[i]);\n",
+                      builderOpState, namedSuccessor.name);
+      continue;
+    }
+
+    body << formatv("  {0}.addSuccessor({1}, {1}Operands);\n", builderOpState,
+                    namedSuccessor.name);
+  }
 }
 
 void OpEmitter::genCanonicalizerDecls() {
@@ -1228,6 +1297,7 @@ void OpEmitter::genVerifier() {
   }
 
   genRegionVerifier(body);
+  genSuccessorVerifier(body);
 
   if (hasCustomVerify) {
     FmtContext fctx;
@@ -1305,6 +1375,58 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
   }
 }
 
+void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
+  unsigned numSuccessors = op.getNumSuccessors();
+
+  const char *checkSuccessorSizeCode = R"(
+  if (this->getOperation()->getNumSuccessors() {0} {1}) {
+    return emitOpError("has incorrect number of successors: expected{2} {1}"
+                       " but found ")
+             << this->getOperation()->getNumSuccessors();
+  }
+  )";
+
+  // Verify this op has the correct number of successors.
+  unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
+  if (numVariadicSuccessors == 0) {
+    body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, "");
+  } else if (numVariadicSuccessors != numSuccessors) {
+    body << formatv(checkSuccessorSizeCode, "<",
+                    numSuccessors - numVariadicSuccessors, " at least");
+  }
+
+  // If we have no successors, there is nothing more to do.
+  if (numSuccessors == 0)
+    return;
+
+  body << "{\n";
+  body << "    unsigned index = 0; (void)index;\n";
+
+  for (unsigned i = 0; i < numSuccessors; ++i) {
+    const auto &successor = op.getSuccessor(i);
+    if (successor.constraint.getPredicate().isNull())
+      continue;
+
+    body << "    for (Block *successor : ";
+    body << formatv(successor.isVariadic() ? "{0}()"
+                                           : "ArrayRef<Block *>({0}())",
+                    successor.name);
+    body << ") {\n";
+    auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
+                            &verifyCtx.withSelf("successor"))
+                          .str();
+
+    body << formatv(
+        "      (void)successor;\n"
+        "      if (!({0})) {\n        "
+        "return emitOpError(\"successor #\") << index << \"('{2}') failed to "
+        "verify constraint: {3}\";\n      }\n",
+        constraint, i, successor.name, successor.constraint.getDescription());
+    body << "    }\n";
+  }
+  body << "  }\n";
+}
+
 void OpEmitter::genTraits() {
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariadicResults();
@@ -1342,7 +1464,9 @@ void OpEmitter::genTraits() {
   int numVariadicOperands = op.getNumVariadicOperands();
 
   // Add operand size trait.
-  if (numVariadicOperands != 0) {
+  // Note: Successor operands are also included in the operation's operand list,
+  // so we always need to use VariadicOperands in the presence of successors.
+  if (numVariadicOperands != 0 || op.getNumSuccessors()) {
     if (numOperands == numVariadicOperands)
       opClass.addTrait("OpTrait::VariadicOperands");
     else


        


More information about the Mlir-commits mailing list