[Mlir-commits] [mlir] 9eb436f - [mlir][DeclarativeParser] Add support for formatting the successors of an operation.

River Riddle llvmlistbot at llvm.org
Fri Feb 21 15:17:45 PST 2020


Author: River Riddle
Date: 2020-02-21T15:15:32-08:00
New Revision: 9eb436feaa7f5f01dc4852396647a5b46311c8eb

URL: https://github.com/llvm/llvm-project/commit/9eb436feaa7f5f01dc4852396647a5b46311c8eb
DIFF: https://github.com/llvm/llvm-project/commit/9eb436feaa7f5f01dc4852396647a5b46311c8eb.diff

LOG: [mlir][DeclarativeParser] Add support for formatting the successors of an operation.

This revision add support for formatting successor variables in a similar way to operands, attributes, etc.

Differential Revision: https://reviews.llvm.org/D74789

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/Dialect/SPIRV/control-flow-ops.mlir
    mlir/test/IR/invalid.mlir
    mlir/test/lib/TestDialect/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 70d718e22578..30406b7e4e83 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -625,6 +625,10 @@ The available directives are as follows:
 
     -   Represents all of the results of an operation.
 
+*   `successors`
+
+    -   Represents all of the successors of an operation.
+
 *   `type` ( input )
 
     -   Represents the type of the given input.
@@ -641,8 +645,8 @@ The following are the set of valid punctuation:
 #### Variables
 
 A variable is an entity that has been registered on the operation itself, i.e.
-an argument(attribute or operand), result, etc. In the `CallOp` example above,
-the variables would be `$callee`  and `$args`.
+an argument(attribute or operand), result, successor, etc. In the `CallOp`
+example above, the variables would be `$callee` and `$args`.
 
 Attribute variables are printed with their respective value type, unless that
 value type is buildable. In those cases, the type of the attribute is elided.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e0848c2d6d77..5be06a8bfe72 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -455,15 +455,12 @@ def LLVM_SelectOp
 // Terminators.
 def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
   let successors = (successor AnySuccessor:$dest);
-  let parser = [{ return parseBrOp(parser, result); }];
-  let printer = [{ printBrOp(p, *this); }];
+  let assemblyFormat = "$dest attr-dict";
 }
 def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
   let arguments = (ins LLVMI1:$condition);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
-
-  let parser = [{ return parseCondBrOp(parser, result); }];
-  let printer = [{ printCondBrOp(p, *this); }];
+  let assemblyFormat = "$condition `,` successors attr-dict";
 }
 def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
                     Arguments<(ins Variadic<LLVM_Type>:$args)> {

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 433c1323cdee..03884afe3e95 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -69,6 +69,8 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
   }];
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = "successors attr-dict";
 }
 
 // -----

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index abe92e2afb28..83c781a19b18 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -250,6 +250,7 @@ def BranchOp : Std_Op<"br", [Terminator]> {
   }];
 
   let hasCanonicalizer = 1;
+  let assemblyFormat = "$dest attr-dict";
 }
 
 def CallOp : Std_Op<"call", [CallOpInterface]> {
@@ -602,6 +603,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
   }];
 
   let hasCanonicalizer = 1;
+  let assemblyFormat = "$condition `,` successors attr-dict";
 }
 
 def ConstantOp : Std_Op<"constant",

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 08fac5ea19ef..54ecf8cde39c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -578,6 +578,11 @@ class OpAsmParser {
   virtual ParseResult
   parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
 
+  /// Parse an optional operation successor and its operand list.
+  virtual OptionalParseResult
+  parseOptionalSuccessorAndUseList(Block *&dest,
+                                   SmallVectorImpl<Value> &operands) = 0;
+
   //===--------------------------------------------------------------------===//
   // Type Parsing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 880c95c441a4..c1773bbd8120 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -780,69 +780,6 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::BrOp.
-//===----------------------------------------------------------------------===//
-
-static void printBrOp(OpAsmPrinter &p, BrOp &op) {
-  p << op.getOperationName() << ' ';
-  p.printSuccessorAndUseList(op.getOperation(), 0);
-  p.printOptionalAttrDict(op.getAttrs());
-}
-
-// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
-// attribute-dict?
-static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
-  Block *dest;
-  SmallVector<Value, 4> operands;
-  if (parser.parseSuccessorAndUseList(dest, operands) ||
-      parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-
-  result.addSuccessor(dest, operands);
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CondBrOp.
-//===----------------------------------------------------------------------===//
-
-static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
-  p << op.getOperationName() << ' ' << op.getOperand(0) << ", ";
-  p.printSuccessorAndUseList(op.getOperation(), 0);
-  p << ", ";
-  p.printSuccessorAndUseList(op.getOperation(), 1);
-  p.printOptionalAttrDict(op.getAttrs());
-}
-
-// <operation> ::= `llvm.cond_br` ssa-use `,`
-//                  bb-id (`[` ssa-use-and-type-list `]`)? `,`
-//                  bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
-static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) {
-  Block *trueDest;
-  Block *falseDest;
-  SmallVector<Value, 4> trueOperands;
-  SmallVector<Value, 4> falseOperands;
-  OpAsmParser::OperandType condition;
-
-  Builder &builder = parser.getBuilder();
-  auto *llvmDialect =
-      builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-  auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
-
-  if (parser.parseOperand(condition) || parser.parseComma() ||
-      parser.parseSuccessorAndUseList(trueDest, trueOperands) ||
-      parser.parseComma() ||
-      parser.parseSuccessorAndUseList(falseDest, falseOperands) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.resolveOperand(condition, i1Type, result.operands))
-    return failure();
-
-  result.addSuccessor(trueDest, trueOperands);
-  result.addSuccessor(falseDest, falseOperands);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::ReturnOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 0a7f93f58367..01197498a704 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1018,24 +1018,6 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
   results.insert<ConvertChainedBitcast>(context);
 }
 
-//===----------------------------------------------------------------------===//
-// spv.BranchOp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
-  Block *dest;
-  SmallVector<Value, 4> destOperands;
-  if (parser.parseSuccessorAndUseList(dest, destOperands))
-    return failure();
-  state.addSuccessor(dest, destOperands);
-  return success();
-}
-
-static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
-  printer << spirv::BranchOp::getOperationName() << ' ';
-  printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
-}
-
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 80f85e02289b..5c5fcfc47c11 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -414,20 +414,6 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
 };
 } // end anonymous namespace.
 
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
-  Block *dest;
-  SmallVector<Value, 4> destOperands;
-  if (parser.parseSuccessorAndUseList(dest, destOperands))
-    return failure();
-  result.addSuccessor(dest, destOperands);
-  return success();
-}
-
-static void print(OpAsmPrinter &p, BranchOp op) {
-  p << "br ";
-  p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
 Block *BranchOp::getDest() { return getSuccessor(0); }
 
 void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
@@ -810,42 +796,6 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
 };
 } // end anonymous namespace.
 
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
-                                     OperationState &result) {
-  SmallVector<Value, 4> destOperands;
-  Block *dest;
-  OpAsmParser::OperandType condInfo;
-
-  // Parse the condition.
-  Type int1Ty = parser.getBuilder().getI1Type();
-  if (parser.parseOperand(condInfo) || parser.parseComma() ||
-      parser.resolveOperand(condInfo, int1Ty, result.operands)) {
-    return parser.emitError(parser.getNameLoc(),
-                            "expected condition type was boolean (i1)");
-  }
-
-  // Parse the true successor.
-  if (parser.parseSuccessorAndUseList(dest, destOperands))
-    return failure();
-  result.addSuccessor(dest, destOperands);
-
-  // Parse the false successor.
-  destOperands.clear();
-  if (parser.parseComma() ||
-      parser.parseSuccessorAndUseList(dest, destOperands))
-    return failure();
-  result.addSuccessor(dest, destOperands);
-
-  return success();
-}
-
-static void print(OpAsmPrinter &p, CondBranchOp op) {
-  p << "cond_br " << op.getCondition() << ", ";
-  p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
-  p << ", ";
-  p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
 void CondBranchOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<SimplifyConstCondBranchPred>(context);

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index a542b117b423..3929cdd10a97 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -4423,6 +4423,15 @@ class CustomOpAsmParser : public OpAsmParser {
     return parser.parseSuccessorAndUseList(dest, operands);
   }
 
+  /// Parse an optional operation successor and its operand list.
+  OptionalParseResult
+  parseOptionalSuccessorAndUseList(Block *&dest,
+                                   SmallVectorImpl<Value> &operands) override {
+    if (parser.getToken().isNot(Token::caret_identifier))
+      return llvm::None;
+    return parseSuccessorAndUseList(dest, operands);
+  }
+
   //===--------------------------------------------------------------------===//
   // Type Parsing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 4201411783c5..cc7e09fcd1e1 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -24,8 +24,8 @@ func @branch_argument() -> () {
 // -----
 
 func @missing_accessor() -> () {
+  // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}}
   spv.Branch
-  // expected-error @+1 {{expected block name}}
 }
 
 // -----

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index a8a72fd87f12..12e3bfdcdec5 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -402,7 +402,6 @@ func @condbr_notbool() {
 ^bb0:
   %a = "foo"() : () -> i32 // expected-note {{prior use here}}
   cond_br %a, ^bb0, ^bb0 // expected-error {{use of value '%a' expects 
diff erent type than prior uses: 'i1' vs 'i32'}}
-// expected-error at -1 {{expected condition type was boolean (i1)}}
 }
 
 // -----

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index a7acfaaeeada..135dcdea6fd9 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -1139,4 +1139,9 @@ def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
   $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
 }]>;
 
+def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
+  let successors = (successor VariadicSuccessor<AnySuccessor>:$targets);
+  let assemblyFormat = "$targets attr-dict";
+}
+
 #endif // TEST_OPS

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index e4483cc2a638..ac5aa259a907 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -94,10 +94,18 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
 // results
 
 // CHECK: error: 'results' directive can not be used as a top-level directive
-def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{
+def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{
   results
 }]>;
 
+//===----------------------------------------------------------------------===//
+// successors
+
+// CHECK: error: 'successors' is only valid as a top-level directive
+def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{
+  type(successors)
+}]>;
+
 //===----------------------------------------------------------------------===//
 // type
 
@@ -235,7 +243,7 @@ def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
 // Variables
 //===----------------------------------------------------------------------===//
 
-// CHECK: error: expected variable to refer to a argument or result
+// CHECK: error: expected variable to refer to a argument, result, or successor
 def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
   $unknown_arg attr-dict
 }]>;
@@ -255,6 +263,18 @@ def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{
 def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{
   $result attr-dict
 }]>, Results<(outs I64:$result)>;
+// CHECK: error: successor 'successor' is already bound
+def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{
+  $successor $successor attr-dict
+}]> {
+  let successors = (successor AnySuccessor:$successor);
+}
+// CHECK: error: successor 'successor' is already bound
+def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
+  successors $successor attr-dict
+}]> {
+  let successors = (successor AnySuccessor:$successor);
+}
 
 //===----------------------------------------------------------------------===//
 // Coverage Checks

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 42ddc201f6d5..6c53527d649f 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -41,3 +41,19 @@ test.format_operand_d_op %i64, %memref : memref<1xf64>
 
 // CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
 test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
+
+"foo.successor_test_region"() ( {
+  ^bb0:
+    // CHECK: test.format_successor_a_op ^bb1 {attr}
+    test.format_successor_a_op ^bb1 {attr}
+
+  ^bb1:
+    // CHECK: test.format_successor_a_op ^bb1, ^bb2 {attr}
+    test.format_successor_a_op ^bb1, ^bb2 {attr}
+
+  ^bb2:
+    // CHECK: test.format_successor_a_op {attr}
+    test.format_successor_a_op {attr}
+
+}) { arg_names = ["i", "j", "k"] } : () -> ()
+

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 77b7b33615d6..3655bf795c8b 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -49,6 +49,7 @@ class Element {
     FunctionalTypeDirective,
     OperandsDirective,
     ResultsDirective,
+    SuccessorsDirective,
     TypeDirective,
 
     /// This element is a literal.
@@ -58,6 +59,7 @@ class Element {
     AttributeVariable,
     OperandVariable,
     ResultVariable,
+    SuccessorVariable,
 
     /// This element is an optional element.
     Optional,
@@ -105,6 +107,10 @@ using OperandVariable =
 /// This class represents a variable that refers to a result.
 using ResultVariable =
     VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
+
+/// This class represents a variable that refers to a successor.
+using SuccessorVariable =
+    VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -126,6 +132,11 @@ using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
 /// all of the results of an operation.
 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
 
+/// This class represents the `successors` directive. This directive represents
+/// all of the successors of an operation.
+using SuccessorsDirective =
+    DirectiveElement<Element::Kind::SuccessorsDirective>;
+
 /// This class represents the `attr-dict` directive. This directive represents
 /// the attribute dictionary of the operation.
 class AttrDictDirective
@@ -294,6 +305,8 @@ struct OperationFormat {
   /// Generate the c++ to resolve the types of operands and results during
   /// parsing.
   void genParserTypeResolution(Operator &op, OpMethodBody &body);
+  /// Generate the c++ to resolve successors during parsing.
+  void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
 
   /// Generate the operation printer from this format.
   void genPrinter(Operator &op, OpClass &opClass);
@@ -403,6 +416,51 @@ const char *const functionalTypeParserCode = R"(
   {1}Types = {0}__{1}_functionType.getResults();
 )";
 
+/// The code snippet used to generate a parser call for a successor list.
+///
+/// {0}: The name for the successor list.
+const char *successorListParserCode = R"(
+  SmallVector<std::pair<Block *, SmallVector<Value, 4>>, 2> {0}Successors;
+  {
+    Block *succ;
+    SmallVector<Value, 4> succOperands;
+    // Parse the first successor.
+    auto firstSucc = parser.parseOptionalSuccessorAndUseList(succ,
+                                                             succOperands);
+    if (firstSucc.hasValue()) {
+      if (failed(*firstSucc))
+        return failure();
+      {0}Successors.emplace_back(succ, succOperands);
+
+      // Parse any trailing successors.
+      while (succeeded(parser.parseOptionalComma())) {
+        succOperands.clear();
+        if (parser.parseSuccessorAndUseList(succ, succOperands))
+          return failure();
+        {0}Successors.emplace_back(succ, succOperands);
+      }
+    }
+  }
+)";
+
+/// The code snippet used to generate a parser call for a successor.
+///
+/// {0}: The name of the successor.
+const char *successorParserCode = R"(
+  Block *{0}Successor = nullptr;
+  SmallVector<Value, 4> {0}Operands;
+  if (parser.parseSuccessorAndUseList({0}Successor, {0}Operands))
+    return failure();
+)";
+
+/// The code snippet used to resolve a list of parsed successors.
+///
+/// {0}: The name of the successor list.
+const char *resolveSuccessorListParserCode = R"(
+  for (auto &succAndArgs : {0}Successors)
+    result.addSuccessor(succAndArgs.first, succAndArgs.second);
+)";
+
 /// Get the name used for the type list for the given type directive operand.
 /// 'isVariadic' is set to true if the operand has variadic types.
 static StringRef getTypeListName(Element *arg, bool &isVariadic) {
@@ -539,6 +597,10 @@ static void genElementParser(Element *element, OpMethodBody &body,
     bool isVariadic = operand->getVar()->isVariadic();
     body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
                     operand->getVar()->name);
+  } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+    bool isVariadic = successor->getVar()->isVariadic();
+    body << formatv(isVariadic ? successorListParserCode : successorParserCode,
+                    successor->getVar()->name);
 
     /// Directives.
   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
@@ -551,6 +613,8 @@ static void genElementParser(Element *element, OpMethodBody &body,
          << "  SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
          << "  if (parser.parseOperandList(allOperands))\n"
          << "    return failure();\n";
+  } else if (isa<SuccessorsDirective>(element)) {
+    body << llvm::formatv(successorListParserCode, "full");
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     bool isVariadic = false;
     StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
@@ -586,9 +650,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
   for (auto &element : elements)
     genElementParser(element.get(), body, attrTypeCtx);
 
-  // Generate the code to resolve the operand and result types now that they
-  // have been parsed.
+  // Generate the code to resolve the operand/result types and successors now
+  // that they have been parsed.
   genParserTypeResolution(op, body);
+  genParserSuccessorResolution(op, body);
   body << "  return success();\n";
 }
 
@@ -730,6 +795,28 @@ void OperationFormat::genParserTypeResolution(Operator &op,
   }
 }
 
+void OperationFormat::genParserSuccessorResolution(Operator &op,
+                                                   OpMethodBody &body) {
+  // Check for the case where all successors were parsed.
+  bool hasAllSuccessors = llvm::any_of(
+      elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
+  if (hasAllSuccessors) {
+    body << llvm::formatv(resolveSuccessorListParserCode, "full");
+    return;
+  }
+
+  // Otherwise, handle each successor individually.
+  for (const NamedSuccessor &successor : op.getSuccessors()) {
+    if (successor.isVariadic()) {
+      body << llvm::formatv(resolveSuccessorListParserCode, successor.name);
+      continue;
+    }
+
+    body << llvm::formatv("  result.addSuccessor({0}Successor, {0}Operands);\n",
+                          successor.name);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // PrinterGen
 
@@ -790,8 +877,8 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
 
 /// Generate the code for printing the given element.
 static void genElementPrinter(Element *element, OpMethodBody &body,
-                              OperationFormat &fmt, bool &shouldEmitSpace,
-                              bool &lastWasPunctuation) {
+                              OperationFormat &fmt, Operator &op,
+                              bool &shouldEmitSpace, bool &lastWasPunctuation) {
   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
     return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
                              lastWasPunctuation);
@@ -808,7 +895,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
 
     // Emit each of the elements.
     for (Element &childElement : optional->getElements())
-      genElementPrinter(&childElement, body, fmt, shouldEmitSpace,
+      genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
                         lastWasPunctuation);
     body << "  }\n";
     return;
@@ -847,8 +934,30 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
       body << "  p.printAttribute(" << var->name << "Attr());\n";
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     body << "  p << " << operand->getVar()->name << "();\n";
+  } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+    const NamedSuccessor *var = successor->getVar();
+    if (var->isVariadic()) {
+      body << "  {\n"
+           << "    auto succRange = " << var->name << "();\n"
+           << "    auto opSuccBegin = getOperation()->successor_begin();\n"
+           << "    int i = succRange.begin() - opSuccBegin;\n"
+           << "    int e = i + succRange.size();\n"
+           << "    interleaveComma(llvm::seq<int>(i, e), p, [&](int i) {\n"
+           << "      p.printSuccessorAndUseList(*this, i);\n"
+           << "    });\n"
+           << "  }\n";
+      return;
+    }
+
+    unsigned index = successor->getVar() - op.successor_begin();
+    body << "  p.printSuccessorAndUseList(*this, " << index << ");\n";
   } else if (isa<OperandsDirective>(element)) {
     body << "  p << getOperation()->getOperands();\n";
+  } else if (isa<SuccessorsDirective>(element)) {
+    body << "  interleaveComma(llvm::seq<int>(0, "
+            "getOperation()->getNumSuccessors()), p, [&](int i) {"
+         << "    p.printSuccessorAndUseList(*this, i);"
+         << "  });\n";
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     body << "  p << ";
     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
@@ -879,7 +988,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
   // punctuation.
   bool shouldEmitSpace = true, lastWasPunctuation = false;
   for (auto &element : elements)
-    genElementPrinter(element.get(), body, *this, shouldEmitSpace,
+    genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
                       lastWasPunctuation);
 }
 
@@ -911,6 +1020,7 @@ class Token {
     kw_functional_type,
     kw_operands,
     kw_results,
+    kw_successors,
     kw_type,
     keyword_end,
 
@@ -1094,6 +1204,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
           .Case("functional-type", Token::kw_functional_type)
           .Case("operands", Token::kw_operands)
           .Case("results", Token::kw_results)
+          .Case("successors", Token::kw_successors)
           .Case("type", Token::kw_type)
           .Default(Token::identifier);
   return Token(kind, str);
@@ -1173,6 +1284,8 @@ class FormatParser {
                                        llvm::SMLoc loc, bool isTopLevel);
   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
                                       llvm::SMLoc loc, bool isTopLevel);
+  LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
+                                         llvm::SMLoc loc, bool isTopLevel);
   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
                                    bool isTopLevel);
   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element);
@@ -1211,9 +1324,11 @@ class FormatParser {
   // The following are various bits of format state used for verification
   // during parsing.
   bool hasAllOperands = false, hasAttrDict = false;
+  bool hasAllSuccessors = false;
   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
   llvm::DenseSet<const NamedAttribute *> seenAttrs;
+  llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
   llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
 };
 } // end anonymous namespace
@@ -1313,6 +1428,17 @@ LogicalResult FormatParser::parse() {
     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
     fmt.operandTypes[i].setBuilderIdx(it.first->second);
   }
+
+  // Check that all of the successors are within the format.
+  if (!hasAllSuccessors) {
+    for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
+      const NamedSuccessor &successor = op.getSuccessor(i);
+      if (!seenSuccessors.count(&successor)) {
+        return emitError(loc, "format missing instance of successor #" +
+                                  Twine(i) + "('" + successor.name + "')");
+      }
+    }
+  }
   return success();
 }
 
@@ -1417,7 +1543,17 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
     element = std::make_unique<ResultVariable>(result);
     return success();
   }
-  return emitError(loc, "expected variable to refer to a argument or result");
+  /// Successors.
+  if (const auto *successor = findArg(op.getSuccessors(), name)) {
+    if (!isTopLevel)
+      return emitError(loc, "successors can only be used at the top level");
+    if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
+      return emitError(loc, "successor '" + name + "' is already bound");
+    element = std::make_unique<SuccessorVariable>(successor);
+    return success();
+  }
+  return emitError(
+      loc, "expected variable to refer to a argument, result, or successor");
 }
 
 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@@ -1438,6 +1574,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
     return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_results:
     return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
+  case Token::kw_successors:
+    return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_type:
     return parseTypeDirective(element, dirTok, isTopLevel);
 
@@ -1624,6 +1762,19 @@ FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
   return success();
 }
 
+LogicalResult
+FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
+                                       llvm::SMLoc loc, bool isTopLevel) {
+  if (!isTopLevel)
+    return emitError(loc,
+                     "'successors' is only valid as a top-level directive");
+  if (hasAllSuccessors || !seenSuccessors.empty())
+    return emitError(loc, "'successors' directive creates overlap in format");
+  hasAllSuccessors = true;
+  element = std::make_unique<SuccessorsDirective>();
+  return success();
+}
+
 LogicalResult
 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
                                  bool isTopLevel) {


        


More information about the Mlir-commits mailing list