[mlir] [clang] [llvm] [clang-tools-extra] [mlir][irdl] Add `irdl.base` op (PR #76400)

Fehr Mathieu via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 17 05:20:08 PST 2024


https://github.com/math-fehr updated https://github.com/llvm/llvm-project/pull/76400

>From 4363403ffcff10844c304426cb92bc559cf0d95c Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Sat, 23 Dec 2023 17:11:46 +0000
Subject: [PATCH 1/2] [mlir][irdl] Add irdl.base operation

---
 mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td  | 41 ++++++++++++
 .../include/mlir/Dialect/IRDL/IRDLVerifiers.h | 42 +++++++++++++
 mlir/lib/Dialect/IRDL/IR/IRDL.cpp             | 33 ++++++++++
 mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp          | 54 ++++++++++++++++
 mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp       | 33 ++++++++++
 mlir/test/Dialect/IRDL/invalid.irdl.mlir      | 43 +++++++++++++
 mlir/test/Dialect/IRDL/testd.irdl.mlir        | 48 +++++++++++---
 mlir/test/Dialect/IRDL/testd.mlir             | 63 ++++++++++++++++---
 8 files changed, 340 insertions(+), 17 deletions(-)
 create mode 100644 mlir/test/Dialect/IRDL/invalid.irdl.mlir

diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 681425f8174426a..c63a3a70f6703f6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -451,6 +451,47 @@ def IRDL_IsOp : IRDL_ConstraintOp<"is",
   let assemblyFormat = " $expected ` ` attr-dict ";
 }
 
+def IRDL_BaseOp : IRDL_ConstraintOp<"base",
+    [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Constraints an attribute/type base";
+  let description = [{
+    `irdl.base` defines a constraint that only accepts a single type
+    or attribute base, e.g. an `IntegerType`. The attribute base is defined
+    either by a symbolic reference to the corresponding IRDL definition,
+    or by the name of the base. Named bases are prefixed with `!` or `#`
+    respectively for types and attributes.
+
+    Example:
+
+    ```mlir
+    irdl.dialect @cmath {
+      irdl.type @complex {
+        %0 = irdl.base "!builtin.integer"
+        irdl.parameters(%0)
+      }
+
+      irdl.type @complex_wrapper {
+        %0 = irdl.base @complex
+        irdl.parameters(%0)
+      }
+    }
+    ```
+
+    The above program defines a `cmath.complex` type that expects a single
+    parameter, which is a type with base name `builtin.integer`, which is the
+    name of an `IntegerType` type.
+    It also defines a `cmath.complex_wrapper` type that expects a single
+    parameter, which is a type of base type `cmath.complex`.
+  }];
+
+  let arguments = (ins OptionalAttr<SymbolRefAttr>:$base_ref,
+                       OptionalAttr<StrAttr>:$base_name);
+  let results = (outs IRDL_AttributeType:$output);
+  let assemblyFormat = " ($base_ref^)? ($base_name^)? ` ` attr-dict";
+  let hasVerifier = 1;
+}
+
 def IRDL_ParametricOp : IRDL_ConstraintOp<"parametric",
     [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, Pure]> {
   let summary = "Constraints an attribute/type base and its parameters";
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index f8ce77cbc50e9ed..9ecb7c0107d7f8a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -99,6 +99,48 @@ class IsConstraint : public Constraint {
   Attribute expectedAttribute;
 };
 
+/// A constraint that checks that an attribute is of a given attribute base
+/// (e.g. IntegerAttr).
+class BaseAttrConstraint : public Constraint {
+public:
+  BaseAttrConstraint(TypeID baseTypeID, StringRef baseName)
+      : baseTypeID(baseTypeID), baseName(baseName) {}
+
+  virtual ~BaseAttrConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  /// The expected base attribute typeID.
+  TypeID baseTypeID;
+
+  /// The base attribute name, only used for error reporting.
+  StringRef baseName;
+};
+
+/// A constraint that checks that a type is of a given type base (e.g.
+/// IntegerType).
+class BaseTypeConstraint : public Constraint {
+public:
+  BaseTypeConstraint(TypeID baseTypeID, StringRef baseName)
+      : baseTypeID(baseTypeID), baseName(baseName) {}
+
+  virtual ~BaseTypeConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  /// The expected base type typeID.
+  TypeID baseTypeID;
+
+  /// The base type name, only used for error reporting.
+  StringRef baseName;
+};
+
 /// A constraint that checks that an attribute is of a
 /// specific dynamic attribute definition, and that all of its parameters
 /// satisfy the given constraints.
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index 33c6bb869a643f3..4eae2b03024c248 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -117,6 +117,39 @@ LogicalResult AttributesOp::verify() {
   return success();
 }
 
+LogicalResult BaseOp::verify() {
+  std::optional<StringRef> baseName = getBaseName();
+  std::optional<SymbolRefAttr> baseRef = getBaseRef();
+  if (baseName.has_value() == baseRef.has_value())
+    return emitOpError() << "the base type or attribute should be specified by "
+                            "either a name or a reference";
+
+  if (baseName &&
+      (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
+    return emitOpError() << "the base type or attribute name should start with "
+                            "'!' or '#'";
+
+  return success();
+}
+
+LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  std::optional<SymbolRefAttr> baseRef = getBaseRef();
+  if (!baseRef)
+    return success();
+
+  TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
+  if (typeOp)
+    return success();
+
+  AttributeOp attrOp =
+      symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
+  if (attrOp)
+    return success();
+
+  return emitOpError() << "'" << *baseRef
+                       << "' does not refer to a type or attribute definition";
+}
+
 /// Parse a value with its variadicity first. By default, the variadicity is
 /// single.
 ///
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index e172039712f24c9..0895306b8bce1a9 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -37,6 +37,60 @@ std::unique_ptr<Constraint> IsOp::getVerifier(
   return std::make_unique<IsConstraint>(getExpectedAttr());
 }
 
+std::unique_ptr<Constraint> BaseOp::getVerifier(
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
+  MLIRContext *ctx = getContext();
+
+  // Case where the input is a symbol reference.
+  // This corresponds to the case where the base is an IRDL type or attribute.
+  if (auto baseRef = getBaseRef()) {
+    Operation *defOp =
+        SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+
+    // Type case.
+    if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
+      DynamicTypeDefinition *typeDef = types.at(typeOp).get();
+      auto name = StringAttr::get(ctx, typeDef->getDialect()->getNamespace() +
+                                           "." + typeDef->getName().str());
+      return std::make_unique<BaseTypeConstraint>(typeDef->getTypeID(), name);
+    }
+
+    // Attribute case.
+    auto attrOp = cast<AttributeOp>(defOp);
+    DynamicAttrDefinition *attrDef = attrs.at(attrOp).get();
+    auto name = StringAttr::get(ctx, attrDef->getDialect()->getNamespace() +
+                                         "." + attrDef->getName().str());
+    return std::make_unique<BaseAttrConstraint>(attrDef->getTypeID(), name);
+  }
+
+  // Case where the input is string literal.
+  // This corresponds to the case where the base is a registered type or
+  // attribute.
+  StringRef baseName = getBaseName().value();
+
+  // Type case.
+  if (baseName[0] == '!') {
+    auto abstractType = AbstractType::lookup(baseName.drop_front(1), ctx);
+    if (!abstractType) {
+      emitError() << "no registered type with name " << baseName;
+      return nullptr;
+    }
+    return std::make_unique<BaseTypeConstraint>(abstractType->get().getTypeID(),
+                                                abstractType->get().getName());
+  }
+
+  auto abstractAttr = AbstractAttribute::lookup(baseName.drop_front(1), ctx);
+  if (!abstractAttr) {
+    emitError() << "no registered attribute with name " << baseName;
+    return nullptr;
+  }
+  return std::make_unique<BaseAttrConstraint>(abstractAttr->get().getTypeID(),
+                                              abstractAttr->get().getName());
+}
+
 std::unique_ptr<Constraint> ParametricOp::getVerifier(
     ArrayRef<Value> valueToConstr,
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index 90b068ba35831b1..2310c11ea0e8edf 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -69,6 +69,39 @@ LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
   return failure();
 }
 
+LogicalResult
+BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Attribute attr, ConstraintVerifier &context) const {
+  if (attr.getTypeID() == baseTypeID)
+    return success();
+
+  if (emitError)
+    return emitError() << "expected base attribute '" << baseName
+                       << "' but got '" << attr.getAbstractAttribute().getName()
+                       << "'";
+  return failure();
+}
+
+LogicalResult
+BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Attribute attr, ConstraintVerifier &context) const {
+  auto typeAttr = dyn_cast<TypeAttr>(attr);
+  if (!typeAttr) {
+    if (emitError)
+      return emitError() << "expected type, got attribute '" << attr;
+    return failure();
+  }
+
+  Type type = typeAttr.getValue();
+  if (type.getTypeID() == baseTypeID)
+    return success();
+
+  if (emitError)
+    return emitError() << "expected base type '" << baseName << "' but got '"
+                       << type.getAbstractType().getName() << "'";
+  return failure();
+}
+
 LogicalResult DynParametricAttrConstraint::verify(
     function_ref<InFlightDiagnostic()> emitError, Attribute attr,
     ConstraintVerifier &context) const {
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
new file mode 100644
index 000000000000000..d62bb498a7ad982
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file
+
+// Testing invalid IRDL IRs
+
+func.func private @foo()
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error at +1 {{'@foo' does not refer to a type or attribute definition}}
+    %0 = irdl.base @foo
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error at +1 {{the base type or attribute name should start with '!' or '#'}}
+    %0 = irdl.base "builtin.integer"
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error at +1 {{the base type or attribute name should start with '!' or '#'}}
+    %0 = irdl.base ""
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error at +1 {{the base type or attribute should be specified by either a name}}
+    %0 = irdl.base
+    irdl.parameters(%0)
+  }
+}
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index 684286e4afeb0fe..f828d95bdb81d59 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -11,6 +11,15 @@ irdl.dialect @testd {
     irdl.parameters(%0)
   }
 
+  // CHECK: irdl.attribute @parametric_attr {
+  // CHECK:  %[[v0:[^ ]*]] = irdl.any
+  // CHECK:  irdl.parameters(%[[v0]])
+  // CHECK: }
+  irdl.attribute @parametric_attr {
+    %0 = irdl.any
+    irdl.parameters(%0)
+  }
+
   // CHECK: irdl.type @attr_in_type_out {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
   // CHECK:   irdl.parameters(%[[v0]])
@@ -66,15 +75,40 @@ irdl.dialect @testd {
     irdl.results(%0)
   }
 
-  // CHECK: irdl.operation @dynbase {
-  // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @parametric<%[[v0]]>
+  // CHECK: irdl.operation @dyn_type_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric
   // CHECK:   irdl.results(%[[v1]])
   // CHECK: }
-  irdl.operation @dynbase {
-    %0 = irdl.any
-    %1 = irdl.parametric @parametric<%0>
-    irdl.results(%1)
+  irdl.operation @dyn_type_base {
+    %0 = irdl.base @parametric
+    irdl.results(%0)
+  }
+
+  // CHECK: irdl.operation @dyn_attr_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric_attr
+  // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
+  // CHECK: }
+  irdl.operation @dyn_attr_base {
+    %0 = irdl.base @parametric_attr
+    irdl.attributes {"attr1" = %0}
+  }
+
+  // CHECK: irdl.operation @named_type_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base "!builtin.integer"
+  // CHECK:   irdl.results(%[[v1]])
+  // CHECK: }
+  irdl.operation @named_type_base {
+    %0 = irdl.base "!builtin.integer"
+    irdl.results(%0)
+  }
+
+  // CHECK: irdl.operation @named_attr_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base "#builtin.integer"
+  // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
+  // CHECK: }
+  irdl.operation @named_attr_base {
+    %0 = irdl.base "#builtin.integer"
+    irdl.attributes {"attr1" = %0}
   }
 
   // CHECK: irdl.operation @dynparams {
diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
index bb1e9f46356411b..333bb96eb2e60fe 100644
--- a/mlir/test/Dialect/IRDL/testd.mlir
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -120,24 +120,67 @@ func.func @succeededAnyConstraint() {
 // -----
 
 //===----------------------------------------------------------------------===//
-// Dynamic base constraint
+// Base constraints
 //===----------------------------------------------------------------------===//
 
 func.func @succeededDynBaseConstraint() {
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
-  "testd.dynbase"() : () -> !testd.parametric<i32>
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i64>
-  "testd.dynbase"() : () -> !testd.parametric<i64>
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
-  "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+  "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+  "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+  "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+  // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+  "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+  // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
+  "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
   return
 }
 
 // -----
 
-func.func @failedDynBaseConstraint() {
-  // expected-error at +1 {{expected base type 'testd.parametric' but got 'i32'}}
-  "testd.dynbase"() : () -> i32
+func.func @failedDynTypeBaseConstraint() {
+  // expected-error at +1 {{expected base type 'testd.parametric' but got 'builtin.integer'}}
+  "testd.dyn_type_base"() : () -> i32
+  return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+  // expected-error at +1 {{expected base attribute 'testd.parametric_attr' but got 'builtin.type'}}
+  "testd.dyn_attr_base"() {attr1 = i32}: () -> ()
+  return
+}
+
+// -----
+
+
+func.func @succeededNamedBaseConstraint() {
+  // CHECK: "testd.named_type_base"() : () -> i32
+  "testd.named_type_base"() : () -> i32
+  // CHECK: "testd.named_type_base"() : () -> i64
+  "testd.named_type_base"() : () -> i64
+  // CHECK: "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+  "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+  // CHECK: "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+  "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @failedNamedTypeBaseConstraint() {
+  // expected-error at +1 {{expected base type 'builtin.integer' but got 'builtin.vector'}}
+  "testd.named_type_base"() : () -> vector<i32>
+  return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+  // expected-error at +1 {{expected base attribute 'builtin.integer' but got 'builtin.type'}}
+  "testd.named_attr_base"() {attr1 = i32}: () -> ()
   return
 }
 

>From 4c1d6f58c216029049c0c6dd9a567dbae694683e Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Tue, 26 Dec 2023 23:19:39 +0000
Subject: [PATCH 2/2] Add builders for `irdl.base`

---
 mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index c63a3a70f6703f6..aa6a8e93c0288ec 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -489,6 +489,16 @@ def IRDL_BaseOp : IRDL_ConstraintOp<"base",
                        OptionalAttr<StrAttr>:$base_name);
   let results = (outs IRDL_AttributeType:$output);
   let assemblyFormat = " ($base_ref^)? ($base_name^)? ` ` attr-dict";
+  
+  let builders = [
+    OpBuilder<(ins "SymbolRefAttr":$base_ref), [{
+      build($_builder, $_state, base_ref, {});
+    }]>,
+    OpBuilder<(ins "StringAttr":$base_name), [{
+      build($_builder, $_state, {}, base_name);
+    }]>,
+  ];
+
   let hasVerifier = 1;
 }
 



More information about the cfe-commits mailing list