[Mlir-commits] [mlir] 0359b86 - [mlir][ODS] Add support for variadic regions.
River Riddle
llvmlistbot at llvm.org
Sun Apr 5 01:07:12 PDT 2020
Author: River Riddle
Date: 2020-04-05T01:03:38-07:00
New Revision: 0359b86d8bb24a7bcd37dc6126baee303bc6c939
URL: https://github.com/llvm/llvm-project/commit/0359b86d8bb24a7bcd37dc6126baee303bc6c939
DIFF: https://github.com/llvm/llvm-project/commit/0359b86d8bb24a7bcd37dc6126baee303bc6c939.diff
LOG: [mlir][ODS] Add support for variadic regions.
Summary: This revision adds support for marking the last region as variadic in the ODS region list with the VariadicRegion directive.
Differential Revision: https://reviews.llvm.org/D77455
Added:
mlir/lib/TableGen/Region.cpp
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/TableGen/Operator.h
mlir/include/mlir/TableGen/Region.h
mlir/lib/IR/Operation.cpp
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/Operator.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/Dialect/LLVMIR/global.mlir
mlir/test/Dialect/Loops/invalid.mlir
mlir/test/IR/region.mlir
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 5d7f67907309..0187ff740bf7 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -265,6 +265,24 @@ Right now, the following primitive constraints are supported:
TODO: Design and implement more primitive constraints
+### Operation regions
+
+The regions of an operation are specified inside of the `dag`-typed `regions`,
+led by `region`:
+
+```tablegen
+let regions = (region
+ <region-constraint>:$<region-name>,
+ ...
+);
+```
+
+#### Variadic regions
+
+Similar to the `Variadic` class used for variadic operands and results,
+`VariadicRegion<...>` can be used for regions. Variadic regions can currently
+only be specified as the last region in the regions list.
+
### Operation results
Similar to operands, results are specified inside the `dag`-typed `results`, led
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 38a402dfe9dc..25f062b02d15 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1533,6 +1533,10 @@ class SizedRegion<int numBlocks> : Region<
CPred<"$_self.getBlocks().size() == " # numBlocks>,
"region with " # numBlocks # " blocks">;
+// A variadic region constraint. It expands to zero or more of the base region.
+class VariadicRegion<Region region>
+ : Region<region.predicate, region.description>;
+
//===----------------------------------------------------------------------===//
// Successor definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 7d663d363097..4901f5b232b7 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -368,6 +368,10 @@ LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
+LogicalResult verifyZeroRegion(Operation *op);
+LogicalResult verifyOneRegion(Operation *op);
+LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
+LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyZeroResult(Operation *op);
LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
@@ -529,6 +533,89 @@ template <typename ConcreteType>
class VariadicOperands
: public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
+//===----------------------------------------------------------------------===//
+// Region Traits
+
+/// This class provides verification for ops that are known to have zero
+/// regions.
+template <typename ConcreteType>
+class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifyZeroRegion(op);
+ }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple regions.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
+ using region_iterator = MutableArrayRef<Region>;
+ using region_range = RegionRange;
+
+ /// Return the number of regions.
+ unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
+
+ /// Return the region at `index`.
+ Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
+
+ /// Region iterator access.
+ region_iterator region_begin() {
+ return this->getOperation()->region_begin();
+ }
+ region_iterator region_end() { return this->getOperation()->region_end(); }
+ region_range getRegions() { return this->getOperation()->getRegions(); }
+};
+} // end namespace detail
+
+/// This class provides APIs for ops that are known to have a single region.
+template <typename ConcreteType>
+class OneRegion : public TraitBase<ConcreteType, OneRegion> {
+public:
+ Region &getRegion() { return this->getOperation()->getRegion(0); }
+
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifyOneRegion(op);
+ }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of regions.
+template <unsigned N> class NRegions {
+public:
+ static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
+
+ template <typename ConcreteType>
+ class Impl
+ : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifyNRegions(op, N);
+ }
+ };
+};
+
+/// This class provides APIs for ops that are known to have at least a specified
+/// number of regions.
+template <unsigned N> class AtLeastNRegions {
+public:
+ template <typename ConcreteType>
+ class Impl : public detail::MultiRegionTraitBase<ConcreteType,
+ AtLeastNRegions<N>::Impl> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifyAtLeastNRegions(op, N);
+ }
+ };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// regions.
+template <typename ConcreteType>
+class VariadicRegions
+ : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
+
//===----------------------------------------------------------------------===//
// Result Traits
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index aaf93d3964e8..2748894fe601 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -165,11 +165,20 @@ class Operator {
// requiring the raw MLIR trait here.
const OpTrait *getTrait(llvm::StringRef trait) const;
+ // Regions.
+ using const_region_iterator = const NamedRegion *;
+ const_region_iterator region_begin() const;
+ const_region_iterator region_end() const;
+ llvm::iterator_range<const_region_iterator> getRegions() const;
+
// Returns the number of regions.
unsigned getNumRegions() const;
// Returns the `index`-th region.
const NamedRegion &getRegion(unsigned index) const;
+ // Returns the number of variadic regions in this operation.
+ unsigned getNumVariadicRegions() const;
+
// Successors.
using const_successor_iterator = const NamedSuccessor *;
const_successor_iterator successor_begin() const;
diff --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h
index b2ed98cc58ca..423ef7208263 100644
--- a/mlir/include/mlir/TableGen/Region.h
+++ b/mlir/include/mlir/TableGen/Region.h
@@ -22,10 +22,16 @@ class Region : public Constraint {
using Constraint::Constraint;
static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
+
+ // Returns true if this region is variadic.
+ bool isVariadic() const;
};
// A struct bundling a region's constraint and its name.
struct NamedRegion {
+ // Returns true if this region is variadic.
+ bool isVariadic() const { return constraint.isVariadic(); }
+
StringRef name;
Region constraint;
};
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c6f25b2ff4a0..0d8b28185eb5 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -709,6 +709,32 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
return success();
}
+LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) {
+ if (op->getNumRegions() != 0)
+ return op->emitOpError() << "requires zero regions";
+ return success();
+}
+
+LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) {
+ if (op->getNumRegions() != 1)
+ return op->emitOpError() << "requires one region";
+ return success();
+}
+
+LogicalResult OpTrait::impl::verifyNRegions(Operation *op,
+ unsigned numRegions) {
+ if (op->getNumRegions() != numRegions)
+ return op->emitOpError() << "expected " << numRegions << " regions";
+ return success();
+}
+
+LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op,
+ unsigned numRegions) {
+ if (op->getNumRegions() < numRegions)
+ return op->emitOpError() << "expected " << numRegions << " or more regions";
+ return success();
+}
+
LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
if (op->getNumResults() != 0)
return op->emitOpError() << "requires zero results";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 08384657f94f..a395fdb14a7a 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -11,6 +11,7 @@ add_llvm_library(LLVMMLIRTableGen
Pass.cpp
Pattern.cpp
Predicate.cpp
+ Region.cpp
SideEffects.cpp
Successor.cpp
Type.cpp
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index ff84456d5491..46e26af40bdf 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -173,12 +173,28 @@ const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
return nullptr;
}
+auto tblgen::Operator::region_begin() const -> const_region_iterator {
+ return regions.begin();
+}
+auto tblgen::Operator::region_end() const -> const_region_iterator {
+ return regions.end();
+}
+auto tblgen::Operator::getRegions() const
+ -> llvm::iterator_range<const_region_iterator> {
+ return {region_begin(), region_end()};
+}
+
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
return regions[index];
}
+unsigned tblgen::Operator::getNumVariadicRegions() const {
+ return llvm::count_if(regions,
+ [](const NamedRegion &c) { return c.isVariadic(); });
+}
+
auto tblgen::Operator::successor_begin() const -> const_successor_iterator {
return successors.begin();
}
@@ -388,7 +404,16 @@ void tblgen::Operator::populateOpStructure() {
PrintFatalError(def.getLoc(),
Twine("undefined kind for region #") + Twine(i));
}
- regions.push_back({name, Region(regionInit->getDef())});
+ Region region(regionInit->getDef());
+ if (region.isVariadic()) {
+ // Only support variadic regions if it is the last one for now.
+ if (i != e - 1)
+ PrintFatalError(def.getLoc(), "only the last region can be variadic");
+ if (name.empty())
+ PrintFatalError(def.getLoc(), "variadic regions must be named");
+ }
+
+ regions.push_back({name, region});
}
LLVM_DEBUG(print(llvm::dbgs()));
diff --git a/mlir/lib/TableGen/Region.cpp b/mlir/lib/TableGen/Region.cpp
new file mode 100644
index 000000000000..d8380fa79158
--- /dev/null
+++ b/mlir/lib/TableGen/Region.cpp
@@ -0,0 +1,20 @@
+//===- Region.cpp - Region class ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Region wrapper to simplify using TableGen Record defining a MLIR Region.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Region.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+// Returns true if this region is variadic.
+bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); }
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 2db5d3553e93..689e6db54065 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -118,7 +118,7 @@ module {
// -----
module {
- // expected-error at +1 {{expects one region}}
+ // expected-error at +1 {{requires one region}}
"llvm.func"() {sym_name = "no_region", type = !llvm<"void ()">} : () -> ()
}
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 70944b86960d..cc4a00b5e0ec 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -60,12 +60,12 @@ func @references() {
// -----
// expected-error @+1 {{op requires string attribute 'sym_name'}}
-"llvm.mlir.global"() {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
+"llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
// -----
// expected-error @+1 {{op requires attribute 'type'}}
-"llvm.mlir.global"() {sym_name = "foo", constant, value = 42 : i64} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> ()
// -----
@@ -75,12 +75,12 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm<"label">
// -----
// expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> ()
// -----
// expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> ()
// -----
diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir
index e827a78f2b56..562e6031e9b7 100644
--- a/mlir/test/Dialect/Loops/invalid.mlir
+++ b/mlir/test/Dialect/Loops/invalid.mlir
@@ -2,7 +2,7 @@
func @loop_for_lb(%arg0: f32, %arg1: index) {
// expected-error at +1 {{operand #0 must be index}}
- "loop.for"(%arg0, %arg1, %arg1) : (f32, index, index) -> ()
+ "loop.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> ()
return
}
@@ -10,7 +10,7 @@ func @loop_for_lb(%arg0: f32, %arg1: index) {
func @loop_for_ub(%arg0: f32, %arg1: index) {
// expected-error at +1 {{operand #1 must be index}}
- "loop.for"(%arg1, %arg0, %arg1) : (index, f32, index) -> ()
+ "loop.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> ()
return
}
@@ -18,7 +18,7 @@ func @loop_for_ub(%arg0: f32, %arg1: index) {
func @loop_for_step(%arg0: f32, %arg1: index) {
// expected-error at +1 {{operand #2 must be index}}
- "loop.for"(%arg1, %arg1, %arg0) : (index, index, f32) -> ()
+ "loop.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> ()
return
}
@@ -37,7 +37,7 @@ func @loop_for_step_positive(%arg0: index) {
// -----
func @loop_for_one_region(%arg0: index) {
- // expected-error at +1 {{incorrect number of regions: expected 1 but found 2}}
+ // expected-error at +1 {{requires one region}}
"loop.for"(%arg0, %arg0, %arg0) (
{loop.yield},
{loop.yield}
@@ -77,14 +77,14 @@ func @loop_for_single_index_argument(%arg0: index) {
func @loop_if_not_i1(%arg0: index) {
// expected-error at +1 {{operand #0 must be 1-bit signless integer}}
- "loop.if"(%arg0) : (index) -> ()
+ "loop.if"(%arg0) ({}, {}) : (index) -> ()
return
}
// -----
func @loop_if_more_than_2_regions(%arg0: i1) {
- // expected-error at +1 {{op has incorrect number of regions: expected 2}}
+ // expected-error at +1 {{expected 2 regions}}
"loop.if"(%arg0) ({}, {}, {}): (i1) -> ()
return
}
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index a32371cb7155..465ae511aad2 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -16,7 +16,7 @@ func @correct_number_of_regions() {
// -----
func @missing_regions() {
- // expected-error at +1 {{op has incorrect number of regions: expected 2 but found 1}}
+ // expected-error at +1 {{expected 2 regions}}
"test.two_region_op"()(
{"work"() : () -> ()}
) : () -> ()
@@ -26,7 +26,7 @@ func @missing_regions() {
// -----
func @extra_regions() {
- // expected-error at +1 {{op has incorrect number of regions: expected 2 but found 3}}
+ // expected-error at +1 {{expected 2 regions}}
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()},
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 7606c3356140..4ccbd04ef221 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -26,7 +26,10 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
Variadic<F32>:$s
);
- let regions = (region AnyRegion:$someRegion);
+ let regions = (region
+ AnyRegion:$someRegion,
+ VariadicRegion<AnyRegion>:$someRegions
+ );
let builders = [OpBuilder<"Value val">];
let parser = [{ foo }];
let printer = [{ bar }];
@@ -55,7 +58,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: ArrayRef<Value> tblgen_operands;
// CHECK: };
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
// CHECK-NOT: OpTrait::IsIsolatedFromAbove
// CHECK: public:
// CHECK: using Op::Op;
@@ -67,14 +70,15 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: Operation::result_range getODSResults(unsigned index);
// CHECK: Value r();
// CHECK: Region &someRegion();
+// CHECK: MutableArrayRef<Region> someRegions();
// CHECK: IntegerAttr attr1Attr()
// CHECK: APInt attr1();
// CHECK: FloatAttr attr2Attr()
// CHECK: Optional< APFloat > attr2();
// CHECK: static void build(Value val);
-// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
-// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
-// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount)
+// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount)
+// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes, unsigned numRegions)
// CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result);
// CHECK: void print(OpAsmPrinter &p);
// CHECK: LogicalResult verify();
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f262f12ddf2d..1d2ee2f0efe9 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -603,10 +603,19 @@ void OpEmitter::genNamedRegionGetters() {
unsigned numRegions = op.getNumRegions();
for (unsigned i = 0; i < numRegions; ++i) {
const auto ®ion = op.getRegion(i);
- if (!region.name.empty()) {
- auto &m = opClass.newMethod("Region &", region.name);
- m.body() << formatv(" return this->getOperation()->getRegion({0});", i);
+ if (region.name.empty())
+ continue;
+
+ // Generate the accessors for a varidiadic region.
+ if (region.isVariadic()) {
+ auto &m = opClass.newMethod("MutableArrayRef<Region>", region.name);
+ m.body() << formatv(
+ " return this->getOperation()->getRegions().drop_front({0});", i);
+ continue;
}
+
+ auto &m = opClass.newMethod("Region &", region.name);
+ m.body() << formatv(" return this->getOperation()->getRegion({0});", i);
}
}
@@ -739,6 +748,8 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
std::string params =
std::string("Builder *odsBuilder, OperationState &") + builderOpState +
", ValueRange operands, ArrayRef<NamedAttribute> attributes";
+ if (op.getNumVariadicRegions())
+ params += ", unsigned numRegions";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
@@ -750,8 +761,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
- for (int i = 0; i < numRegions; ++i)
- m.body() << " (void)" << builderOpState << ".addRegion();\n";
+ body << llvm::formatv(
+ " for (unsigned i = 0; i != {0}; ++i)\n",
+ (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
+ body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
@@ -897,6 +910,8 @@ void OpEmitter::genCollectiveParamBuilder() {
builderOpState +
", ArrayRef<Type> resultTypes, ValueRange operands, "
"ArrayRef<NamedAttribute> attributes";
+ if (op.getNumVariadicRegions())
+ params += ", unsigned numRegions";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
@@ -913,8 +928,10 @@ void OpEmitter::genCollectiveParamBuilder() {
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
- for (int i = 0; i < numRegions; ++i)
- m.body() << " (void)" << builderOpState << ".addRegion();\n";
+ body << llvm::formatv(
+ " for (unsigned i = 0; i != {0}; ++i)\n",
+ (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
+ body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
@@ -1042,11 +1059,17 @@ void OpEmitter::buildParamList(std::string ¶mList,
}
}
- /// Insert parameters for the block and operands for each successor.
+ /// Insert parameters for each successor.
for (const NamedSuccessor &succ : op.getSuccessors()) {
paramList += (succ.isVariadic() ? ", ArrayRef<Block *> " : ", Block *");
paramList += succ.name;
}
+
+ /// Insert parameters for variadic regions.
+ for (const NamedRegion ®ion : op.getRegions()) {
+ if (region.isVariadic())
+ paramList += llvm::formatv(", unsigned {0}Count", region.name).str();
+ }
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
@@ -1110,9 +1133,12 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
}
// Create the correct number of regions.
- if (int numRegions = op.getNumRegions()) {
- for (int i = 0; i < numRegions; ++i)
- body << " (void)" << builderOpState << ".addRegion();\n";
+ for (const NamedRegion ®ion : op.getRegions()) {
+ if (region.isVariadic())
+ body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
+ region.name);
+
+ body << " (void)" << builderOpState << ".addRegion();\n";
}
// Push all successors to the result.
@@ -1436,33 +1462,42 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+ // If we have no regions, there is nothing more to do.
unsigned numRegions = op.getNumRegions();
+ if (numRegions == 0)
+ return;
- // Verify this op has the correct number of regions
- body << formatv(
- " if (this->getOperation()->getNumRegions() != {0}) {\n "
- "return emitOpError(\"has incorrect number of regions: expected {0} but "
- "found \") << this->getOperation()->getNumRegions();\n }\n",
- numRegions);
+ body << "{\n";
+ body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numRegions; ++i) {
const auto ®ion = op.getRegion(i);
+ if (region.constraint.getPredicate().isNull())
+ continue;
- std::string name = std::string(formatv("#{0}", i));
- if (!region.name.empty()) {
- name += std::string(formatv(" ('{0}')", region.name));
- }
-
- auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
+ body << " for (Region ®ion : ";
+ body << formatv(
+ region.isVariadic()
+ ? "{0}()"
+ : "MutableArrayRef<Region>(this->getOperation()->getRegion({1}))",
+ region.name, i);
+ body << ") {\n";
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
- &verifyCtx.withSelf(getRegion))
+ &verifyCtx.withSelf("region"))
.str();
- body << formatv(" if (!({0})) {\n "
- "return emitOpError(\"region {1} failed to verify "
- "constraint: {2}\");\n }\n",
- constraint, name, region.constraint.getDescription());
+ body << formatv(" (void)region;\n"
+ " if (!({0})) {\n "
+ "return emitOpError(\"region #\") << index << \" {1}"
+ "failed to "
+ "verify constraint: {2}\";\n }\n",
+ constraint,
+ region.name.empty() ? "" : "('" + region.name + "') ",
+ region.constraint.getDescription())
+ << " ++index;\n"
+ << " }\n";
}
+ body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
@@ -1488,29 +1523,31 @@ void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
&verifyCtx.withSelf("successor"))
.str();
- body << formatv(
- " (void)successor;\n"
- " if (!({0})) {\n "
- "return emitOpError(\"successor #\") << index << \"('{2}') failed to "
- "verify constraint: {3}\";\n }\n",
- constraint, i, successor.name, successor.constraint.getDescription());
- body << " }\n";
+ body << formatv(" (void)successor;\n"
+ " if (!({0})) {\n "
+ "return emitOpError(\"successor #\") << index << \"('{1}') "
+ "failed to "
+ "verify constraint: {2}\";\n }\n",
+ constraint, successor.name,
+ successor.constraint.getDescription())
+ << " ++index;\n"
+ << " }\n";
}
body << " }\n";
}
/// Add a size count trait to the given operation class.
static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
- int numNonVariadic, int numVariadic) {
+ int numTotal, int numVariadic) {
if (numVariadic != 0) {
- if (numNonVariadic == numVariadic)
+ if (numTotal == numVariadic)
opClass.addTrait("OpTrait::Variadic" + traitKind + "s");
else
opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" +
- Twine(numNonVariadic - numVariadic) + ">::Impl");
+ Twine(numTotal - numVariadic) + ">::Impl");
return;
}
- switch (numNonVariadic) {
+ switch (numTotal) {
case 0:
opClass.addTrait("OpTrait::Zero" + traitKind);
break;
@@ -1518,17 +1555,21 @@ static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
opClass.addTrait("OpTrait::One" + traitKind);
break;
default:
- opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) +
+ opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
">::Impl");
break;
}
}
void OpEmitter::genTraits() {
+ // Add region size trait.
+ unsigned numRegions = op.getNumRegions();
+ unsigned numVariadicRegions = op.getNumVariadicRegions();
+ addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
+
+ // Add result size trait.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();
-
- // Add return size trait.
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// Add successor size trait.
More information about the Mlir-commits
mailing list