[Mlir-commits] [mlir] 9eb436f - [mlir][DeclarativeParser] Add support for formatting the successors of an operation.
River Riddle
llvmlistbot at llvm.org
Fri Feb 21 15:17:45 PST 2020
Author: River Riddle
Date: 2020-02-21T15:15:32-08:00
New Revision: 9eb436feaa7f5f01dc4852396647a5b46311c8eb
URL: https://github.com/llvm/llvm-project/commit/9eb436feaa7f5f01dc4852396647a5b46311c8eb
DIFF: https://github.com/llvm/llvm-project/commit/9eb436feaa7f5f01dc4852396647a5b46311c8eb.diff
LOG: [mlir][DeclarativeParser] Add support for formatting the successors of an operation.
This revision add support for formatting successor variables in a similar way to operands, attributes, etc.
Differential Revision: https://reviews.llvm.org/D74789
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/Dialect/SPIRV/control-flow-ops.mlir
mlir/test/IR/invalid.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 70d718e22578..30406b7e4e83 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -625,6 +625,10 @@ The available directives are as follows:
- Represents all of the results of an operation.
+* `successors`
+
+ - Represents all of the successors of an operation.
+
* `type` ( input )
- Represents the type of the given input.
@@ -641,8 +645,8 @@ The following are the set of valid punctuation:
#### Variables
A variable is an entity that has been registered on the operation itself, i.e.
-an argument(attribute or operand), result, etc. In the `CallOp` example above,
-the variables would be `$callee` and `$args`.
+an argument(attribute or operand), result, successor, etc. In the `CallOp`
+example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e0848c2d6d77..5be06a8bfe72 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -455,15 +455,12 @@ def LLVM_SelectOp
// Terminators.
def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
let successors = (successor AnySuccessor:$dest);
- let parser = [{ return parseBrOp(parser, result); }];
- let printer = [{ printBrOp(p, *this); }];
+ let assemblyFormat = "$dest attr-dict";
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
let arguments = (ins LLVMI1:$condition);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
-
- let parser = [{ return parseCondBrOp(parser, result); }];
- let printer = [{ printCondBrOp(p, *this); }];
+ let assemblyFormat = "$condition `,` successors attr-dict";
}
def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 433c1323cdee..03884afe3e95 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -69,6 +69,8 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
}];
let autogenSerialization = 0;
+
+ let assemblyFormat = "successors attr-dict";
}
// -----
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index abe92e2afb28..83c781a19b18 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -250,6 +250,7 @@ def BranchOp : Std_Op<"br", [Terminator]> {
}];
let hasCanonicalizer = 1;
+ let assemblyFormat = "$dest attr-dict";
}
def CallOp : Std_Op<"call", [CallOpInterface]> {
@@ -602,6 +603,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
}];
let hasCanonicalizer = 1;
+ let assemblyFormat = "$condition `,` successors attr-dict";
}
def ConstantOp : Std_Op<"constant",
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 08fac5ea19ef..54ecf8cde39c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -578,6 +578,11 @@ class OpAsmParser {
virtual ParseResult
parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
+ /// Parse an optional operation successor and its operand list.
+ virtual OptionalParseResult
+ parseOptionalSuccessorAndUseList(Block *&dest,
+ SmallVectorImpl<Value> &operands) = 0;
+
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 880c95c441a4..c1773bbd8120 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -780,69 +780,6 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success();
}
-//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::BrOp.
-//===----------------------------------------------------------------------===//
-
-static void printBrOp(OpAsmPrinter &p, BrOp &op) {
- p << op.getOperationName() << ' ';
- p.printSuccessorAndUseList(op.getOperation(), 0);
- p.printOptionalAttrDict(op.getAttrs());
-}
-
-// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
-// attribute-dict?
-static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value, 4> operands;
- if (parser.parseSuccessorAndUseList(dest, operands) ||
- parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- result.addSuccessor(dest, operands);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CondBrOp.
-//===----------------------------------------------------------------------===//
-
-static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
- p << op.getOperationName() << ' ' << op.getOperand(0) << ", ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), 1);
- p.printOptionalAttrDict(op.getAttrs());
-}
-
-// <operation> ::= `llvm.cond_br` ssa-use `,`
-// bb-id (`[` ssa-use-and-type-list `]`)? `,`
-// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
-static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) {
- Block *trueDest;
- Block *falseDest;
- SmallVector<Value, 4> trueOperands;
- SmallVector<Value, 4> falseOperands;
- OpAsmParser::OperandType condition;
-
- Builder &builder = parser.getBuilder();
- auto *llvmDialect =
- builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
-
- if (parser.parseOperand(condition) || parser.parseComma() ||
- parser.parseSuccessorAndUseList(trueDest, trueOperands) ||
- parser.parseComma() ||
- parser.parseSuccessorAndUseList(falseDest, falseOperands) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.resolveOperand(condition, i1Type, result.operands))
- return failure();
-
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ReturnOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 0a7f93f58367..01197498a704 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1018,24 +1018,6 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
results.insert<ConvertChainedBitcast>(context);
}
-//===----------------------------------------------------------------------===//
-// spv.BranchOp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
- Block *dest;
- SmallVector<Value, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- state.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
- printer << spirv::BranchOp::getOperationName() << ' ';
- printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
-}
-
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 80f85e02289b..5c5fcfc47c11 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -414,20 +414,6 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
};
} // end anonymous namespace.
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void print(OpAsmPrinter &p, BranchOp op) {
- p << "br ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
Block *BranchOp::getDest() { return getSuccessor(0); }
void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
@@ -810,42 +796,6 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
};
} // end anonymous namespace.
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<Value, 4> destOperands;
- Block *dest;
- OpAsmParser::OperandType condInfo;
-
- // Parse the condition.
- Type int1Ty = parser.getBuilder().getI1Type();
- if (parser.parseOperand(condInfo) || parser.parseComma() ||
- parser.resolveOperand(condInfo, int1Ty, result.operands)) {
- return parser.emitError(parser.getNameLoc(),
- "expected condition type was boolean (i1)");
- }
-
- // Parse the true successor.
- if (parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- // Parse the false successor.
- destOperands.clear();
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- return success();
-}
-
-static void print(OpAsmPrinter &p, CondBranchOp op) {
- p << "cond_br " << op.getCondition() << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyConstCondBranchPred>(context);
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index a542b117b423..3929cdd10a97 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -4423,6 +4423,15 @@ class CustomOpAsmParser : public OpAsmParser {
return parser.parseSuccessorAndUseList(dest, operands);
}
+ /// Parse an optional operation successor and its operand list.
+ OptionalParseResult
+ parseOptionalSuccessorAndUseList(Block *&dest,
+ SmallVectorImpl<Value> &operands) override {
+ if (parser.getToken().isNot(Token::caret_identifier))
+ return llvm::None;
+ return parseSuccessorAndUseList(dest, operands);
+ }
+
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 4201411783c5..cc7e09fcd1e1 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -24,8 +24,8 @@ func @branch_argument() -> () {
// -----
func @missing_accessor() -> () {
+ // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}}
spv.Branch
- // expected-error @+1 {{expected block name}}
}
// -----
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index a8a72fd87f12..12e3bfdcdec5 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -402,7 +402,6 @@ func @condbr_notbool() {
^bb0:
%a = "foo"() : () -> i32 // expected-note {{prior use here}}
cond_br %a, ^bb0, ^bb0 // expected-error {{use of value '%a' expects
diff erent type than prior uses: 'i1' vs 'i32'}}
-// expected-error at -1 {{expected condition type was boolean (i1)}}
}
// -----
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index a7acfaaeeada..135dcdea6fd9 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -1139,4 +1139,9 @@ def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
$buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
}]>;
+def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
+ let successors = (successor VariadicSuccessor<AnySuccessor>:$targets);
+ let assemblyFormat = "$targets attr-dict";
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index e4483cc2a638..ac5aa259a907 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -94,10 +94,18 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
// results
// CHECK: error: 'results' directive can not be used as a top-level directive
-def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{
+def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{
results
}]>;
+//===----------------------------------------------------------------------===//
+// successors
+
+// CHECK: error: 'successors' is only valid as a top-level directive
+def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{
+ type(successors)
+}]>;
+
//===----------------------------------------------------------------------===//
// type
@@ -235,7 +243,7 @@ def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
// Variables
//===----------------------------------------------------------------------===//
-// CHECK: error: expected variable to refer to a argument or result
+// CHECK: error: expected variable to refer to a argument, result, or successor
def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
$unknown_arg attr-dict
}]>;
@@ -255,6 +263,18 @@ def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{
def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{
$result attr-dict
}]>, Results<(outs I64:$result)>;
+// CHECK: error: successor 'successor' is already bound
+def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{
+ $successor $successor attr-dict
+}]> {
+ let successors = (successor AnySuccessor:$successor);
+}
+// CHECK: error: successor 'successor' is already bound
+def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
+ successors $successor attr-dict
+}]> {
+ let successors = (successor AnySuccessor:$successor);
+}
//===----------------------------------------------------------------------===//
// Coverage Checks
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 42ddc201f6d5..6c53527d649f 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -41,3 +41,19 @@ test.format_operand_d_op %i64, %memref : memref<1xf64>
// CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
+
+"foo.successor_test_region"() ( {
+ ^bb0:
+ // CHECK: test.format_successor_a_op ^bb1 {attr}
+ test.format_successor_a_op ^bb1 {attr}
+
+ ^bb1:
+ // CHECK: test.format_successor_a_op ^bb1, ^bb2 {attr}
+ test.format_successor_a_op ^bb1, ^bb2 {attr}
+
+ ^bb2:
+ // CHECK: test.format_successor_a_op {attr}
+ test.format_successor_a_op {attr}
+
+}) { arg_names = ["i", "j", "k"] } : () -> ()
+
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 77b7b33615d6..3655bf795c8b 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -49,6 +49,7 @@ class Element {
FunctionalTypeDirective,
OperandsDirective,
ResultsDirective,
+ SuccessorsDirective,
TypeDirective,
/// This element is a literal.
@@ -58,6 +59,7 @@ class Element {
AttributeVariable,
OperandVariable,
ResultVariable,
+ SuccessorVariable,
/// This element is an optional element.
Optional,
@@ -105,6 +107,10 @@ using OperandVariable =
/// This class represents a variable that refers to a result.
using ResultVariable =
VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
+
+/// This class represents a variable that refers to a successor.
+using SuccessorVariable =
+ VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -126,6 +132,11 @@ using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
/// all of the results of an operation.
using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
+/// This class represents the `successors` directive. This directive represents
+/// all of the successors of an operation.
+using SuccessorsDirective =
+ DirectiveElement<Element::Kind::SuccessorsDirective>;
+
/// This class represents the `attr-dict` directive. This directive represents
/// the attribute dictionary of the operation.
class AttrDictDirective
@@ -294,6 +305,8 @@ struct OperationFormat {
/// Generate the c++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, OpMethodBody &body);
+ /// Generate the c++ to resolve successors during parsing.
+ void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
/// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass);
@@ -403,6 +416,51 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults();
)";
+/// The code snippet used to generate a parser call for a successor list.
+///
+/// {0}: The name for the successor list.
+const char *successorListParserCode = R"(
+ SmallVector<std::pair<Block *, SmallVector<Value, 4>>, 2> {0}Successors;
+ {
+ Block *succ;
+ SmallVector<Value, 4> succOperands;
+ // Parse the first successor.
+ auto firstSucc = parser.parseOptionalSuccessorAndUseList(succ,
+ succOperands);
+ if (firstSucc.hasValue()) {
+ if (failed(*firstSucc))
+ return failure();
+ {0}Successors.emplace_back(succ, succOperands);
+
+ // Parse any trailing successors.
+ while (succeeded(parser.parseOptionalComma())) {
+ succOperands.clear();
+ if (parser.parseSuccessorAndUseList(succ, succOperands))
+ return failure();
+ {0}Successors.emplace_back(succ, succOperands);
+ }
+ }
+ }
+)";
+
+/// The code snippet used to generate a parser call for a successor.
+///
+/// {0}: The name of the successor.
+const char *successorParserCode = R"(
+ Block *{0}Successor = nullptr;
+ SmallVector<Value, 4> {0}Operands;
+ if (parser.parseSuccessorAndUseList({0}Successor, {0}Operands))
+ return failure();
+)";
+
+/// The code snippet used to resolve a list of parsed successors.
+///
+/// {0}: The name of the successor list.
+const char *resolveSuccessorListParserCode = R"(
+ for (auto &succAndArgs : {0}Successors)
+ result.addSuccessor(succAndArgs.first, succAndArgs.second);
+)";
+
/// Get the name used for the type list for the given type directive operand.
/// 'isVariadic' is set to true if the operand has variadic types.
static StringRef getTypeListName(Element *arg, bool &isVariadic) {
@@ -539,6 +597,10 @@ static void genElementParser(Element *element, OpMethodBody &body,
bool isVariadic = operand->getVar()->isVariadic();
body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
operand->getVar()->name);
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+ bool isVariadic = successor->getVar()->isVariadic();
+ body << formatv(isVariadic ? successorListParserCode : successorParserCode,
+ successor->getVar()->name);
/// Directives.
} else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
@@ -551,6 +613,8 @@ static void genElementParser(Element *element, OpMethodBody &body,
<< " SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
+ } else if (isa<SuccessorsDirective>(element)) {
+ body << llvm::formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool isVariadic = false;
StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
@@ -586,9 +650,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
for (auto &element : elements)
genElementParser(element.get(), body, attrTypeCtx);
- // Generate the code to resolve the operand and result types now that they
- // have been parsed.
+ // Generate the code to resolve the operand/result types and successors now
+ // that they have been parsed.
genParserTypeResolution(op, body);
+ genParserSuccessorResolution(op, body);
body << " return success();\n";
}
@@ -730,6 +795,28 @@ void OperationFormat::genParserTypeResolution(Operator &op,
}
}
+void OperationFormat::genParserSuccessorResolution(Operator &op,
+ OpMethodBody &body) {
+ // Check for the case where all successors were parsed.
+ bool hasAllSuccessors = llvm::any_of(
+ elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
+ if (hasAllSuccessors) {
+ body << llvm::formatv(resolveSuccessorListParserCode, "full");
+ return;
+ }
+
+ // Otherwise, handle each successor individually.
+ for (const NamedSuccessor &successor : op.getSuccessors()) {
+ if (successor.isVariadic()) {
+ body << llvm::formatv(resolveSuccessorListParserCode, successor.name);
+ continue;
+ }
+
+ body << llvm::formatv(" result.addSuccessor({0}Successor, {0}Operands);\n",
+ successor.name);
+ }
+}
+
//===----------------------------------------------------------------------===//
// PrinterGen
@@ -790,8 +877,8 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
/// Generate the code for printing the given element.
static void genElementPrinter(Element *element, OpMethodBody &body,
- OperationFormat &fmt, bool &shouldEmitSpace,
- bool &lastWasPunctuation) {
+ OperationFormat &fmt, Operator &op,
+ bool &shouldEmitSpace, bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation);
@@ -808,7 +895,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit each of the elements.
for (Element &childElement : optional->getElements())
- genElementPrinter(&childElement, body, fmt, shouldEmitSpace,
+ genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
lastWasPunctuation);
body << " }\n";
return;
@@ -847,8 +934,30 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
body << " p << " << operand->getVar()->name << "();\n";
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+ const NamedSuccessor *var = successor->getVar();
+ if (var->isVariadic()) {
+ body << " {\n"
+ << " auto succRange = " << var->name << "();\n"
+ << " auto opSuccBegin = getOperation()->successor_begin();\n"
+ << " int i = succRange.begin() - opSuccBegin;\n"
+ << " int e = i + succRange.size();\n"
+ << " interleaveComma(llvm::seq<int>(i, e), p, [&](int i) {\n"
+ << " p.printSuccessorAndUseList(*this, i);\n"
+ << " });\n"
+ << " }\n";
+ return;
+ }
+
+ unsigned index = successor->getVar() - op.successor_begin();
+ body << " p.printSuccessorAndUseList(*this, " << index << ");\n";
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
+ } else if (isa<SuccessorsDirective>(element)) {
+ body << " interleaveComma(llvm::seq<int>(0, "
+ "getOperation()->getNumSuccessors()), p, [&](int i) {"
+ << " p.printSuccessorAndUseList(*this, i);"
+ << " });\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
@@ -879,7 +988,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false;
for (auto &element : elements)
- genElementPrinter(element.get(), body, *this, shouldEmitSpace,
+ genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
lastWasPunctuation);
}
@@ -911,6 +1020,7 @@ class Token {
kw_functional_type,
kw_operands,
kw_results,
+ kw_successors,
kw_type,
keyword_end,
@@ -1094,6 +1204,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("results", Token::kw_results)
+ .Case("successors", Token::kw_successors)
.Case("type", Token::kw_type)
.Default(Token::identifier);
return Token(kind, str);
@@ -1173,6 +1284,8 @@ class FormatParser {
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
+ LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
bool isTopLevel);
LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element);
@@ -1211,9 +1324,11 @@ class FormatParser {
// The following are various bits of format state used for verification
// during parsing.
bool hasAllOperands = false, hasAttrDict = false;
+ bool hasAllSuccessors = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
llvm::DenseSet<const NamedAttribute *> seenAttrs;
+ llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
};
} // end anonymous namespace
@@ -1313,6 +1428,17 @@ LogicalResult FormatParser::parse() {
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second);
}
+
+ // Check that all of the successors are within the format.
+ if (!hasAllSuccessors) {
+ for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
+ const NamedSuccessor &successor = op.getSuccessor(i);
+ if (!seenSuccessors.count(&successor)) {
+ return emitError(loc, "format missing instance of successor #" +
+ Twine(i) + "('" + successor.name + "')");
+ }
+ }
+ }
return success();
}
@@ -1417,7 +1543,17 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
element = std::make_unique<ResultVariable>(result);
return success();
}
- return emitError(loc, "expected variable to refer to a argument or result");
+ /// Successors.
+ if (const auto *successor = findArg(op.getSuccessors(), name)) {
+ if (!isTopLevel)
+ return emitError(loc, "successors can only be used at the top level");
+ if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
+ return emitError(loc, "successor '" + name + "' is already bound");
+ element = std::make_unique<SuccessorVariable>(successor);
+ return success();
+ }
+ return emitError(
+ loc, "expected variable to refer to a argument, result, or successor");
}
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@@ -1438,6 +1574,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
+ case Token::kw_successors:
+ return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_type:
return parseTypeDirective(element, dirTok, isTopLevel);
@@ -1624,6 +1762,19 @@ FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
return success();
}
+LogicalResult
+FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel) {
+ if (!isTopLevel)
+ return emitError(loc,
+ "'successors' is only valid as a top-level directive");
+ if (hasAllSuccessors || !seenSuccessors.empty())
+ return emitError(loc, "'successors' directive creates overlap in format");
+ hasAllSuccessors = true;
+ element = std::make_unique<SuccessorsDirective>();
+ return success();
+}
+
LogicalResult
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
bool isTopLevel) {
More information about the Mlir-commits
mailing list