[Mlir-commits] [mlir] 0359b86 - [mlir][ODS] Add support for variadic regions.

River Riddle llvmlistbot at llvm.org
Sun Apr 5 01:07:12 PDT 2020


Author: River Riddle
Date: 2020-04-05T01:03:38-07:00
New Revision: 0359b86d8bb24a7bcd37dc6126baee303bc6c939

URL: https://github.com/llvm/llvm-project/commit/0359b86d8bb24a7bcd37dc6126baee303bc6c939
DIFF: https://github.com/llvm/llvm-project/commit/0359b86d8bb24a7bcd37dc6126baee303bc6c939.diff

LOG: [mlir][ODS] Add support for variadic regions.

Summary: This revision adds support for marking the last region as variadic in the ODS region list with the VariadicRegion directive.

Differential Revision: https://reviews.llvm.org/D77455

Added: 
    mlir/lib/TableGen/Region.cpp

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/include/mlir/TableGen/Region.h
    mlir/lib/IR/Operation.cpp
    mlir/lib/TableGen/CMakeLists.txt
    mlir/lib/TableGen/Operator.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Dialect/LLVMIR/global.mlir
    mlir/test/Dialect/Loops/invalid.mlir
    mlir/test/IR/region.mlir
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 5d7f67907309..0187ff740bf7 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -265,6 +265,24 @@ Right now, the following primitive constraints are supported:
 
 TODO: Design and implement more primitive constraints
 
+### Operation regions
+
+The regions of an operation are specified inside of the `dag`-typed `regions`,
+led by `region`:
+
+```tablegen
+let regions = (region
+  <region-constraint>:$<region-name>,
+  ...
+);
+```
+
+#### Variadic regions
+
+Similar to the `Variadic` class used for variadic operands and results,
+`VariadicRegion<...>` can be used for regions. Variadic regions can currently
+only be specified as the last region in the regions list.
+
 ### Operation results
 
 Similar to operands, results are specified inside the `dag`-typed `results`, led

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 38a402dfe9dc..25f062b02d15 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1533,6 +1533,10 @@ class SizedRegion<int numBlocks> : Region<
   CPred<"$_self.getBlocks().size() == " # numBlocks>,
   "region with " # numBlocks # " blocks">;
 
+// A variadic region constraint. It expands to zero or more of the base region.
+class VariadicRegion<Region region>
+  : Region<region.predicate, region.description>;
+
 //===----------------------------------------------------------------------===//
 // Successor definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 7d663d363097..4901f5b232b7 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -368,6 +368,10 @@ LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
 LogicalResult verifyOperandsAreFloatLike(Operation *op);
 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
 LogicalResult verifySameTypeOperands(Operation *op);
+LogicalResult verifyZeroRegion(Operation *op);
+LogicalResult verifyOneRegion(Operation *op);
+LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
+LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
 LogicalResult verifyZeroResult(Operation *op);
 LogicalResult verifyOneResult(Operation *op);
 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
@@ -529,6 +533,89 @@ template <typename ConcreteType>
 class VariadicOperands
     : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
 
+//===----------------------------------------------------------------------===//
+// Region Traits
+
+/// This class provides verification for ops that are known to have zero
+/// regions.
+template <typename ConcreteType>
+class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyZeroRegion(op);
+  }
+};
+
+namespace detail {
+/// Utility trait base that provides accessors for derived traits that have
+/// multiple regions.
+template <typename ConcreteType, template <typename> class TraitType>
+struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
+  using region_iterator = MutableArrayRef<Region>;
+  using region_range = RegionRange;
+
+  /// Return the number of regions.
+  unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
+
+  /// Return the region at `index`.
+  Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
+
+  /// Region iterator access.
+  region_iterator region_begin() {
+    return this->getOperation()->region_begin();
+  }
+  region_iterator region_end() { return this->getOperation()->region_end(); }
+  region_range getRegions() { return this->getOperation()->getRegions(); }
+};
+} // end namespace detail
+
+/// This class provides APIs for ops that are known to have a single region.
+template <typename ConcreteType>
+class OneRegion : public TraitBase<ConcreteType, OneRegion> {
+public:
+  Region &getRegion() { return this->getOperation()->getRegion(0); }
+
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifyOneRegion(op);
+  }
+};
+
+/// This class provides the API for ops that are known to have a specified
+/// number of regions.
+template <unsigned N> class NRegions {
+public:
+  static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
+
+  template <typename ConcreteType>
+  class Impl
+      : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyNRegions(op, N);
+    }
+  };
+};
+
+/// This class provides APIs for ops that are known to have at least a specified
+/// number of regions.
+template <unsigned N> class AtLeastNRegions {
+public:
+  template <typename ConcreteType>
+  class Impl : public detail::MultiRegionTraitBase<ConcreteType,
+                                                   AtLeastNRegions<N>::Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      return impl::verifyAtLeastNRegions(op, N);
+    }
+  };
+};
+
+/// This class provides the API for ops which have an unknown number of
+/// regions.
+template <typename ConcreteType>
+class VariadicRegions
+    : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
+
 //===----------------------------------------------------------------------===//
 // Result Traits
 

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index aaf93d3964e8..2748894fe601 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -165,11 +165,20 @@ class Operator {
   // requiring the raw MLIR trait here.
   const OpTrait *getTrait(llvm::StringRef trait) const;
 
+  // Regions.
+  using const_region_iterator = const NamedRegion *;
+  const_region_iterator region_begin() const;
+  const_region_iterator region_end() const;
+  llvm::iterator_range<const_region_iterator> getRegions() const;
+
   // Returns the number of regions.
   unsigned getNumRegions() const;
   // Returns the `index`-th region.
   const NamedRegion &getRegion(unsigned index) const;
 
+  // Returns the number of variadic regions in this operation.
+  unsigned getNumVariadicRegions() const;
+
   // Successors.
   using const_successor_iterator = const NamedSuccessor *;
   const_successor_iterator successor_begin() const;

diff  --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h
index b2ed98cc58ca..423ef7208263 100644
--- a/mlir/include/mlir/TableGen/Region.h
+++ b/mlir/include/mlir/TableGen/Region.h
@@ -22,10 +22,16 @@ class Region : public Constraint {
   using Constraint::Constraint;
 
   static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
+
+  // Returns true if this region is variadic.
+  bool isVariadic() const;
 };
 
 // A struct bundling a region's constraint and its name.
 struct NamedRegion {
+  // Returns true if this region is variadic.
+  bool isVariadic() const { return constraint.isVariadic(); }
+
   StringRef name;
   Region constraint;
 };

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c6f25b2ff4a0..0d8b28185eb5 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -709,6 +709,32 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
   return success();
 }
 
+LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) {
+  if (op->getNumRegions() != 0)
+    return op->emitOpError() << "requires zero regions";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return op->emitOpError() << "requires one region";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyNRegions(Operation *op,
+                                            unsigned numRegions) {
+  if (op->getNumRegions() != numRegions)
+    return op->emitOpError() << "expected " << numRegions << " regions";
+  return success();
+}
+
+LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op,
+                                                   unsigned numRegions) {
+  if (op->getNumRegions() < numRegions)
+    return op->emitOpError() << "expected " << numRegions << " or more regions";
+  return success();
+}
+
 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
   if (op->getNumResults() != 0)
     return op->emitOpError() << "requires zero results";

diff  --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 08384657f94f..a395fdb14a7a 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -11,6 +11,7 @@ add_llvm_library(LLVMMLIRTableGen
   Pass.cpp
   Pattern.cpp
   Predicate.cpp
+  Region.cpp
   SideEffects.cpp
   Successor.cpp
   Type.cpp

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index ff84456d5491..46e26af40bdf 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -173,12 +173,28 @@ const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
   return nullptr;
 }
 
+auto tblgen::Operator::region_begin() const -> const_region_iterator {
+  return regions.begin();
+}
+auto tblgen::Operator::region_end() const -> const_region_iterator {
+  return regions.end();
+}
+auto tblgen::Operator::getRegions() const
+    -> llvm::iterator_range<const_region_iterator> {
+  return {region_begin(), region_end()};
+}
+
 unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
 
 const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
   return regions[index];
 }
 
+unsigned tblgen::Operator::getNumVariadicRegions() const {
+  return llvm::count_if(regions,
+                        [](const NamedRegion &c) { return c.isVariadic(); });
+}
+
 auto tblgen::Operator::successor_begin() const -> const_successor_iterator {
   return successors.begin();
 }
@@ -388,7 +404,16 @@ void tblgen::Operator::populateOpStructure() {
       PrintFatalError(def.getLoc(),
                       Twine("undefined kind for region #") + Twine(i));
     }
-    regions.push_back({name, Region(regionInit->getDef())});
+    Region region(regionInit->getDef());
+    if (region.isVariadic()) {
+      // Only support variadic regions if it is the last one for now.
+      if (i != e - 1)
+        PrintFatalError(def.getLoc(), "only the last region can be variadic");
+      if (name.empty())
+        PrintFatalError(def.getLoc(), "variadic regions must be named");
+    }
+
+    regions.push_back({name, region});
   }
 
   LLVM_DEBUG(print(llvm::dbgs()));

diff  --git a/mlir/lib/TableGen/Region.cpp b/mlir/lib/TableGen/Region.cpp
new file mode 100644
index 000000000000..d8380fa79158
--- /dev/null
+++ b/mlir/lib/TableGen/Region.cpp
@@ -0,0 +1,20 @@
+//===- Region.cpp - Region class ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Region wrapper to simplify using TableGen Record defining a MLIR Region.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Region.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+// Returns true if this region is variadic.
+bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); }

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 2db5d3553e93..689e6db54065 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -118,7 +118,7 @@ module {
 // -----
 
 module {
-  // expected-error at +1 {{expects one region}}
+  // expected-error at +1 {{requires one region}}
   "llvm.func"() {sym_name = "no_region", type = !llvm<"void ()">} : () -> ()
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 70944b86960d..cc4a00b5e0ec 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -60,12 +60,12 @@ func @references() {
 // -----
 
 // expected-error @+1 {{op requires string attribute 'sym_name'}}
-"llvm.mlir.global"() {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
+"llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
 
 // -----
 
 // expected-error @+1 {{op requires attribute 'type'}}
-"llvm.mlir.global"() {sym_name = "foo", constant, value = 42 : i64} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> ()
 
 // -----
 
@@ -75,12 +75,12 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm<"label">
 // -----
 
 // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> ()
 
 // -----
 
 // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> ()
 
 // -----
 

diff  --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir
index e827a78f2b56..562e6031e9b7 100644
--- a/mlir/test/Dialect/Loops/invalid.mlir
+++ b/mlir/test/Dialect/Loops/invalid.mlir
@@ -2,7 +2,7 @@
 
 func @loop_for_lb(%arg0: f32, %arg1: index) {
   // expected-error at +1 {{operand #0 must be index}}
-  "loop.for"(%arg0, %arg1, %arg1) : (f32, index, index) -> ()
+  "loop.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> ()
   return
 }
 
@@ -10,7 +10,7 @@ func @loop_for_lb(%arg0: f32, %arg1: index) {
 
 func @loop_for_ub(%arg0: f32, %arg1: index) {
   // expected-error at +1 {{operand #1 must be index}}
-  "loop.for"(%arg1, %arg0, %arg1) : (index, f32, index) -> ()
+  "loop.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> ()
   return
 }
 
@@ -18,7 +18,7 @@ func @loop_for_ub(%arg0: f32, %arg1: index) {
 
 func @loop_for_step(%arg0: f32, %arg1: index) {
   // expected-error at +1 {{operand #2 must be index}}
-  "loop.for"(%arg1, %arg1, %arg0) : (index, index, f32) -> ()
+  "loop.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> ()
   return
 }
 
@@ -37,7 +37,7 @@ func @loop_for_step_positive(%arg0: index) {
 // -----
 
 func @loop_for_one_region(%arg0: index) {
-  // expected-error at +1 {{incorrect number of regions: expected 1 but found 2}}
+  // expected-error at +1 {{requires one region}}
   "loop.for"(%arg0, %arg0, %arg0) (
     {loop.yield},
     {loop.yield}
@@ -77,14 +77,14 @@ func @loop_for_single_index_argument(%arg0: index) {
 
 func @loop_if_not_i1(%arg0: index) {
   // expected-error at +1 {{operand #0 must be 1-bit signless integer}}
-  "loop.if"(%arg0) : (index) -> ()
+  "loop.if"(%arg0) ({}, {}) : (index) -> ()
   return
 }
 
 // -----
 
 func @loop_if_more_than_2_regions(%arg0: i1) {
-  // expected-error at +1 {{op has incorrect number of regions: expected 2}}
+  // expected-error at +1 {{expected 2 regions}}
   "loop.if"(%arg0) ({}, {}, {}): (i1) -> ()
   return
 }

diff  --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index a32371cb7155..465ae511aad2 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -16,7 +16,7 @@ func @correct_number_of_regions() {
 // -----
 
 func @missing_regions() {
-    // expected-error at +1 {{op has incorrect number of regions: expected 2 but found 1}}
+    // expected-error at +1 {{expected 2 regions}}
     "test.two_region_op"()(
       {"work"() : () -> ()}
     ) : () -> ()
@@ -26,7 +26,7 @@ func @missing_regions() {
 // -----
 
 func @extra_regions() {
-    // expected-error at +1 {{op has incorrect number of regions: expected 2 but found 3}}
+    // expected-error at +1 {{expected 2 regions}}
     "test.two_region_op"()(
       {"work"() : () -> ()},
       {"work"() : () -> ()},

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 7606c3356140..4ccbd04ef221 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -26,7 +26,10 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
     Variadic<F32>:$s
   );
 
-  let regions = (region AnyRegion:$someRegion);
+  let regions = (region
+    AnyRegion:$someRegion,
+    VariadicRegion<AnyRegion>:$someRegions
+  );
   let builders = [OpBuilder<"Value val">];
   let parser = [{ foo }];
   let printer = [{ bar }];
@@ -55,7 +58,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ArrayRef<Value> tblgen_operands;
 // CHECK: };
 
-// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
+// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
 // CHECK-NOT: OpTrait::IsIsolatedFromAbove
 // CHECK: public:
 // CHECK:   using Op::Op;
@@ -67,14 +70,15 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   Operation::result_range getODSResults(unsigned index);
 // CHECK:   Value r();
 // CHECK:   Region &someRegion();
+// CHECK:   MutableArrayRef<Region> someRegions();
 // CHECK:   IntegerAttr attr1Attr()
 // CHECK:   APInt attr1();
 // CHECK:   FloatAttr attr2Attr()
 // CHECK:   Optional< APFloat > attr2();
 // CHECK:   static void build(Value val);
-// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
-// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
-// CHECK:   static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
+// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount)
+// CHECK:   static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount)
+// CHECK:   static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes, unsigned numRegions)
 // CHECK:   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 // CHECK:   void print(OpAsmPrinter &p);
 // CHECK:   LogicalResult verify();

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f262f12ddf2d..1d2ee2f0efe9 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -603,10 +603,19 @@ void OpEmitter::genNamedRegionGetters() {
   unsigned numRegions = op.getNumRegions();
   for (unsigned i = 0; i < numRegions; ++i) {
     const auto &region = op.getRegion(i);
-    if (!region.name.empty()) {
-      auto &m = opClass.newMethod("Region &", region.name);
-      m.body() << formatv("  return this->getOperation()->getRegion({0});", i);
+    if (region.name.empty())
+      continue;
+
+    // Generate the accessors for a varidiadic region.
+    if (region.isVariadic()) {
+      auto &m = opClass.newMethod("MutableArrayRef<Region>", region.name);
+      m.body() << formatv(
+          "  return this->getOperation()->getRegions().drop_front({0});", i);
+      continue;
     }
+
+    auto &m = opClass.newMethod("Region &", region.name);
+    m.body() << formatv("  return this->getOperation()->getRegion({0});", i);
   }
 }
 
@@ -739,6 +748,8 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
   std::string params =
       std::string("Builder *odsBuilder, OperationState &") + builderOpState +
       ", ValueRange operands, ArrayRef<NamedAttribute> attributes";
+  if (op.getNumVariadicRegions())
+    params += ", unsigned numRegions";
   auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
   auto &body = m.body();
 
@@ -750,8 +761,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
 
   // Create the correct number of regions
   if (int numRegions = op.getNumRegions()) {
-    for (int i = 0; i < numRegions; ++i)
-      m.body() << "  (void)" << builderOpState << ".addRegion();\n";
+    body << llvm::formatv(
+        "  for (unsigned i = 0; i != {0}; ++i)\n",
+        (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
+    body << "    (void)" << builderOpState << ".addRegion();\n";
   }
 
   // Result types
@@ -897,6 +910,8 @@ void OpEmitter::genCollectiveParamBuilder() {
                        builderOpState +
                        ", ArrayRef<Type> resultTypes, ValueRange operands, "
                        "ArrayRef<NamedAttribute> attributes";
+  if (op.getNumVariadicRegions())
+    params += ", unsigned numRegions";
   auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
   auto &body = m.body();
 
@@ -913,8 +928,10 @@ void OpEmitter::genCollectiveParamBuilder() {
 
   // Create the correct number of regions
   if (int numRegions = op.getNumRegions()) {
-    for (int i = 0; i < numRegions; ++i)
-      m.body() << "  (void)" << builderOpState << ".addRegion();\n";
+    body << llvm::formatv(
+        "  for (unsigned i = 0; i != {0}; ++i)\n",
+        (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
+    body << "    (void)" << builderOpState << ".addRegion();\n";
   }
 
   // Result types
@@ -1042,11 +1059,17 @@ void OpEmitter::buildParamList(std::string &paramList,
     }
   }
 
-  /// Insert parameters for the block and operands for each successor.
+  /// Insert parameters for each successor.
   for (const NamedSuccessor &succ : op.getSuccessors()) {
     paramList += (succ.isVariadic() ? ", ArrayRef<Block *> " : ", Block *");
     paramList += succ.name;
   }
+
+  /// Insert parameters for variadic regions.
+  for (const NamedRegion &region : op.getRegions()) {
+    if (region.isVariadic())
+      paramList += llvm::formatv(", unsigned {0}Count", region.name).str();
+  }
 }
 
 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
@@ -1110,9 +1133,12 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
   }
 
   // Create the correct number of regions.
-  if (int numRegions = op.getNumRegions()) {
-    for (int i = 0; i < numRegions; ++i)
-      body << "  (void)" << builderOpState << ".addRegion();\n";
+  for (const NamedRegion &region : op.getRegions()) {
+    if (region.isVariadic())
+      body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
+                      region.name);
+
+    body << "  (void)" << builderOpState << ".addRegion();\n";
   }
 
   // Push all successors to the result.
@@ -1436,33 +1462,42 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
 }
 
 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+  // If we have no regions, there is nothing more to do.
   unsigned numRegions = op.getNumRegions();
+  if (numRegions == 0)
+    return;
 
-  // Verify this op has the correct number of regions
-  body << formatv(
-      "  if (this->getOperation()->getNumRegions() != {0}) {\n    "
-      "return emitOpError(\"has incorrect number of regions: expected {0} but "
-      "found \") << this->getOperation()->getNumRegions();\n  }\n",
-      numRegions);
+  body << "{\n";
+  body << "    unsigned index = 0; (void)index;\n";
 
   for (unsigned i = 0; i < numRegions; ++i) {
     const auto &region = op.getRegion(i);
+    if (region.constraint.getPredicate().isNull())
+      continue;
 
-    std::string name = std::string(formatv("#{0}", i));
-    if (!region.name.empty()) {
-      name += std::string(formatv(" ('{0}')", region.name));
-    }
-
-    auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
+    body << "    for (Region &region : ";
+    body << formatv(
+        region.isVariadic()
+            ? "{0}()"
+            : "MutableArrayRef<Region>(this->getOperation()->getRegion({1}))",
+        region.name, i);
+    body << ") {\n";
     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
-                            &verifyCtx.withSelf(getRegion))
+                            &verifyCtx.withSelf("region"))
                           .str();
 
-    body << formatv("  if (!({0})) {\n    "
-                    "return emitOpError(\"region {1} failed to verify "
-                    "constraint: {2}\");\n  }\n",
-                    constraint, name, region.constraint.getDescription());
+    body << formatv("      (void)region;\n"
+                    "      if (!({0})) {\n        "
+                    "return emitOpError(\"region #\") << index << \" {1}"
+                    "failed to "
+                    "verify constraint: {2}\";\n      }\n",
+                    constraint,
+                    region.name.empty() ? "" : "('" + region.name + "') ",
+                    region.constraint.getDescription())
+         << "      ++index;\n"
+         << "    }\n";
   }
+  body << "  }\n";
 }
 
 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
@@ -1488,29 +1523,31 @@ void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
                             &verifyCtx.withSelf("successor"))
                           .str();
 
-    body << formatv(
-        "      (void)successor;\n"
-        "      if (!({0})) {\n        "
-        "return emitOpError(\"successor #\") << index << \"('{2}') failed to "
-        "verify constraint: {3}\";\n      }\n",
-        constraint, i, successor.name, successor.constraint.getDescription());
-    body << "    }\n";
+    body << formatv("      (void)successor;\n"
+                    "      if (!({0})) {\n        "
+                    "return emitOpError(\"successor #\") << index << \"('{1}') "
+                    "failed to "
+                    "verify constraint: {2}\";\n      }\n",
+                    constraint, successor.name,
+                    successor.constraint.getDescription())
+         << "      ++index;\n"
+         << "    }\n";
   }
   body << "  }\n";
 }
 
 /// Add a size count trait to the given operation class.
 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
-                              int numNonVariadic, int numVariadic) {
+                              int numTotal, int numVariadic) {
   if (numVariadic != 0) {
-    if (numNonVariadic == numVariadic)
+    if (numTotal == numVariadic)
       opClass.addTrait("OpTrait::Variadic" + traitKind + "s");
     else
       opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" +
-                       Twine(numNonVariadic - numVariadic) + ">::Impl");
+                       Twine(numTotal - numVariadic) + ">::Impl");
     return;
   }
-  switch (numNonVariadic) {
+  switch (numTotal) {
   case 0:
     opClass.addTrait("OpTrait::Zero" + traitKind);
     break;
@@ -1518,17 +1555,21 @@ static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
     opClass.addTrait("OpTrait::One" + traitKind);
     break;
   default:
-    opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) +
+    opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
                      ">::Impl");
     break;
   }
 }
 
 void OpEmitter::genTraits() {
+  // Add region size trait.
+  unsigned numRegions = op.getNumRegions();
+  unsigned numVariadicRegions = op.getNumVariadicRegions();
+  addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
+
+  // Add result size trait.
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariadicResults();
-
-  // Add return size trait.
   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
 
   // Add successor size trait.


        


More information about the Mlir-commits mailing list