[Mlir-commits] [mlir] d4bde69 - [mlir][irdl] Introduce a way to define regions

Daniil Dudkin llvmlistbot at llvm.org
Wed Aug 23 07:55:40 PDT 2023


Author: Daniil Dudkin
Date: 2023-08-23T17:55:10+03:00
New Revision: d4bde6968e11a6b8ff9ecd5fbca1d98d8f6feeca

URL: https://github.com/llvm/llvm-project/commit/d4bde6968e11a6b8ff9ecd5fbca1d98d8f6feeca
DIFF: https://github.com/llvm/llvm-project/commit/d4bde6968e11a6b8ff9ecd5fbca1d98d8f6feeca.diff

LOG: [mlir][irdl] Introduce a way to define regions

This patch introduces new operations:
`irdl.region` and `irdl.regions`.
The former lets us to specify characteristics of a region,
such as the arguments for the entry block and the number of blocks.
The latter accepts all results of the former operations
to define the set of the regions for the operation.

Example:

```
    irdl.dialect @example {
      irdl.operation @op_with_regions {
          %r0 = irdl.region
          %r1 = irdl.region()
          %v0 = irdl.is i32
          %v1 = irdl.is i64
          %r2 = irdl.region(%v0, %v1)
          %r3 = irdl.region with size 3

          irdl.regions(%r0, %r1, %r2, %r3)
      }
    }
```

The above snippet demonstrates an operation named `@op_with_regions`,
which is constrained to have four regions.

* Region `%r0` doesn't have any constraints on the arguments or the number of blocks.
* Region `%r1` should have an empty set of arguments.
* Region `%r2` should have two arguments of types `i32` and `i64`.
* Region `%r3` should contain exactly three blocks.

In the future the block count constraint may be expanded to support range of possible number of blocks.

Reviewed By: math-fehr, Mogball

Differential Revision: https://reviews.llvm.org/D155112

Added: 
    mlir/test/Dialect/IRDL/regions-ops.irdl.mlir

Modified: 
    mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
    mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
    mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
    mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
    mlir/lib/Dialect/IRDL/IR/IRDL.cpp
    mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
    mlir/lib/Dialect/IRDL/IRDLLoading.cpp
    mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
    mlir/test/Dialect/IRDL/testd.irdl.mlir
    mlir/test/Dialect/IRDL/testd.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
index 8f3f96a5677af8..d1545630dc337c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
@@ -15,12 +15,12 @@
 
 include "mlir/IR/OpBase.td"
 
-def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> {
+class VerifyInterface<string name, string return_type>
+      : OpInterface<"Verify" # name # "Interface"> {
   let cppNamespace = "::mlir::irdl";
 
-  let description = [{
-    Interface to get an IRDL constraint verifier from an operation. 
-  }];
+  let description = "Interface to get an IRDL"
+                  # name # "verifier from an operation.";
 
   let methods = [
     InterfaceMethod<
@@ -28,13 +28,19 @@ def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> {
         Get an instance of a constraint verifier for the associated operation."
         Returns `nullptr` upon failure.
       }],
-      "std::unique_ptr<::mlir::irdl::Constraint>",
+      "std::unique_ptr<::mlir::irdl::" # return_type # ">",
       "getVerifier",
       (ins "::mlir::ArrayRef<Value>":$valueToConstr,
-      "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> const&":$types,
-      "::mlir::DenseMap<::mlir::irdl::AttributeOp, std::unique_ptr<::mlir::DynamicAttrDefinition>> const&":$attrs)
+      [{::mlir::DenseMap<::mlir::irdl::TypeOp,
+        std::unique_ptr<::mlir::DynamicTypeDefinition>> const&}]:$types,
+      [{::mlir::DenseMap<::mlir::irdl::AttributeOp,
+        std::unique_ptr<::mlir::DynamicAttrDefinition>> const&}]:$attrs)
     >
   ];
 }
 
+def VerifyConstraintInterface : VerifyInterface<"Constraint", "Constraint"> {}
+
+def VerifyRegionInterface : VerifyInterface<"Region", "RegionConstraint"> {}
+
 #endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES

diff  --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 250119c89d3a7f..681425f8174426 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -162,7 +162,8 @@ def IRDL_ParametersOp : IRDL_Op<"parameters",
 
 def IRDL_OperationOp : IRDL_Op<"operation",
     [HasParent<"DialectOp">, NoTerminator, NoRegionArguments,
-    AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp">, Symbol]> {
+    AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp, RegionsOp">,
+    Symbol]> {
   let summary = "Define a new operation";
   let description = [{
     `irdl.operation` defines a new operation belonging to the `irdl.dialect`
@@ -328,6 +329,92 @@ def IRDL_AttributesOp : IRDL_Op<"attributes", [HasParent<"OperationOp">]> {
   let hasVerifier = true;
 }
 
+def IRDL_RegionOp : IRDL_Op<"region",
+    [HasParent<"OperationOp">, VerifyRegionInterface,
+    DeclareOpInterfaceMethods<VerifyRegionInterface>]> {
+  let summary = "Define a region of an operation";
+  let description = [{
+    The irdl.region construct defines a set of characteristics
+    that a region of an operation should satify.
+
+    These characteristics include constraints for the entry block arguments
+    of the region and the total number of blocks it contains.
+    The number of blocks must be a non-zero and non-negative integer,
+    and it is optional by default.
+    The set of constraints for the entry block arguments may be optional or
+    empty. If no parentheses are provided, the set is assumed to be optional,
+    and the arguments are not constrained in any way. If parentheses are
+    provided with no arguments, it means that the region must have
+    no entry block arguments
+
+
+    Example:
+
+    ```mlir
+    irdl.dialect @example {
+      irdl.operation @op_with_regions {
+          %r0 = irdl.region
+          %r1 = irdl.region()
+          %v0 = irdl.is i32
+          %v1 = irdl.is i64
+          %r2 = irdl.region(%v0, %v1)
+          %r3 = irdl.region with size 3
+
+          irdl.regions(%r0, %r1, %r2, %r3)
+      }
+    }
+    ```
+
+    The above snippet demonstrates an operation named `@op_with_regions`,
+    which is constrained to have four regions.
+
+    * Region `%r0` doesn't have any constraints on the arguments
+      or the number of blocks.
+    * Region `%r1` should have an empty set of arguments.
+    * Region `%r2` should have two arguments of types `i32` and `i64`.
+    * Region `%r3` should contain exactly three blocks.
+  }];
+  let arguments = (ins Variadic<IRDL_AttributeType>:$entryBlockArgs,
+                    OptionalAttr<I32Attr>:$numberOfBlocks,
+                    UnitAttr:$constrainedArguments);
+  let results = (outs IRDL_RegionType:$output);
+
+  let assemblyFormat = [{
+    ``(`(` $entryBlockArgs $constrainedArguments^ `)`)?
+    ``(` ` `with` `size` $numberOfBlocks^)? attr-dict
+  }];
+
+  let hasVerifier = true;
+}
+
+def IRDL_RegionsOp : IRDL_Op<"regions", [HasParent<"OperationOp">]> {
+  let summary = "Define the regions of an operation";
+  let description = [{
+    `irdl.regions` defines the regions of an operation by accepting
+    values produced by `irdl.region` operation as arguments.
+
+    Example:
+
+    ```mlir
+    irdl.dialect @example {
+      irdl.operation @op_with_regions {
+        %r1 = irdl.region with size 3
+        %0 = irdl.any
+        %r2 = irdl.region(%0)
+        irdl.regions(%r1, %r2)
+      }
+    }
+    ```
+
+    In the snippet above the operation is constrained to have two regions.
+    The first region should contain three blocks.
+    The second region should have one region with one argument.
+  }];
+
+  let arguments = (ins Variadic<IRDL_RegionType>:$args);
+  let assemblyFormat = " `(` $args `)` attr-dict ";
+}
+
 //===----------------------------------------------------------------------===//
 // IRDL Constraint operations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 7073eb81664aa3..2fcf1b41ffd78f 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -51,4 +51,28 @@ def IRDL_AttributeType : IRDL_Type<"Attribute", "attribute"> {
   }];
 }
 
+def IRDL_RegionType : IRDL_Type<"Region", "region"> {
+  let summary = "IRDL handle to a region definition";
+  let description = [{
+    This type represents a region constraint. It is produced by
+    the `irdl.region` operation and consumed by the `irdl.regions` operation.
+    The region can be constrained on the number of arguments
+    and the number of blocks.
+
+    Example:
+    ```mlir
+    irdl.dialect @example {
+      irdl.operation @op_with_regions {
+        %r1 = irdl.region with size 3
+        %0 = irdl.any
+        %r2 = irdl.region(%0)
+        irdl.regions(%r1, %r2)
+      }
+    }
+    ```
+
+    Here we have `%r1` and `%r2`, both of which have the type `!irdl.region`.
+  }];
+}
+
 #endif // MLIR_DIALECT_IRDL_IR_IRDLTYPES

diff  --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index 8f0628e37f1eb4..f8ce77cbc50e9e 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -14,8 +14,10 @@
 #define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Region.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 #include <optional>
 
 namespace mlir {
@@ -178,6 +180,30 @@ class AnyAttributeConstraint : public Constraint {
                        ConstraintVerifier &context) const override;
 };
 
+/// A constraint checking that a region satisfies `irdl.region` requirements
+struct RegionConstraint {
+  /// The constructor accepts constrained entities from the `irdl.region`
+  /// operation, such as slots of constraints for the region's arguments and the
+  /// block count.
+
+  // Both entities are optional, which means if an entity is not present, then
+  // it is not constrained.
+  RegionConstraint(std::optional<SmallVector<unsigned>> argumentConstraints,
+                   std::optional<size_t> blockCount)
+      : argumentConstraints(std::move(argumentConstraints)),
+        blockCount(blockCount) {}
+
+  /// Check that the `region` satisfies the constraint.
+  ///
+  /// `constraintContext` is needed to verify the region's arguments
+  /// constraints.
+  LogicalResult verify(mlir::Region &region,
+                       ConstraintVerifier &constraintContext);
+
+private:
+  std::optional<SmallVector<unsigned>> argumentConstraints;
+  std::optional<size_t> blockCount;
+};
 } // namespace irdl
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index 009fc0d144a81e..33c6bb869a643f 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -9,10 +9,13 @@
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -221,6 +224,15 @@ static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
   p << '}';
 }
 
+LogicalResult RegionOp::verify() {
+  if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
+    if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
+      return emitOpError("the number of blocks is expected to be >= 1 but got ")
+             << number;
+    }
+  return success();
+}
+
 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
 
 #define GET_TYPEDEF_CLASSES

diff  --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 9a79f9fa55a212..e172039712f24c 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,10 +7,28 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/IR/ValueRange.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::irdl;
 
+/// Maps given `args` to the index in the `valueToConstr`
+static SmallVector<unsigned>
+getConstraintIndicesForArgs(mlir::OperandRange args,
+                            ArrayRef<Value> valueToConstr) {
+  SmallVector<unsigned> constraints;
+  for (Value arg : args) {
+    for (auto [i, value] : enumerate(valueToConstr)) {
+      if (value == arg) {
+        constraints.push_back(i);
+        break;
+      }
+    }
+  }
+  return constraints;
+}
+
 std::unique_ptr<Constraint> IsOp::getVerifier(
     ArrayRef<Value> valueToConstr,
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
@@ -24,15 +42,8 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
     DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
         &attrs) {
-  SmallVector<unsigned> constraints;
-  for (Value arg : getArgs()) {
-    for (auto [i, value] : enumerate(valueToConstr)) {
-      if (value == arg) {
-        constraints.push_back(i);
-        break;
-      }
-    }
-  }
+  SmallVector<unsigned> constraints =
+      getConstraintIndicesForArgs(getArgs(), valueToConstr);
 
   // Symbol reference case for the base
   SymbolRefAttr symRef = getBaseType();
@@ -60,17 +71,8 @@ std::unique_ptr<Constraint> AnyOfOp::getVerifier(
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
     DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
         &attrs) {
-  SmallVector<unsigned> constraints;
-  for (Value arg : getArgs()) {
-    for (auto [i, value] : enumerate(valueToConstr)) {
-      if (value == arg) {
-        constraints.push_back(i);
-        break;
-      }
-    }
-  }
-
-  return std::make_unique<AnyOfConstraint>(constraints);
+  return std::make_unique<AnyOfConstraint>(
+      getConstraintIndicesForArgs(getArgs(), valueToConstr));
 }
 
 std::unique_ptr<Constraint> AllOfOp::getVerifier(
@@ -78,17 +80,8 @@ std::unique_ptr<Constraint> AllOfOp::getVerifier(
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
     DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
         &attrs) {
-  SmallVector<unsigned> constraints;
-  for (Value arg : getArgs()) {
-    for (auto [i, value] : enumerate(valueToConstr)) {
-      if (value == arg) {
-        constraints.push_back(i);
-        break;
-      }
-    }
-  }
-
-  return std::make_unique<AllOfConstraint>(constraints);
+  return std::make_unique<AllOfConstraint>(
+      getConstraintIndicesForArgs(getArgs(), valueToConstr));
 }
 
 std::unique_ptr<Constraint> AnyOp::getVerifier(
@@ -98,3 +91,15 @@ std::unique_ptr<Constraint> AnyOp::getVerifier(
         &attrs) {
   return std::make_unique<AnyAttributeConstraint>();
 }
+
+std::unique_ptr<RegionConstraint> RegionOp::getVerifier(
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
+  return std::make_unique<RegionConstraint>(
+      getConstrainedArguments() ? std::optional{getConstraintIndicesForArgs(
+                                      getEntryBlockArgs(), valueToConstr)}
+                                : std::nullopt,
+      getNumberOfBlocks());
+}

diff  --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 54869913b262cb..a95235f407b0c8 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/ExtensibleDialect.h"
@@ -190,7 +191,7 @@ LogicalResult getResultSegmentSizes(Operation *op,
 /// This encodes the logic of the verification method for operations defined
 /// with IRDL.
 static LogicalResult irdlOpVerifier(
-    Operation *op, ArrayRef<std::unique_ptr<Constraint>> constraints,
+    Operation *op, ConstraintVerifier &verifier,
     ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
     ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
     const DenseMap<StringAttr, size_t> &attributeConstrs) {
@@ -209,8 +210,6 @@ static LogicalResult irdlOpVerifier(
 
   auto emitError = [op] { return op->emitError(); };
 
-  ConstraintVerifier verifier(constraints);
-
   /// Сheck that we have all needed attributes passed
   /// and they satisfy the constraints.
   DictionaryAttr actualAttrs = op->getAttrDictionary();
@@ -254,6 +253,23 @@ static LogicalResult irdlOpVerifier(
   return success();
 }
 
+static LogicalResult irdlRegionVerifier(
+    Operation *op, ConstraintVerifier &verifier,
+    ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
+  if (op->getNumRegions() != regionsConstraints.size()) {
+    return op->emitOpError()
+           << "unexpected number of regions: expected "
+           << regionsConstraints.size() << " but got " << op->getNumRegions();
+  }
+
+  for (auto [constraint, region] :
+       llvm::zip(regionsConstraints, op->getRegions()))
+    if (failed(constraint->verify(region, verifier)))
+      return failure();
+
+  return success();
+}
+
 /// Define and load an operation represented by a `irdl.operation`
 /// operation.
 static WalkResult loadOperation(
@@ -262,6 +278,7 @@ static WalkResult loadOperation(
     DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
   // Resolve SSA values to verifier constraint slots
   SmallVector<Value> constrToValue;
+  SmallVector<Value> regionToValue;
   for (Operation &op : op->getRegion(0).getOps()) {
     if (isa<VerifyConstraintInterface>(op)) {
       if (op.getNumResults() != 1)
@@ -269,6 +286,12 @@ static WalkResult loadOperation(
                << "IRDL constraint operations must have exactly one result";
       constrToValue.push_back(op.getResult(0));
     }
+    if (isa<VerifyRegionInterface>(op)) {
+      if (op.getNumResults() != 1)
+        return op.emitError()
+               << "IRDL constraint operations must have exactly one result";
+      regionToValue.push_back(op.getResult(0));
+    }
   }
 
   // Build the verifiers for each constraint slot
@@ -283,6 +306,15 @@ static WalkResult loadOperation(
     constraints.push_back(std::move(verifier));
   }
 
+  // Build region constraints
+  SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints;
+  for (Value v : regionToValue) {
+    VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
+    std::unique_ptr<RegionConstraint> verifier =
+        op.getVerifier(constrToValue, types, attrs);
+    regionConstraints.push_back(std::move(verifier));
+  }
+
   SmallVector<size_t> operandConstraints;
   SmallVector<Variadicity> operandVariadicity;
 
@@ -352,18 +384,25 @@ static WalkResult loadOperation(
 
   auto verifier =
       [constraints{std::move(constraints)},
+       regionConstraints{std::move(regionConstraints)},
        operandConstraints{std::move(operandConstraints)},
        operandVariadicity{std::move(operandVariadicity)},
        resultConstraints{std::move(resultConstraints)},
        resultVariadicity{std::move(resultVariadicity)},
        attributesContraints{std::move(attributesContraints)}](Operation *op) {
-        return irdlOpVerifier(op, constraints, operandConstraints,
-                              operandVariadicity, resultConstraints,
-                              resultVariadicity, attributesContraints);
+        ConstraintVerifier verifier(constraints);
+        const LogicalResult opVerifierResult = irdlOpVerifier(
+            op, verifier, operandConstraints, operandVariadicity,
+            resultConstraints, resultVariadicity, attributesContraints);
+        const LogicalResult opRegionVerifierResult =
+            irdlRegionVerifier(op, verifier, regionConstraints);
+        return LogicalResult::success(opVerifierResult.succeeded() &&
+                                      opRegionVerifierResult.succeeded());
       };
 
-  // IRDL does not support defining regions.
-  auto regionVerifier = [](Operation *op) { return success(); };
+  // IRDL supports only checking number of blocks and argument contraints
+  // It is done in the main verifier to reuse `ConstraintVerifier` context
+  auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
 
   auto opDef = DynamicOpDefinition::get(
       op.getName(), dialect, std::move(verifier), std::move(regionVerifier),

diff  --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index 8b09f441a4bd10..90b068ba35831b 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -11,9 +11,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 using namespace mlir::irdl;
@@ -175,3 +182,45 @@ AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
                                ConstraintVerifier &context) const {
   return success();
 }
+
+LogicalResult RegionConstraint::verify(mlir::Region &region,
+                                       ConstraintVerifier &constraintContext) {
+  const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) {
+    return [loc, parentOp] {
+      InFlightDiagnostic diag = mlir::emitError(loc);
+      // If we already have been given location of the parent operation, which
+      // might happen when the region location is passed, we do not want to
+      // produce the note on the same location
+      if (loc != parentOp->getLoc())
+        diag.attachNote(parentOp->getLoc()).append("see the operation");
+      return diag;
+    };
+  };
+
+  if (blockCount.has_value() && *blockCount != region.getBlocks().size()) {
+    return emitError(region.getLoc())()
+           << "expected region " << region.getRegionNumber() << " to have "
+           << *blockCount << " block(s) but got " << region.getBlocks().size();
+  }
+
+  if (argumentConstraints.has_value()) {
+    auto actualArgs = region.getArguments();
+    if (actualArgs.size() != argumentConstraints->size()) {
+      const mlir::Location firstArgLoc =
+          actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc();
+      return emitError(firstArgLoc)()
+             << "expected region " << region.getRegionNumber() << " to have "
+             << argumentConstraints->size() << " arguments but got "
+             << actualArgs.size();
+    }
+
+    for (auto [arg, constraint] : llvm::zip(actualArgs, *argumentConstraints)) {
+      mlir::Attribute type = TypeAttr::get(arg.getType());
+      if (failed(constraintContext.verify(emitError(arg.getLoc()), type,
+                                          constraint))) {
+        return failure();
+      }
+    }
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir b/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir
new file mode 100644
index 00000000000000..762f992e786d48
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file
+
+irdl.dialect @testRegionOpNegativeNumber {
+    irdl.operation @op {
+        // expected-error @below {{'irdl.region' op the number of blocks is expected to be >= 1 but got -42}}
+        %r1 = irdl.region with size -42
+    }
+}
+
+// -----
+
+irdl.dialect @testRegionsOpWrongOperation {
+    irdl.operation @op {
+        // expected-note @below {{prior use here}}
+        %r1 = irdl.any
+        // expected-error @below {{use of value '%r1' expects 
diff erent type than prior uses: '!irdl.region' vs '!irdl.attribute'}}
+        irdl.regions(%r1)
+    }
+}

diff  --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index 485a6aedaa660b..684286e4afeb0f 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -119,4 +119,37 @@ irdl.dialect @testd {
       "attr2" = %1
     }
   }
+  // CHECK: irdl.operation @regions {
+  // CHECK:   %[[r0:[^ ]*]] = irdl.region
+  // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
+  // CHECK:   %[[r1:[^ ]*]] = irdl.region(%[[v0]], %[[v1]])
+  // CHECK:   %[[r2:[^ ]*]] = irdl.region with size 3
+  // CHECK:   %[[r3:[^ ]*]] = irdl.region()
+  // CHECK:   irdl.regions(%[[r0]], %[[r1]], %[[r2]], %[[r3]])
+  // CHECK: }
+  irdl.operation @regions {
+    %r0 = irdl.region
+    %v0 = irdl.is i32
+    %v1 = irdl.is i64
+    %r1 = irdl.region(%v0, %v1)
+    %r2 = irdl.region with size 3
+    %r3 = irdl.region()
+
+    irdl.regions(%r0, %r1, %r2, %r3)
+  }
+
+  // CHECK: irdl.operation @region_and_operand {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.any
+  // CHECK:   %[[r0:[^ ]*]] = irdl.region(%[[v0]])
+  // CHECK:   irdl.operands(%[[v0]])
+  // CHECK:   irdl.regions(%[[r0]])
+  // CHECK: }
+  irdl.operation @region_and_operand {
+    %v0 = irdl.any
+    %r0 = irdl.region(%v0)
+
+    irdl.operands(%v0)
+    irdl.regions(%r0)
+  }
 }

diff  --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
index 1aa539d8b57741..bb1e9f46356411 100644
--- a/mlir/test/Dialect/IRDL/testd.mlir
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -234,3 +234,363 @@ func.func @failedAttrsConstraint2() {
   "testd.attrs"() {attr1 = i32, attr2 = i32} : () -> ()
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Regions
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @succeededRegions
+func.func @succeededRegions() {
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+// -----
+
+// CHECK: func.func @succeededRegionWithNoConstraints
+func.func @succeededRegionWithNoConstraints() {
+  "testd.regions"() (
+  {
+    ^bb1(%arg0: i32, %arg1: i64, %arg2 : f64):
+      llvm.unreachable
+    ^bb2(%arg3: i32, %arg4: i64, %arg5 : f64):
+      llvm.unreachable
+    ^bb3(%arg6: i32, %arg7: i64, %arg8 : f64):
+      llvm.unreachable
+    ^bb4(%arg9: i32, %arg10: i64, %arg11 : f64):
+      llvm.unreachable
+    ^bb5(%arg12: i32, %arg13: i64, %arg14 : f64):
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionArgsLessThanNeeded() {
+  // expected-note at +1 {{see the operation}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    // expected-error at +1 {{expected region 1 to have 2 arguments but got 1}}
+    ^bb1(%arg0: i32):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionArgsMoreThanNeeded() {
+  // expected-note at +1 {{see the operation}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    // expected-error at +1 {{expected region 1 to have 2 arguments but got 3}}
+    ^bb1(%arg0: i32, %arg1: i64, %arg2 : f64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionArgsEmptyButRequired() {
+  // expected-error at +1 {{expected region 1 to have 2 arguments but got 0}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1():
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @faIledRegionArgsConstraint() {
+  // expected-note at +1 {{see the operation}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    // expected-error at +1 {{expected 'i64' but got 'f64'}}
+    ^bb1(%arg0: i32, %arg1: f64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionBlocksCountLessThanNeeded() {
+  // expected-error at +1 {{expected region 2 to have 3 block(s) but got 2}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionBlocksCountMoreThanNeeded() {
+  // expected-error at +1 {{expected region 2 to have 3 block(s) but got 4}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb4:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionWithEmptyArgs() {
+  // expected-note at +1 {{see the operation}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    // expected-error at +1 {{expected region 3 to have 0 arguments but got 2}}
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionWithLessBlocksThanNeeded() {
+  // expected-error at +1 {{'testd.regions' op unexpected number of regions: expected 4 but got 3}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedRegionWithMoreBlocksThanNeeded() {
+  // expected-error at +1 {{'testd.regions' op unexpected number of regions: expected 4 but got 5}}
+  "testd.regions"() (
+  {
+    ^bb1:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1:
+      cf.br ^bb3
+    ^bb2:
+      cf.br ^bb3
+    ^bb3:
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  },
+  {
+    ^bb1(%arg0: i32, %arg1: i64):
+      llvm.unreachable
+  }) : () -> ()
+
+  return
+}
+
+// -----
+
+func.func @successReuseConstraintBetweenRegionAndOperand() {
+  %0 = arith.constant 42 : i32
+  "testd.region_and_operand"(%0) ({
+    ^bb(%1: i32):
+      llvm.unreachable
+  }) : (i32) -> ()
+
+  return
+}
+
+// -----
+
+func.func @failedReuseConstraintBetweenRegionAndOperand() {
+  %0 = arith.constant 42 : i32
+  // expected-note at +1 {{see the operation}}
+  "testd.region_and_operand"(%0) ({
+    // expected-error at +1 {{expected 'i32' but got 'i64'}}
+    ^bb(%1: i64):
+      llvm.unreachable
+  }) : (i32) -> ()
+
+  return
+}


        


More information about the Mlir-commits mailing list