[Mlir-commits] [mlir] b1de971 - [mlir][ODS] Add support for specifying the successors of an operation.
River Riddle
llvmlistbot at llvm.org
Fri Feb 21 15:17:43 PST 2020
Author: River Riddle
Date: 2020-02-21T15:15:32-08:00
New Revision: b1de971ba8c83c82ef63077b666aaff3ba8e56b9
URL: https://github.com/llvm/llvm-project/commit/b1de971ba8c83c82ef63077b666aaff3ba8e56b9
DIFF: https://github.com/llvm/llvm-project/commit/b1de971ba8c83c82ef63077b666aaff3ba8e56b9.diff
LOG: [mlir][ODS] Add support for specifying the successors of an operation.
This revision add support in ODS for specifying the successors of an operation. Successors are specified via the `successors` list:
```
let successors = (successor AnySuccessor:$target, AnySuccessor:$otherTarget);
```
Differential Revision: https://reviews.llvm.org/D74783
Added:
mlir/include/mlir/TableGen/Successor.h
mlir/lib/TableGen/Successor.cpp
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Constraint.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/Constraint.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/Dialect/SPIRV/control-flow-ops.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index fc648362c9fb..70d718e22578 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -279,6 +279,24 @@ Similar to variadic operands, `Variadic<...>` can also be used for results.
And similarly, `SameVariadicResultSize` for multiple variadic results in the
same operation.
+### Operation successors
+
+For terminator operations, the successors are specified inside of the
+`dag`-typed `successors`, led by `successor`:
+
+```tablegen
+let successors = (successor
+ <successor-constraint>:$<successor-name>,
+ ...
+);
+```
+
+#### Variadic successors
+
+Similar to the `Variadic` class used for variadic operands and results,
+`VariadicSuccessor<...>` can be used for successors. Variadic successors can
+currently only be specified as the last successor in the successor list.
+
### Operation traits and constraints
Traits are operation properties that affect syntax or semantics. MLIR C++
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a58d9af0666a..840b4396134e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -31,6 +31,25 @@ def LLVMInt : TypeConstraint<
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
"LLVM dialect integer">;
+def LLVMIntBase : TypeConstraint<
+ And<[LLVM_Type.predicate,
+ CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
+ "LLVM dialect integer">;
+
+// Integer type of a specific width.
+class LLVMI<int width>
+ : Type<And<[
+ LLVM_Type.predicate,
+ CPred<
+ "$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
+ "LLVM dialect " # width # "-bit integer">,
+ BuildableType<
+ "::mlir::LLVM::LLVMType::getIntNTy("
+ "$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
+ # width # ")">;
+
+def LLVMI1 : LLVMI<1>;
+
// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
// used to translate to LLVM IR proper.
class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d73fd1187431..e0848c2d6d77 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -72,8 +72,7 @@ class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for LLVM terminator operations. All terminator operations have
// zero results and an optional list of successors.
class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
- LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
- Arguments<(ins Variadic<LLVM_Type>:$args)>, Results<(outs)> {
+ LLVM_Op<mnemonic, !listconcat(traits, [Terminator])> {
let builders = [
OpBuilder<
"Builder *, OperationState &result, "
@@ -320,15 +319,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>)>,
Results<(outs Variadic<LLVM_Type>)> {
+ let successors = (successor AnySuccessor:$normalDest,
+ AnySuccessor:$unwindDest);
+
let builders = [OpBuilder<
- "Builder *b, OperationState &result, ArrayRef<Type> tys, "
- "FlatSymbolRefAttr callee, ValueRange ops, Block* normal, "
- "ValueRange normalOps, Block* unwind, ValueRange unwindOps",
- [{
- result.addAttribute("callee", callee);
- build(b, result, tys, ops, normal, normalOps, unwind, unwindOps);
- }]>,
- OpBuilder<
"Builder *b, OperationState &result, ArrayRef<Type> tys, "
"ValueRange ops, Block* normal, "
"ValueRange normalOps, Block* unwind, ValueRange unwindOps",
@@ -460,19 +454,19 @@ def LLVM_SelectOp
// Terminators.
def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+ let successors = (successor AnySuccessor:$dest);
let parser = [{ return parseBrOp(parser, result); }];
let printer = [{ printBrOp(p, *this); }];
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
- let verifier = [{
- if (getNumSuccessors() != 2)
- return emitOpError("expected exactly two successors");
- return success();
- }];
+ let arguments = (ins LLVMI1:$condition);
+ let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
+
let parser = [{ return parseCondBrOp(parser, result); }];
let printer = [{ printCondBrOp(p, *this); }];
}
-def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> {
+def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)> {
string llvmBuilder = [{
if ($_numOperands != 0)
builder.CreateRet($args[0]);
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index ac95214b6184..433c1323cdee 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -41,12 +41,14 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
```
}];
- let arguments = (ins
- Variadic<AnyType>:$block_arguments
- );
+ let arguments = (ins);
let results = (outs);
+ let successors = (successor AnySuccessor:$target);
+
+ let verifier = [{ return success(); }];
+
let builders = [
OpBuilder<
"Builder *, OperationState &state, "
@@ -60,12 +62,10 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
let extraClassDeclaration = [{
/// Returns the branch target block.
- Block *getTarget() { return getOperation()->getSuccessor(0); }
+ Block *getTarget() { return target(); }
/// Returns the block arguments.
- operand_range getBlockArguments() {
- return getOperation()->getSuccessorOperands(0);
- }
+ operand_range getBlockArguments() { return targetOperands(); }
}];
let autogenSerialization = 0;
@@ -115,12 +115,14 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
let arguments = (ins
SPV_Bool:$condition,
- Variadic<AnyType>:$branch_arguments,
OptionalAttr<I32ArrayAttr>:$branch_weights
);
let results = (outs);
+ let successors = (successor AnySuccessor:$trueTarget,
+ AnySuccessor:$falseTarget);
+
let builders = [
OpBuilder<
"Builder *builder, OperationState &state, Value condition, "
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 1fc4330cefdd..abe92e2afb28 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -232,12 +232,10 @@ def BranchOp : Std_Op<"br", [Terminator]> {
^bb3(%3: tensor<*xf32>):
}];
- let arguments = (ins Variadic<AnyType>:$operands);
+ let successors = (successor AnySuccessor:$dest);
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Block *dest,"
- "ValueRange operands = {}", [{
- result.addSuccessor(dest, operands);
+ let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{
+ result.addSuccessor(dest, llvm::None);
}]>];
// BranchOp is fully verified by traits.
@@ -513,16 +511,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
...
}];
- let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value condition,"
- "Block *trueDest, ValueRange trueOperands,"
- "Block *falseDest, ValueRange falseOperands", [{
- result.addOperands(condition);
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- }]>];
+ let arguments = (ins I1:$condition);
+ let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
// CondBranchOp is fully verified by traits.
let verifier = ?;
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 3dba6a09c5a7..25c0238946a9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -185,6 +185,10 @@ class AttrConstraint<Pred predicate, string description = ""> :
class RegionConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
+// Subclass for constraints on a successor.
+class SuccessorConstraint<Pred predicate, string description = ""> :
+ Constraint<predicate, description>;
+
// How to use these constraint categories:
//
// * Use TypeConstraint to specify
@@ -1341,6 +1345,21 @@ class SizedRegion<int numBlocks> : Region<
CPred<"$_self.getBlocks().size() == " # numBlocks>,
"region with " # numBlocks # " blocks">;
+//===----------------------------------------------------------------------===//
+// Successor definitions
+//===----------------------------------------------------------------------===//
+
+class Successor<Pred condition, string descr = ""> :
+ SuccessorConstraint<condition, descr>;
+
+// Any successor.
+def AnySuccessor : Successor<?, "any successor">;
+
+// A variadic successor constraint. It expands to zero or more of the base
+// successor.
+class VariadicSuccessor<Successor successor>
+ : Successor<successor.predicate, successor.description>;
+
//===----------------------------------------------------------------------===//
// OpTrait definitions
//===----------------------------------------------------------------------===//
@@ -1537,6 +1556,9 @@ def outs;
// Marker used to identify the region list for an op.
def region;
+// Marker used to identify the successor list for an op.
+def successor;
+
// Class for defining a custom builder.
//
// TableGen generates several generic builders for each op by default (see
@@ -1587,6 +1609,9 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// The list of regions of the op. Default to 0 regions.
dag regions = (region);
+ // The list of successors of the op. Default to 0 successors.
+ dag successors = (successor);
+
// Attribute getters can be added to the op by adding an Attr member
// with the name and type of the attribute. E.g., adding int attribute
// with name "value" and type "i32":
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 105fc4075647..775b3545a034 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -48,7 +48,7 @@ class Constraint {
StringRef getDescription() const;
// Constraint kind
- enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
+ enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
Kind getKind() const { return kind; }
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index d9a458ebdf33..e83b25231a87 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -19,6 +19,7 @@
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Region.h"
+#include "mlir/TableGen/Successor.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
@@ -138,6 +139,20 @@ class Operator {
// Returns the `index`-th region.
const NamedRegion &getRegion(unsigned index) const;
+ // Successors.
+ using const_successor_iterator = const NamedSuccessor *;
+ const_successor_iterator successor_begin() const;
+ const_successor_iterator successor_end() const;
+ llvm::iterator_range<const_successor_iterator> getSuccessors() const;
+
+ // Returns the number of successors.
+ unsigned getNumSuccessors() const;
+ // Returns the `index`-th successor.
+ const NamedSuccessor &getSuccessor(unsigned index) const;
+
+ // Returns the number of variadic successors in this operation.
+ unsigned getNumVariadicSuccessors() const;
+
// Trait.
using const_trait_iterator = const OpTrait *;
const_trait_iterator trait_begin() const;
@@ -193,6 +208,9 @@ class Operator {
// The results of the op.
SmallVector<NamedTypeConstraint, 4> results;
+ // The successors of this op.
+ SmallVector<NamedSuccessor, 0> successors;
+
// The traits of the op.
SmallVector<OpTrait, 4> traits;
diff --git a/mlir/include/mlir/TableGen/Successor.h b/mlir/include/mlir/TableGen/Successor.h
new file mode 100644
index 000000000000..0659a983f0d5
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Successor.h
@@ -0,0 +1,44 @@
+//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_SUCCESSOR_H_
+#define MLIR_TABLEGEN_SUCCESSOR_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing helper methods for accessing Successor defined in
+// TableGen.
+class Successor : public Constraint {
+public:
+ using Constraint::Constraint;
+
+ static bool classof(const Constraint *c) {
+ return c->getKind() == CK_Successor;
+ }
+
+ // Returns true if this successor is variadic.
+ bool isVariadic() const;
+};
+
+// A struct bundling a successor's constraint and its name.
+struct NamedSuccessor {
+ // Returns true if this successor is variadic.
+ bool isVariadic() const { return constraint.isVariadic(); }
+
+ StringRef name;
+ Successor constraint;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_SUCCESSOR_H_
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a63a593dfc1e..880c95c441a4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -234,15 +234,14 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
- if (op.getNumSuccessors() != 2)
- return op.emitOpError("must have normal and unwind destinations");
- if (op.getSuccessor(1)->empty())
+ Block *unwindDest = op.unwindDest();
+ if (unwindDest->empty())
return op.emitError(
"must have at least one operation in unwind destination");
// In unwind destination, first operation must be LandingpadOp
- if (!isa<LandingpadOp>(op.getSuccessor(1)->front()))
+ if (!isa<LandingpadOp>(unwindDest->front()))
return op.emitError("first operation in unwind destination should be a "
"llvm.landingpad operation");
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 89c243bc09ca..0a7f93f58367 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1036,14 +1036,6 @@ static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
}
-static LogicalResult verify(spirv::BranchOp branchOp) {
- auto *op = branchOp.getOperation();
- if (op->getNumSuccessors() != 1)
- branchOp.emitOpError("must have exactly one successor");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
@@ -1114,10 +1106,6 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
}
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
- auto *op = branchOp.getOperation();
- if (op->getNumSuccessors() != 2)
- return branchOp.emitOpError("must have exactly two successors");
-
if (auto weights = branchOp.branch_weights()) {
if (weights->getValue().size() != 2) {
return branchOp.emitOpError("must have exactly two branch weights");
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 08bf24029c44..6317c669192e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
OpTrait.cpp
Pattern.cpp
Predicate.cpp
+ Successor.cpp
Type.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 251d15a0c806..98bb7d63d06d 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -23,6 +23,8 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Attr;
} else if (record->isSubClassOf("RegionConstraint")) {
kind = CK_Region;
+ } else if (record->isSubClassOf("SuccessorConstraint")) {
+ kind = CK_Successor;
} else {
assert(record->isSubClassOf("Constraint"));
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 5e338b37a00f..2fd2997970b6 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -159,6 +159,31 @@ const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
return regions[index];
}
+auto tblgen::Operator::successor_begin() const -> const_successor_iterator {
+ return successors.begin();
+}
+auto tblgen::Operator::successor_end() const -> const_successor_iterator {
+ return successors.end();
+}
+auto tblgen::Operator::getSuccessors() const
+ -> llvm::iterator_range<const_successor_iterator> {
+ return {successor_begin(), successor_end()};
+}
+
+unsigned tblgen::Operator::getNumSuccessors() const {
+ return successors.size();
+}
+
+const tblgen::NamedSuccessor &
+tblgen::Operator::getSuccessor(unsigned index) const {
+ return successors[index];
+}
+
+unsigned tblgen::Operator::getNumVariadicSuccessors() const {
+ return llvm::count_if(successors,
+ [](const NamedSuccessor &c) { return c.isVariadic(); });
+}
+
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
return traits.begin();
}
@@ -285,6 +310,29 @@ void tblgen::Operator::populateOpStructure() {
results.push_back({name, TypeConstraint(resultDef)});
}
+ // Handle successors
+ auto *successorsDag = def.getValueAsDag("successors");
+ auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
+ if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
+ PrintFatalError(def.getLoc(),
+ "'successors' must have 'successor' directive");
+ }
+
+ for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
+ auto name = successorsDag->getArgNameStr(i);
+ auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
+ if (!successorInit) {
+ PrintFatalError(def.getLoc(),
+ Twine("undefined kind for successor #") + Twine(i));
+ }
+ Successor successor(successorInit->getDef());
+
+ // Only support variadic successors if it is the last one for now.
+ if (i != e - 1 && successor.isVariadic())
+ PrintFatalError(def.getLoc(), "only the last successor can be variadic");
+ successors.push_back({name, successor});
+ }
+
// Create list of traits, skipping over duplicates: appending to lists in
// tablegen is easy, making them unique less so, so dedupe here.
if (auto traitList = def.getValueAsListInit("traits")) {
diff --git a/mlir/lib/TableGen/Successor.cpp b/mlir/lib/TableGen/Successor.cpp
new file mode 100644
index 000000000000..80c6e9bb74d3
--- /dev/null
+++ b/mlir/lib/TableGen/Successor.cpp
@@ -0,0 +1,24 @@
+//===- Successor.cpp - Successor class ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Successor wrapper to simplify using TableGen Record defining a MLIR
+// Successor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Successor.h"
+#include "mlir/ADT/TypeSwitch.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+// Returns true if this successor is variadic.
+bool Successor::isVariadic() const {
+ return def->isSubClassOf("VariadicSuccessor");
+}
diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 55ca9a45c6c5..4201411783c5 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -32,7 +32,7 @@ func @missing_accessor() -> () {
func @wrong_accessor_count() -> () {
%true = spv.constant true
- // expected-error @+1 {{must have exactly one successor}}
+ // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}}
"spv.Branch"()[^one, ^two] : () -> ()
^one:
spv.Return
@@ -116,7 +116,7 @@ func @wrong_condition_type() -> () {
func @wrong_accessor_count() -> () {
%true = spv.constant true
- // expected-error @+1 {{must have exactly two successors}}
+ // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}}
"spv.BranchConditional"(%true)[^one] : (i1) -> ()
^one:
spv.Return
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 756e9cd98428..a7acfaaeeada 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -431,7 +431,7 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
[(IsNotScalar $attr)]>;
def TestBranchOp : TEST_Op<"br", [Terminator]> {
- let arguments = (ins Variadic<AnyType>:$operands);
+ let successors = (successor AnySuccessor:$target);
}
def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2d6999275e9a..d8a4b91c026e 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -183,6 +183,9 @@ class OpEmitter {
// Generates getters for named regions.
void genNamedRegionGetters();
+ // Generates getters for named successors.
+ void genNamedSuccessorGetters();
+
// Generates builder methods for the operation.
void genBuilder();
@@ -266,6 +269,10 @@ class OpEmitter {
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
+ // Generates verify statements for successors in the operation.
+ // The generated code will be attached to `body`.
+ void genSuccessorVerifier(OpMethodBody &body);
+
// Generates the traits used by the object.
void genTraits();
@@ -302,6 +309,7 @@ OpEmitter::OpEmitter(const Operator &op)
genNamedOperandGetters();
genNamedResultGetters();
genNamedRegionGetters();
+ genNamedSuccessorGetters();
genAttrGetters();
genAttrSetters();
genBuilder();
@@ -579,6 +587,42 @@ void OpEmitter::genNamedRegionGetters() {
}
}
+void OpEmitter::genNamedSuccessorGetters() {
+ unsigned numSuccessors = op.getNumSuccessors();
+ for (unsigned i = 0; i < numSuccessors; ++i) {
+ const NamedSuccessor &successor = op.getSuccessor(i);
+ if (successor.name.empty())
+ continue;
+
+ // Generate the accessors for a variadic successor.
+ if (successor.isVariadic()) {
+ // Generate the getter.
+ auto &m = opClass.newMethod("SuccessorRange", successor.name);
+ m.body() << formatv(
+ " return {std::next(this->getOperation()->successor_begin(), {0}), "
+ "this->getOperation()->successor_end()};",
+ i);
+ continue;
+ }
+
+ // Generate the block getter.
+ auto &m = opClass.newMethod("Block *", successor.name);
+ m.body() << formatv(" return this->getOperation()->getSuccessor({0});", i);
+
+ // Generate the all-operands getter.
+ auto &operandsMethod = opClass.newMethod(
+ "Operation::operand_range", (successor.name + "Operands").str());
+ operandsMethod.body() << formatv(
+ " return this->getOperation()->getSuccessorOperands({0});", i);
+
+ // Generate the individual-operand getter.
+ auto &operandMethod = opClass.newMethod(
+ "Value", (successor.name + "Operand").str(), "unsigned index");
+ operandMethod.body() << formatv(
+ " return this->getOperation()->getSuccessorOperand({0}, index);", i);
+ }
+}
+
static bool canGenerateUnwrappedBuilder(Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
@@ -869,8 +913,9 @@ void OpEmitter::genCollectiveParamBuilder() {
// Generate builder that infers type too.
// TODO(jpienaar): Subsume this with general checking if type can be infered
// automatically.
- // TODO(jpienaar): Expand to handle regions.
- if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
+ // TODO(jpienaar): Expand to handle regions and successors.
+ if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 &&
+ op.getNumSuccessors() == 0)
genInferedTypeCollectiveParamBuilder();
}
@@ -982,17 +1027,28 @@ void OpEmitter::buildParamList(std::string ¶mList,
++numAttrs;
}
}
+
+ /// Insert parameters for the block and operands for each successor.
+ const char *variadicSuccCode =
+ ", ArrayRef<Block *> {0}, ArrayRef<ValueRange> {0}Operands";
+ const char *succCode = ", Block *{0}, ValueRange {0}Operands";
+ for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
+ if (namedSuccessor.isVariadic())
+ paramList += llvm::formatv(variadicSuccCode, namedSuccessor.name).str();
+ else
+ paramList += llvm::formatv(succCode, namedSuccessor.name).str();
+ }
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
- // Push all operands to the result
+ // Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i)
<< ");\n";
}
- // Push all attributes to the result
+ // Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr()) {
@@ -1030,11 +1086,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
}
}
- // Create the correct number of regions
+ // Create the correct number of regions.
if (int numRegions = op.getNumRegions()) {
for (int i = 0; i < numRegions; ++i)
body << " (void)" << builderOpState << ".addRegion();\n";
}
+
+ // Push all successors to the result.
+ for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
+ if (namedSuccessor.isVariadic()) {
+ body << formatv(" for (int i = 0, e = {1}.size(); i != e; ++i)\n"
+ " {0}.addSuccessor({1}[i], {1}Operands[i]);\n",
+ builderOpState, namedSuccessor.name);
+ continue;
+ }
+
+ body << formatv(" {0}.addSuccessor({1}, {1}Operands);\n", builderOpState,
+ namedSuccessor.name);
+ }
}
void OpEmitter::genCanonicalizerDecls() {
@@ -1228,6 +1297,7 @@ void OpEmitter::genVerifier() {
}
genRegionVerifier(body);
+ genSuccessorVerifier(body);
if (hasCustomVerify) {
FmtContext fctx;
@@ -1305,6 +1375,58 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
}
}
+void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
+ unsigned numSuccessors = op.getNumSuccessors();
+
+ const char *checkSuccessorSizeCode = R"(
+ if (this->getOperation()->getNumSuccessors() {0} {1}) {
+ return emitOpError("has incorrect number of successors: expected{2} {1}"
+ " but found ")
+ << this->getOperation()->getNumSuccessors();
+ }
+ )";
+
+ // Verify this op has the correct number of successors.
+ unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
+ if (numVariadicSuccessors == 0) {
+ body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, "");
+ } else if (numVariadicSuccessors != numSuccessors) {
+ body << formatv(checkSuccessorSizeCode, "<",
+ numSuccessors - numVariadicSuccessors, " at least");
+ }
+
+ // If we have no successors, there is nothing more to do.
+ if (numSuccessors == 0)
+ return;
+
+ body << "{\n";
+ body << " unsigned index = 0; (void)index;\n";
+
+ for (unsigned i = 0; i < numSuccessors; ++i) {
+ const auto &successor = op.getSuccessor(i);
+ if (successor.constraint.getPredicate().isNull())
+ continue;
+
+ body << " for (Block *successor : ";
+ body << formatv(successor.isVariadic() ? "{0}()"
+ : "ArrayRef<Block *>({0}())",
+ successor.name);
+ body << ") {\n";
+ auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
+ &verifyCtx.withSelf("successor"))
+ .str();
+
+ body << formatv(
+ " (void)successor;\n"
+ " if (!({0})) {\n "
+ "return emitOpError(\"successor #\") << index << \"('{2}') failed to "
+ "verify constraint: {3}\";\n }\n",
+ constraint, i, successor.name, successor.constraint.getDescription());
+ body << " }\n";
+ }
+ body << " }\n";
+}
+
void OpEmitter::genTraits() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();
@@ -1342,7 +1464,9 @@ void OpEmitter::genTraits() {
int numVariadicOperands = op.getNumVariadicOperands();
// Add operand size trait.
- if (numVariadicOperands != 0) {
+ // Note: Successor operands are also included in the operation's operand list,
+ // so we always need to use VariadicOperands in the presence of successors.
+ if (numVariadicOperands != 0 || op.getNumSuccessors()) {
if (numOperands == numVariadicOperands)
opClass.addTrait("OpTrait::VariadicOperands");
else
More information about the Mlir-commits
mailing list