[Mlir-commits] [mlir] 42987df - [mlir][irdl] Add `irdl.any_of` operation

Mathieu Fehr llvmlistbot at llvm.org
Wed May 17 13:54:13 PDT 2023


Author: Mathieu Fehr
Date: 2023-05-17T21:57:16+01:00
New Revision: 42987dfa3a85e0cec987b9f07a8ffe61073ddc52

URL: https://github.com/llvm/llvm-project/commit/42987dfa3a85e0cec987b9f07a8ffe61073ddc52
DIFF: https://github.com/llvm/llvm-project/commit/42987dfa3a85e0cec987b9f07a8ffe61073ddc52.diff

LOG: [mlir][irdl] Add `irdl.any_of` operation

The `irdl.any_of` operation represent a constraint that is satisfied
if any of its subconstraint is satisfied.

For instance, in the following example:
```
%0 = irdl.is f32
%1 = irdl.is f64
%2 = irdl.any_of(f32, f64)
```

`%2` can only be satisfied by `f32` or `f64`.

Note that the verification algorithm required by `irdl.any_of` is
non-trivial, since we want that the order of arguments of
`irdl.any_of` to not matter. For this reason, our registration
algorithm fails if two constraints used by `any_of` might be
satisfied by the same `Attribute`. This is approximated by checking
the possible `Attribute` bases of each constraints.

Depends on D145734

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D145735

Added: 
    mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
    mlir/test/Dialect/IRDL/cyclic-types.mlir
    mlir/test/Dialect/IRDL/test-type.irdl.mlir
    mlir/test/Dialect/IRDL/test-type.mlir

Modified: 
    mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
    mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
    mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
    mlir/lib/Dialect/IRDL/IRDLLoading.cpp
    mlir/test/Dialect/IRDL/cmath.irdl.mlir
    mlir/test/Dialect/IRDL/testd.irdl.mlir
    mlir/test/Dialect/IRDL/testd.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
index 0e45711ae441c..8f3f96a5677af 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td
@@ -30,9 +30,9 @@ def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> {
       }],
       "std::unique_ptr<::mlir::irdl::Constraint>",
       "getVerifier",
-      (ins "::mlir::SmallVector<Value> const&":$valueRes,
-      "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> &":$types,
-      "::mlir::DenseMap<::mlir::irdl::AttributeOp, std::unique_ptr<::mlir::DynamicAttrDefinition>> &":$attrs)
+      (ins "::mlir::ArrayRef<Value>":$valueToConstr,
+      "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> const&":$types,
+      "::mlir::DenseMap<::mlir::irdl::AttributeOp, std::unique_ptr<::mlir::DynamicAttrDefinition>> const&":$attrs)
     >
   ];
 }

diff  --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 5cce6858be2c6..f5b4600062ea7 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -356,5 +356,74 @@ def IRDL_Any : IRDL_ConstraintOp<"any",
   let assemblyFormat = " attr-dict ";
 }
 
+def IRDL_AnyOf : IRDL_ConstraintOp<"any_of",
+                  [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
+                   SameOperandsAndResultType]> {
+  let summary = "Constraints to the union of the provided constraints";
+  let description = [{
+    `irdl.any_of` defines a constraint that accepts any type or attribute that
+    satisfies at least one of its provided type constraints.
+
+    Example:
+
+    ```mlir
+    irdl.dialect cmath {
+      irdl.type complex {
+        %0 = irdl.is i32
+        %1 = irdl.is i64
+        %2 = irdl.is f32
+        %3 = irdl.is f64
+        %4 = irdl.any_of(%0, %1, %2, %3)
+        irdl.parameters(%4)
+      }
+    }
+    ```
+
+    The above program defines a type `complex` inside the dialect `cmath` that
+    can have a single type parameter that can be either `i32`, `i64`, `f32` or
+    `f32`.
+  }];
+
+  let arguments = (ins Variadic<IRDL_AttributeType>:$args);
+  let results = (outs IRDL_AttributeType:$output);
+  let assemblyFormat = [{ `(` $args `)` ` ` attr-dict }];
+}
+
+def IRDL_AllOf : IRDL_ConstraintOp<"all_of",
+                 [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
+                  SameOperandsAndResultType]> {
+  let summary = "Constraints to the intersection of the provided constraints";
+  let description = [{
+    `irdl.all_of` defines a constraint that accepts any type or attribute that
+    satisfies all of its provided constraints.
+
+    Example:
+
+    ```mlir
+    irdl.dialect cmath {
+      irdl.type complex_f32 {
+        %0 = irdl.is i32
+        %1 = irdl.is f32
+        %2 = irdl.any_of(%0, %1) // is 32-bit
+
+        %3 = irdl.is f32
+        %4 = irdl.is f64
+        %5 = irdl.any_of(%3, %4) // is a float
+
+        %6 = irdl.all_of(%2, %5) // is a 32-bit float
+        irdl.parameters(%6)
+      }
+    }
+    ```
+
+    The above program defines a type `complex` inside the dialect `cmath` that
+    can has one parameter that must be 32-bit long and a float (in other
+    words, that must be `f32`).
+  }];
+
+  let arguments = (ins Variadic<IRDL_AttributeType>:$args);
+  let results = (outs IRDL_AttributeType:$output);
+  let assemblyFormat = [{ `(` $args `)` ` ` attr-dict }];
+}
 
 #endif // MLIR_DIALECT_IRDL_IR_IRDLOPS

diff  --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index a9956cc630ccf..c0e839720200b 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -12,16 +12,18 @@ using namespace mlir;
 using namespace mlir::irdl;
 
 std::unique_ptr<Constraint> Is::getVerifier(
-    SmallVector<Value> const &valueToConstr,
-    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
-    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
   return std::make_unique<IsConstraint>(getExpectedAttr());
 }
 
 std::unique_ptr<Constraint> Parametric::getVerifier(
-    SmallVector<Value> const &valueToConstr,
-    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
-    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
   SmallVector<unsigned> constraints;
   for (Value arg : getArgs()) {
     for (auto [i, value] : enumerate(valueToConstr)) {
@@ -42,20 +44,57 @@ std::unique_ptr<Constraint> Parametric::getVerifier(
   }
 
   if (auto typeOp = dyn_cast<TypeOp>(defOp))
-    return std::make_unique<DynParametricTypeConstraint>(types[typeOp].get(),
+    return std::make_unique<DynParametricTypeConstraint>(types.at(typeOp).get(),
                                                          constraints);
 
   if (auto attrOp = dyn_cast<AttributeOp>(defOp))
-    return std::make_unique<DynParametricAttrConstraint>(attrs[attrOp].get(),
+    return std::make_unique<DynParametricAttrConstraint>(attrs.at(attrOp).get(),
                                                          constraints);
 
   llvm_unreachable("verifier should ensure that the referenced operation is "
                    "either a type or an attribute definition");
 }
 
+std::unique_ptr<Constraint> AnyOf::getVerifier(
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
+  SmallVector<unsigned> constraints;
+  for (Value arg : getArgs()) {
+    for (auto [i, value] : enumerate(valueToConstr)) {
+      if (value == arg) {
+        constraints.push_back(i);
+        break;
+      }
+    }
+  }
+
+  return std::make_unique<AnyOfConstraint>(constraints);
+}
+
+std::unique_ptr<Constraint> AllOf::getVerifier(
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
+  SmallVector<unsigned> constraints;
+  for (Value arg : getArgs()) {
+    for (auto [i, value] : enumerate(valueToConstr)) {
+      if (value == arg) {
+        constraints.push_back(i);
+        break;
+      }
+    }
+  }
+
+  return std::make_unique<AllOfConstraint>(constraints);
+}
+
 std::unique_ptr<Constraint> Any::getVerifier(
-    SmallVector<Value> const &valueToConstr,
-    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
-    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
   return std::make_unique<AnyAttributeConstraint>();
 }

diff  --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index f65d0eceb03a1..07f0e4e5e443e 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -239,6 +239,116 @@ static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
   return std::move(verifier);
 }
 
+/// Get the possible bases of a constraint. Return `true` if all bases can
+/// potentially be matched.
+/// A base is a type or an attribute definition. For instance, the base of
+/// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
+/// This function returns the following information through arguments:
+/// - `paramIds`: the set of type or attribute IDs that are used as bases.
+/// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
+/// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
+///   constraints.
+static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
+                     SmallPtrSet<Operation *, 4> &paramIrdlOps,
+                     SmallPtrSet<TypeID, 4> &isIds) {
+  // For `irdl.any_of`, we get the bases from all its arguments.
+  if (auto anyOf = dyn_cast<AnyOf>(op)) {
+    bool has_any = false;
+    for (Value arg : anyOf.getArgs())
+      has_any &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
+    return has_any;
+  }
+
+  // For `irdl.all_of`, we get the bases from the first argument.
+  // This is restrictive, but we can relax it later if needed.
+  if (auto allOf = dyn_cast<AllOf>(op))
+    return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
+                    isIds);
+
+  // For `irdl.parametric`, we get directly the base from the operation.
+  if (auto params = dyn_cast<Parametric>(op)) {
+    SymbolRefAttr symRef = params.getBaseType();
+    Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+    assert(defOp && "symbol reference should refer to an existing operation");
+    paramIrdlOps.insert(defOp);
+    return false;
+  }
+
+  // For `irdl.is`, we get the base TypeID directly.
+  if (auto is = dyn_cast<Is>(op)) {
+    Attribute expected = is.getExpected();
+    isIds.insert(expected.getTypeID());
+    return false;
+  }
+
+  // For `irdl.any`, we return `false` since we can match any type or attribute
+  // base.
+  if (auto isA = dyn_cast<Any>(op))
+    return true;
+
+  llvm_unreachable("unknown IRDL constraint");
+}
+
+/// Check that an any_of is in the subset IRDL can handle.
+/// IRDL uses a greedy algorithm to match constraints. This means that if we
+/// encounter an `any_of` with multiple constraints, we will match the first
+/// constraint that is satisfied. Thus, the order of constraints matter in
+/// `any_of` with our current algorithm.
+/// In order to make the order of constraints irrelevant, we require that
+/// all `any_of` constraint parameters are disjoint. For this, we check that
+/// the base parameters are all disjoints between `parametric` operations, and
+/// that they are disjoint between `parametric` and `is` operations.
+/// This restriction will be relaxed in the future, when we will change our
+/// algorithm to be non-greedy.
+static LogicalResult checkCorrectAnyOf(AnyOf anyOf) {
+  SmallPtrSet<TypeID, 4> paramIds;
+  SmallPtrSet<Operation *, 4> paramIrdlOps;
+  SmallPtrSet<TypeID, 4> isIds;
+
+  for (Value arg : anyOf.getArgs()) {
+    Operation *argOp = arg.getDefiningOp();
+    SmallPtrSet<TypeID, 4> argParamIds;
+    SmallPtrSet<Operation *, 4> argParamIrdlOps;
+    SmallPtrSet<TypeID, 4> argIsIds;
+
+    // Get the bases of this argument. If it can match any type or attribute,
+    // then our `any_of` should not be allowed.
+    if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
+      return failure();
+
+    // We check that the base parameters are all disjoints between `parametric`
+    // operations, and that they are disjoint between `parametric` and `is`
+    // operations.
+    for (TypeID id : argParamIds) {
+      if (isIds.count(id))
+        return failure();
+      bool inserted = paramIds.insert(id).second;
+      if (!inserted)
+        return failure();
+    }
+
+    // We check that the base parameters are all disjoints with `irdl.is`
+    // operations.
+    for (TypeID id : isIds) {
+      if (paramIds.count(id))
+        return failure();
+      isIds.insert(id);
+    }
+
+    // We check that all `parametric` operations are disjoint. We do not
+    // need to check that they are disjoint with `is` operations, since
+    // `is` operations cannot refer to attributes defined with `irdl.parametric`
+    // operations.
+    for (Operation *op : argParamIrdlOps) {
+      bool inserted = paramIrdlOps.insert(op).second;
+      if (!inserted)
+        return failure();
+    }
+  }
+
+  return success();
+}
+
 /// Load all dialects in the given module, without loading any operation, type
 /// or attribute definitions.
 static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
@@ -292,6 +402,13 @@ preallocateAttrDefs(ModuleOp op,
 }
 
 LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
+  // First, check that all any_of constraints are in a correct form.
+  // This is to ensure we can do the verification correctly.
+  WalkResult anyOfCorrects =
+      op.walk([](AnyOf anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
+  if (anyOfCorrects.wasInterrupted())
+    return op.emitError("any_of constraints are not in the correct form");
+
   // Preallocate all dialects, and type and attribute definitions.
   // In particular, this allocates TypeIDs so type and attributes can have
   // verifiers that refer to each other.

diff  --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index aaa51501791c0..997af08d24733 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -6,11 +6,15 @@ module {
 
     // CHECK: irdl.type @complex {
     // CHECK:   %[[v0:[^ ]*]] = irdl.is f32
-    // CHECK:   irdl.parameters(%[[v0]])
+    // CHECK:   %[[v1:[^ ]*]] = irdl.is f64
+    // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+    // CHECK:   irdl.parameters(%[[v2]])
     // CHECK: }
     irdl.type @complex {
       %0 = irdl.is f32
-      irdl.parameters(%0)
+      %1 = irdl.is f64
+      %2 = irdl.any_of(%0, %1)
+      irdl.parameters(%2)
     }
 
     // CHECK: irdl.operation @norm {
@@ -28,13 +32,17 @@ module {
 
     // CHECK: irdl.operation @mul {
     // CHECK:   %[[v0:[^ ]*]] = irdl.is f32
-    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+    // CHECK:   %[[v1:[^ ]*]] = irdl.is f64
+    // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
     // CHECK:   irdl.operands(%[[v3]], %[[v3]])
     // CHECK:   irdl.results(%[[v3]])
     // CHECK: }
     irdl.operation @mul {
       %0 = irdl.is f32
-      %3 = irdl.parametric @complex<%0>
+      %1 = irdl.is f64
+      %2 = irdl.any_of(%0, %1)
+      %3 = irdl.parametric @complex<%2>
       irdl.operands(%3, %3)
       irdl.results(%3)
     }

diff  --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
new file mode 100644
index 0000000000000..db8dfc5cb36ca
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// Types that have cyclic references.
+
+// CHECK: irdl.dialect @testd {
+irdl.dialect @testd {
+  // CHECK:   irdl.type @self_referencing {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.any
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+  // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
+  // CHECK:   irdl.parameters(%[[v3]])
+  // CHECK: }
+  irdl.type @self_referencing {
+    %0 = irdl.any
+    %1 = irdl.parametric @self_referencing<%0>
+    %2 = irdl.is i32
+    %3 = irdl.any_of(%1, %2)
+    irdl.parameters(%3)
+  }
+
+
+  // CHECK:   irdl.type @type1 {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.any
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+  // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
+  // CHECK:   irdl.parameters(%[[v3]])
+  irdl.type @type1 {
+    %0 = irdl.any
+    %1 = irdl.parametric @type2<%0>
+    %2 = irdl.is i32
+    %3 = irdl.any_of(%1, %2)
+    irdl.parameters(%3)
+  }
+
+  // CHECK:   irdl.type @type2 {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.any
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+  // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
+  // CHECK:   irdl.parameters(%[[v3]])
+  irdl.type @type2 {
+      %0 = irdl.any
+      %1 = irdl.parametric @type1<%0>
+      %2 = irdl.is i32
+      %3 = irdl.any_of(%1, %2)
+      irdl.parameters(%3)
+  }
+}

diff  --git a/mlir/test/Dialect/IRDL/cyclic-types.mlir b/mlir/test/Dialect/IRDL/cyclic-types.mlir
new file mode 100644
index 0000000000000..56dc2d61787b6
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/cyclic-types.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s --irdl-file=%S/cyclic-types.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s
+
+// Types that have cyclic references.
+
+// CHECK: !testd.self_referencing<i32>
+func.func @no_references(%v: !testd.self_referencing<i32>) {
+  return
+}
+
+// -----
+
+// CHECK: !testd.self_referencing<!testd.self_referencing<i32>>
+func.func @one_reference(%v: !testd.self_referencing<!testd.self_referencing<i32>>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{'i64' does not satisfy the constraint}}
+func.func @wrong_parameter(%v: !testd.self_referencing<i64>) {
+  return
+}
+
+// -----
+
+// CHECK: !testd.type1<i32>
+func.func @type1_no_references(%v: !testd.type1<i32>) {
+  return
+}
+
+// -----
+
+// CHECK: !testd.type1<!testd.type2<i32>>
+func.func @type1_one_references(%v: !testd.type1<!testd.type2<i32>>) {
+  return
+}
+
+// -----
+
+// CHECK: !testd.type1<!testd.type2<!testd.type1<i32>>>
+func.func @type1_two_references(%v: !testd.type1<!testd.type2<!testd.type1<i32>>>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{'i64' does not satisfy the constraint}}
+func.func @wrong_parameter_type1(%v: !testd.type1<i64>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{'i64' does not satisfy the constraint}}
+func.func @wrong_parameter_type2(%v: !testd.type2<i64>) {
+  return
+}

diff  --git a/mlir/test/Dialect/IRDL/test-type.irdl.mlir b/mlir/test/Dialect/IRDL/test-type.irdl.mlir
new file mode 100644
index 0000000000000..1bcfb0b8e20be
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/test-type.irdl.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+module {
+  // CHECK-LABEL: irdl.dialect @testd {
+  irdl.dialect @testd {
+    // CHECK: irdl.type @singleton
+    irdl.type @singleton
+
+    // CHECK: irdl.type @parametrized {
+    // CHECK:   %[[v0:[^ ]*]] = irdl.any
+    // CHECK:   %[[v1:[^ ]*]] = irdl.is i32
+    // CHECK:   %[[v2:[^ ]*]] = irdl.is i64
+    // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
+    // CHECK:   irdl.parameters(%[[v0]], %[[v3]])
+    // CHECK: }
+    irdl.type @parametrized {
+      %0 = irdl.any
+      %1 = irdl.is i32
+      %2 = irdl.is i64
+      %3 = irdl.any_of(%1, %2)
+      irdl.parameters(%0, %3)
+    }
+
+    // CHECK: irdl.operation @any {
+    // CHECK:   %[[v0:[^ ]*]] = irdl.any
+    // CHECK:   irdl.results(%[[v0]])
+    // CHECK: }
+    irdl.operation @any {
+      %0 = irdl.any
+      irdl.results(%0)
+    }
+  }
+}

diff  --git a/mlir/test/Dialect/IRDL/test-type.mlir b/mlir/test/Dialect/IRDL/test-type.mlir
new file mode 100644
index 0000000000000..9f79ebe4ba038
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/test-type.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --irdl-file=%S/test-type.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @succeededTypeVerifier() {
+    // CHECK: "testd.any"() : () -> !testd.singleton
+    "testd.any"() : () -> !testd.singleton
+
+    // CHECK-NEXT: "testd.any"() : () -> !testd.parametrized<f32, i32>
+    "testd.any"() : () -> !testd.parametrized<f32, i32>
+
+    // CHECK: "testd.any"() : () -> !testd.parametrized<i1, i64>
+    "testd.any"() : () -> !testd.parametrized<i1, i64>
+
+    return
+}
+
+// -----
+
+func.func @failedSingletonVerifier() {
+     // expected-error at +1 {{expected 0 type arguments, but had 1}}
+     "testd.any"() : () -> !testd.singleton<i32>
+}
+
+// -----
+
+func.func @failedParametrizedVerifierWrongNumOfArgs() {
+     // expected-error at +1 {{expected 2 type arguments, but had 1}}
+     "testd.any"() : () -> !testd.parametrized<i32>
+}
+
+// -----
+
+func.func @failedParametrizedVerifierWrongArgument() {
+     // expected-error at +1 {{'i1' does not satisfy the constraint}}
+     "testd.any"() : () -> !testd.parametrized<i32, i1>
+}

diff  --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index d4a33ca38a199..939b422759561 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -29,6 +29,34 @@ irdl.dialect @testd {
     irdl.results(%0)
   }
 
+  // CHECK: irdl.operation @anyof {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
+  // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+  // CHECK:   irdl.results(%[[v2]])
+  // CHECK: }
+  irdl.operation @anyof {
+    %0 = irdl.is i32
+    %1 = irdl.is i64
+    %2 = irdl.any_of(%0, %1)
+    irdl.results(%2)
+  }
+
+  // CHECK: irdl.operation @all_of {
+  // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
+  // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+  // CHECK:   %[[v3:[^ ]*]] = irdl.all_of(%[[v2]], %[[v1]])
+  // CHECK:   irdl.results(%[[v3]])
+  // CHECK: }
+  irdl.operation @all_of {
+    %0 = irdl.is i32
+    %1 = irdl.is i64
+    %2 = irdl.any_of(%0, %1)
+    %3 = irdl.all_of(%2, %1)
+    irdl.results(%3)
+  }
+
   // CHECK: irdl.operation @any {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
   // CHECK:   irdl.results(%[[v0]])
@@ -51,21 +79,29 @@ irdl.dialect @testd {
 
   // CHECK: irdl.operation @dynparams {
   // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
-  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
+  // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
   // CHECK:   irdl.results(%[[v3]])
   // CHECK: }
   irdl.operation @dynparams {
     %0 = irdl.is i32
-    %3 = irdl.parametric @parametric<%0>
+    %1 = irdl.is i64
+    %2 = irdl.any_of(%0, %1)
+    %3 = irdl.parametric @parametric<%2>
     irdl.results(%3)
   }
 
   // CHECK: irdl.operation @constraint_vars {
-  // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   irdl.results(%[[v0]], %[[v0]])
+  // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
+  // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
+  // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
+  // CHECK:   irdl.results(%[[v2]], %[[v2]])
   // CHECK: }
   irdl.operation @constraint_vars {
-    %0 = irdl.any
-    irdl.results(%0, %0)
+    %0 = irdl.is i32
+    %1 = irdl.is i64
+    %2 = irdl.any_of(%0, %1)
+    irdl.results(%2, %2)
   }
 }

diff  --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
index e9be54b60d0b9..abda3ccd5cbf4 100644
--- a/mlir/test/Dialect/IRDL/testd.mlir
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -55,6 +55,56 @@ func.func @failedEqConstraint() {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// AnyOf constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededAnyOfConstraint() {
+  // CHECK: "testd.anyof"() : () -> i32
+  "testd.anyof"() : () -> i32
+  // CHECK: "testd.anyof"() : () -> i64
+  "testd.anyof"() : () -> i64
+  return
+}
+
+// -----
+
+func.func @failedAnyOfConstraint() {
+  // expected-error at +1 {{'i1' does not satisfy the constraint}}
+  "testd.anyof"() : () -> i1
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// AllOf constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededAllOfConstraint() {
+  // CHECK: "testd.all_of"() : () -> i64
+  "testd.all_of"() : () -> i64
+  return
+}
+
+// -----
+
+func.func @failedAllOfConstraint1() {
+  // expected-error at +1 {{'i1' does not satisfy the constraint}}
+  "testd.all_of"() : () -> i1
+  return
+}
+
+// -----
+
+func.func @failedAllOfConstraint2() {
+  // expected-error at +1 {{expected 'i64' but got 'i32'}}
+  "testd.all_of"() : () -> i32
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Any constraint
 //===----------------------------------------------------------------------===//
@@ -76,8 +126,10 @@ func.func @succeededAnyConstraint() {
 func.func @succeededDynBaseConstraint() {
   // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
   "testd.dynbase"() : () -> !testd.parametric<i32>
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
-  "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
+  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i64>
+  "testd.dynbase"() : () -> !testd.parametric<i64>
+  // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
+  "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
   return
 }
 
@@ -98,6 +150,8 @@ func.func @failedDynBaseConstraint() {
 func.func @succeededDynParamsConstraint() {
   // CHECK: "testd.dynparams"() : () -> !testd.parametric<i32>
   "testd.dynparams"() : () -> !testd.parametric<i32>
+  // CHECK: "testd.dynparams"() : () -> !testd.parametric<i64>
+  "testd.dynparams"() : () -> !testd.parametric<i64>
   return
 }
 
@@ -112,7 +166,7 @@ func.func @failedDynParamsConstraintBase() {
 // -----
 
 func.func @failedDynParamsConstraintParam() {
-  // expected-error at +1 {{expected 'i32' but got 'i1'}}
+  // expected-error at +1 {{'i1' does not satisfy the constraint}}
   "testd.dynparams"() : () -> !testd.parametric<i1>
   return
 }


        


More information about the Mlir-commits mailing list