[Mlir-commits] [mlir] 914cfa4 - [mlir][irdl] Add `irdl.base` op (#76400)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 18 08:31:44 PST 2024
Author: Fehr Mathieu
Date: 2024-01-18T16:31:40Z
New Revision: 914cfa41385606fe81c3afd296a6ca3ab975a97d
URL: https://github.com/llvm/llvm-project/commit/914cfa41385606fe81c3afd296a6ca3ab975a97d
DIFF: https://github.com/llvm/llvm-project/commit/914cfa41385606fe81c3afd296a6ca3ab975a97d.diff
LOG: [mlir][irdl] Add `irdl.base` op (#76400)
The `irdl.base` op represent an attribute constraint that will check
that the
base of a type or attribute is the expected one (e.g. `IntegerType`) .
Example:
```mlir
irdl.dialect @cmath {
irdl.type @complex {
%0 = irdl.base "!builtin.integer"
irdl.parameters(%0)
}
irdl.type @complex_wrapper {
%0 = irdl.base @complex
irdl.parameters(%0)
}
}
```
The above program defines a `cmath.complex` type that expects a single
parameter, which is a type with base name `builtin.integer`, which is
the
name of an `IntegerType` type.
It also defines a `cmath.complex_wrapper` type that expects a single
parameter, which is a type of base type `cmath.complex`.
Added:
mlir/test/Dialect/IRDL/invalid.irdl.mlir
Modified:
mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
mlir/lib/Dialect/IRDL/IR/IRDL.cpp
mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
mlir/test/Dialect/IRDL/testd.irdl.mlir
mlir/test/Dialect/IRDL/testd.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 681425f8174426..aa6a8e93c0288e 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -451,6 +451,57 @@ def IRDL_IsOp : IRDL_ConstraintOp<"is",
let assemblyFormat = " $expected ` ` attr-dict ";
}
+def IRDL_BaseOp : IRDL_ConstraintOp<"base",
+ [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Constraints an attribute/type base";
+ let description = [{
+ `irdl.base` defines a constraint that only accepts a single type
+ or attribute base, e.g. an `IntegerType`. The attribute base is defined
+ either by a symbolic reference to the corresponding IRDL definition,
+ or by the name of the base. Named bases are prefixed with `!` or `#`
+ respectively for types and attributes.
+
+ Example:
+
+ ```mlir
+ irdl.dialect @cmath {
+ irdl.type @complex {
+ %0 = irdl.base "!builtin.integer"
+ irdl.parameters(%0)
+ }
+
+ irdl.type @complex_wrapper {
+ %0 = irdl.base @complex
+ irdl.parameters(%0)
+ }
+ }
+ ```
+
+ The above program defines a `cmath.complex` type that expects a single
+ parameter, which is a type with base name `builtin.integer`, which is the
+ name of an `IntegerType` type.
+ It also defines a `cmath.complex_wrapper` type that expects a single
+ parameter, which is a type of base type `cmath.complex`.
+ }];
+
+ let arguments = (ins OptionalAttr<SymbolRefAttr>:$base_ref,
+ OptionalAttr<StrAttr>:$base_name);
+ let results = (outs IRDL_AttributeType:$output);
+ let assemblyFormat = " ($base_ref^)? ($base_name^)? ` ` attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "SymbolRefAttr":$base_ref), [{
+ build($_builder, $_state, base_ref, {});
+ }]>,
+ OpBuilder<(ins "StringAttr":$base_name), [{
+ build($_builder, $_state, {}, base_name);
+ }]>,
+ ];
+
+ let hasVerifier = 1;
+}
+
def IRDL_ParametricOp : IRDL_ConstraintOp<"parametric",
[ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, Pure]> {
let summary = "Constraints an attribute/type base and its parameters";
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index f8ce77cbc50e9e..9ecb7c0107d7f8 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -99,6 +99,48 @@ class IsConstraint : public Constraint {
Attribute expectedAttribute;
};
+/// A constraint that checks that an attribute is of a given attribute base
+/// (e.g. IntegerAttr).
+class BaseAttrConstraint : public Constraint {
+public:
+ BaseAttrConstraint(TypeID baseTypeID, StringRef baseName)
+ : baseTypeID(baseTypeID), baseName(baseName) {}
+
+ virtual ~BaseAttrConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ /// The expected base attribute typeID.
+ TypeID baseTypeID;
+
+ /// The base attribute name, only used for error reporting.
+ StringRef baseName;
+};
+
+/// A constraint that checks that a type is of a given type base (e.g.
+/// IntegerType).
+class BaseTypeConstraint : public Constraint {
+public:
+ BaseTypeConstraint(TypeID baseTypeID, StringRef baseName)
+ : baseTypeID(baseTypeID), baseName(baseName) {}
+
+ virtual ~BaseTypeConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ /// The expected base type typeID.
+ TypeID baseTypeID;
+
+ /// The base type name, only used for error reporting.
+ StringRef baseName;
+};
+
/// A constraint that checks that an attribute is of a
/// specific dynamic attribute definition, and that all of its parameters
/// satisfy the given constraints.
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index 33c6bb869a643f..4eae2b03024c24 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -117,6 +117,39 @@ LogicalResult AttributesOp::verify() {
return success();
}
+LogicalResult BaseOp::verify() {
+ std::optional<StringRef> baseName = getBaseName();
+ std::optional<SymbolRefAttr> baseRef = getBaseRef();
+ if (baseName.has_value() == baseRef.has_value())
+ return emitOpError() << "the base type or attribute should be specified by "
+ "either a name or a reference";
+
+ if (baseName &&
+ (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
+ return emitOpError() << "the base type or attribute name should start with "
+ "'!' or '#'";
+
+ return success();
+}
+
+LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ std::optional<SymbolRefAttr> baseRef = getBaseRef();
+ if (!baseRef)
+ return success();
+
+ TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
+ if (typeOp)
+ return success();
+
+ AttributeOp attrOp =
+ symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
+ if (attrOp)
+ return success();
+
+ return emitOpError() << "'" << *baseRef
+ << "' does not refer to a type or attribute definition";
+}
+
/// Parse a value with its variadicity first. By default, the variadicity is
/// single.
///
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index e172039712f24c..0895306b8bce1a 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -37,6 +37,60 @@ std::unique_ptr<Constraint> IsOp::getVerifier(
return std::make_unique<IsConstraint>(getExpectedAttr());
}
+std::unique_ptr<Constraint> BaseOp::getVerifier(
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
+ MLIRContext *ctx = getContext();
+
+ // Case where the input is a symbol reference.
+ // This corresponds to the case where the base is an IRDL type or attribute.
+ if (auto baseRef = getBaseRef()) {
+ Operation *defOp =
+ SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+
+ // Type case.
+ if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
+ DynamicTypeDefinition *typeDef = types.at(typeOp).get();
+ auto name = StringAttr::get(ctx, typeDef->getDialect()->getNamespace() +
+ "." + typeDef->getName().str());
+ return std::make_unique<BaseTypeConstraint>(typeDef->getTypeID(), name);
+ }
+
+ // Attribute case.
+ auto attrOp = cast<AttributeOp>(defOp);
+ DynamicAttrDefinition *attrDef = attrs.at(attrOp).get();
+ auto name = StringAttr::get(ctx, attrDef->getDialect()->getNamespace() +
+ "." + attrDef->getName().str());
+ return std::make_unique<BaseAttrConstraint>(attrDef->getTypeID(), name);
+ }
+
+ // Case where the input is string literal.
+ // This corresponds to the case where the base is a registered type or
+ // attribute.
+ StringRef baseName = getBaseName().value();
+
+ // Type case.
+ if (baseName[0] == '!') {
+ auto abstractType = AbstractType::lookup(baseName.drop_front(1), ctx);
+ if (!abstractType) {
+ emitError() << "no registered type with name " << baseName;
+ return nullptr;
+ }
+ return std::make_unique<BaseTypeConstraint>(abstractType->get().getTypeID(),
+ abstractType->get().getName());
+ }
+
+ auto abstractAttr = AbstractAttribute::lookup(baseName.drop_front(1), ctx);
+ if (!abstractAttr) {
+ emitError() << "no registered attribute with name " << baseName;
+ return nullptr;
+ }
+ return std::make_unique<BaseAttrConstraint>(abstractAttr->get().getTypeID(),
+ abstractAttr->get().getName());
+}
+
std::unique_ptr<Constraint> ParametricOp::getVerifier(
ArrayRef<Value> valueToConstr,
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index fab05d1ffb92fa..05dc154eb5b4c6 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -68,6 +68,39 @@ LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
return failure();
}
+LogicalResult
+BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, ConstraintVerifier &context) const {
+ if (attr.getTypeID() == baseTypeID)
+ return success();
+
+ if (emitError)
+ return emitError() << "expected base attribute '" << baseName
+ << "' but got '" << attr.getAbstractAttribute().getName()
+ << "'";
+ return failure();
+}
+
+LogicalResult
+BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, ConstraintVerifier &context) const {
+ auto typeAttr = dyn_cast<TypeAttr>(attr);
+ if (!typeAttr) {
+ if (emitError)
+ return emitError() << "expected type, got attribute '" << attr;
+ return failure();
+ }
+
+ Type type = typeAttr.getValue();
+ if (type.getTypeID() == baseTypeID)
+ return success();
+
+ if (emitError)
+ return emitError() << "expected base type '" << baseName << "' but got '"
+ << type.getAbstractType().getName() << "'";
+ return failure();
+}
+
LogicalResult DynParametricAttrConstraint::verify(
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
ConstraintVerifier &context) const {
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
new file mode 100644
index 00000000000000..d62bb498a7ad98
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file
+
+// Testing invalid IRDL IRs
+
+func.func private @foo()
+
+irdl.dialect @testd {
+ irdl.type @type {
+ // expected-error at +1 {{'@foo' does not refer to a type or attribute definition}}
+ %0 = irdl.base @foo
+ irdl.parameters(%0)
+ }
+}
+
+// -----
+
+irdl.dialect @testd {
+ irdl.type @type {
+ // expected-error at +1 {{the base type or attribute name should start with '!' or '#'}}
+ %0 = irdl.base "builtin.integer"
+ irdl.parameters(%0)
+ }
+}
+
+// -----
+
+irdl.dialect @testd {
+ irdl.type @type {
+ // expected-error at +1 {{the base type or attribute name should start with '!' or '#'}}
+ %0 = irdl.base ""
+ irdl.parameters(%0)
+ }
+}
+
+// -----
+
+irdl.dialect @testd {
+ irdl.type @type {
+ // expected-error at +1 {{the base type or attribute should be specified by either a name}}
+ %0 = irdl.base
+ irdl.parameters(%0)
+ }
+}
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index 684286e4afeb0f..f828d95bdb81d5 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -11,6 +11,15 @@ irdl.dialect @testd {
irdl.parameters(%0)
}
+ // CHECK: irdl.attribute @parametric_attr {
+ // CHECK: %[[v0:[^ ]*]] = irdl.any
+ // CHECK: irdl.parameters(%[[v0]])
+ // CHECK: }
+ irdl.attribute @parametric_attr {
+ %0 = irdl.any
+ irdl.parameters(%0)
+ }
+
// CHECK: irdl.type @attr_in_type_out {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: irdl.parameters(%[[v0]])
@@ -66,15 +75,40 @@ irdl.dialect @testd {
irdl.results(%0)
}
- // CHECK: irdl.operation @dynbase {
- // CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @parametric<%[[v0]]>
+ // CHECK: irdl.operation @dyn_type_base {
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
// CHECK: irdl.results(%[[v1]])
// CHECK: }
- irdl.operation @dynbase {
- %0 = irdl.any
- %1 = irdl.parametric @parametric<%0>
- irdl.results(%1)
+ irdl.operation @dyn_type_base {
+ %0 = irdl.base @parametric
+ irdl.results(%0)
+ }
+
+ // CHECK: irdl.operation @dyn_attr_base {
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
+ // CHECK: irdl.attributes {"attr1" = %[[v1]]}
+ // CHECK: }
+ irdl.operation @dyn_attr_base {
+ %0 = irdl.base @parametric_attr
+ irdl.attributes {"attr1" = %0}
+ }
+
+ // CHECK: irdl.operation @named_type_base {
+ // CHECK: %[[v1:[^ ]*]] = irdl.base "!builtin.integer"
+ // CHECK: irdl.results(%[[v1]])
+ // CHECK: }
+ irdl.operation @named_type_base {
+ %0 = irdl.base "!builtin.integer"
+ irdl.results(%0)
+ }
+
+ // CHECK: irdl.operation @named_attr_base {
+ // CHECK: %[[v1:[^ ]*]] = irdl.base "#builtin.integer"
+ // CHECK: irdl.attributes {"attr1" = %[[v1]]}
+ // CHECK: }
+ irdl.operation @named_attr_base {
+ %0 = irdl.base "#builtin.integer"
+ irdl.attributes {"attr1" = %0}
}
// CHECK: irdl.operation @dynparams {
diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
index bb1e9f46356411..333bb96eb2e60f 100644
--- a/mlir/test/Dialect/IRDL/testd.mlir
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -120,24 +120,67 @@ func.func @succeededAnyConstraint() {
// -----
//===----------------------------------------------------------------------===//
-// Dynamic base constraint
+// Base constraints
//===----------------------------------------------------------------------===//
func.func @succeededDynBaseConstraint() {
- // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
- "testd.dynbase"() : () -> !testd.parametric<i32>
- // CHECK: "testd.dynbase"() : () -> !testd.parametric<i64>
- "testd.dynbase"() : () -> !testd.parametric<i64>
- // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
- "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
+ // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+ "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+ // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+ "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+ // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+ "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+ // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+ "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+ // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
+ "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
return
}
// -----
-func.func @failedDynBaseConstraint() {
- // expected-error at +1 {{expected base type 'testd.parametric' but got 'i32'}}
- "testd.dynbase"() : () -> i32
+func.func @failedDynTypeBaseConstraint() {
+ // expected-error at +1 {{expected base type 'testd.parametric' but got 'builtin.integer'}}
+ "testd.dyn_type_base"() : () -> i32
+ return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+ // expected-error at +1 {{expected base attribute 'testd.parametric_attr' but got 'builtin.type'}}
+ "testd.dyn_attr_base"() {attr1 = i32}: () -> ()
+ return
+}
+
+// -----
+
+
+func.func @succeededNamedBaseConstraint() {
+ // CHECK: "testd.named_type_base"() : () -> i32
+ "testd.named_type_base"() : () -> i32
+ // CHECK: "testd.named_type_base"() : () -> i64
+ "testd.named_type_base"() : () -> i64
+ // CHECK: "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+ "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+ // CHECK: "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+ "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @failedNamedTypeBaseConstraint() {
+ // expected-error at +1 {{expected base type 'builtin.integer' but got 'builtin.vector'}}
+ "testd.named_type_base"() : () -> vector<i32>
+ return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+ // expected-error at +1 {{expected base attribute 'builtin.integer' but got 'builtin.type'}}
+ "testd.named_attr_base"() {attr1 = i32}: () -> ()
return
}
More information about the Mlir-commits
mailing list