[Mlir-commits] [mlir] 3ffd13b - [mlir][irdl] Add IRDL verification constraint classes
Mathieu Fehr
llvmlistbot at llvm.org
Fri May 5 03:17:32 PDT 2023
Author: Mathieu Fehr
Date: 2023-05-05T11:17:23+01:00
New Revision: 3ffd13b183ed5472d4dc84d15769a511a28e7680
URL: https://github.com/llvm/llvm-project/commit/3ffd13b183ed5472d4dc84d15769a511a28e7680
DIFF: https://github.com/llvm/llvm-project/commit/3ffd13b183ed5472d4dc84d15769a511a28e7680.diff
LOG: [mlir][irdl] Add IRDL verification constraint classes
This patch adds the necessary constraint classes that are be used
by IRDL to define Operation, Type, and Attribute verifiers.
A constraint is a class inheriting the `irdl::Constraint` class,
which may call other constraints that are indexed by `unsigned`.
A constraint represent an invariant over an Attribute.
The `ConstraintVerifier` class group these constraints together,
and make sure that a constraint can only identify a single
attribute. So, once a constraint is used to check the
satisfiability of an `Attribute`, the `Attribute` will be
memorized for this constraint. This ensure that in IRDL, a
single `!irdl.attribute` value only correspond to a single
`Attribute`.
Depends on D144693
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D145733
Added:
mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
Modified:
mlir/lib/Dialect/IRDL/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
new file mode 100644
index 0000000000000..8f0628e37f1eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -0,0 +1,184 @@
+//===- IRDLVerifiers.h - IRDL verifiers --------------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Verifiers for objects declared by IRDL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
+#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <optional>
+
+namespace mlir {
+struct LogicalResult;
+class InFlightDiagnostic;
+class DynamicAttrDefinition;
+class DynamicTypeDefinition;
+} // namespace mlir
+
+namespace mlir {
+namespace irdl {
+
+class Constraint;
+
+/// Provides context to the verification of constraints.
+/// It contains the assignment of variables to attributes, and the assignment
+/// of variables to constraints.
+class ConstraintVerifier {
+public:
+ ConstraintVerifier(ArrayRef<std::unique_ptr<Constraint>> constraints);
+
+ /// Check that a constraint is satisfied by an attribute.
+ ///
+ /// Constraints may call other constraint verifiers. If that is the case,
+ /// the constraint verifier will check if the variable is already assigned,
+ /// and if so, check that the attribute is the same as the one assigned.
+ /// If the variable is not assigned, the constraint verifier will
+ /// assign the attribute to the variable, and check that the constraint
+ /// is satisfied.
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, unsigned variable);
+
+private:
+ /// The constraints that can be used for verification.
+ ArrayRef<std::unique_ptr<Constraint>> constraints;
+
+ /// The assignment of variables to attributes. Variables that are not assigned
+ /// are represented by nullopt. Null attributes needs to be supported here as
+ /// some attributes or types might use the null attribute to represent
+ /// optional parameters.
+ SmallVector<std::optional<Attribute>> assigned;
+};
+
+/// Once turned into IRDL verifiers, all constraints are
+/// attribute constraints. Type constraints are represented
+/// as `TypeAttr` attribute constraints to simplify verification.
+/// Verification that a type constraint must yield a
+/// `TypeAttr` attribute happens before conversion, at the MLIR level.
+class Constraint {
+public:
+ virtual ~Constraint() = default;
+
+ /// Check that an attribute is satisfying the constraint.
+ ///
+ /// Constraints may call other constraint verifiers. If that is the case,
+ /// the constraint verifier will check if the variable is already assigned,
+ /// and if so, check that the attribute is the same as the one assigned.
+ /// If the variable is not assigned, the constraint verifier will
+ /// assign the attribute to the variable, and check that the constraint
+ /// is satisfied.
+ virtual LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const = 0;
+};
+
+/// A constraint that checks that an attribute is equal to a given attribute.
+class IsConstraint : public Constraint {
+public:
+ IsConstraint(Attribute expectedAttribute)
+ : expectedAttribute(expectedAttribute) {}
+
+ virtual ~IsConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ Attribute expectedAttribute;
+};
+
+/// A constraint that checks that an attribute is of a
+/// specific dynamic attribute definition, and that all of its parameters
+/// satisfy the given constraints.
+class DynParametricAttrConstraint : public Constraint {
+public:
+ DynParametricAttrConstraint(DynamicAttrDefinition *attrDef,
+ SmallVector<unsigned> constraints)
+ : attrDef(attrDef), constraints(std::move(constraints)) {}
+
+ virtual ~DynParametricAttrConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ DynamicAttrDefinition *attrDef;
+ SmallVector<unsigned> constraints;
+};
+
+/// A constraint that checks that a type is of a specific dynamic type
+/// definition, and that all of its parameters satisfy the given constraints.
+class DynParametricTypeConstraint : public Constraint {
+public:
+ DynParametricTypeConstraint(DynamicTypeDefinition *typeDef,
+ SmallVector<unsigned> constraints)
+ : typeDef(typeDef), constraints(std::move(constraints)) {}
+
+ virtual ~DynParametricTypeConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ DynamicTypeDefinition *typeDef;
+ SmallVector<unsigned> constraints;
+};
+
+/// A constraint checking that one of the given constraints is satisfied.
+class AnyOfConstraint : public Constraint {
+public:
+ AnyOfConstraint(SmallVector<unsigned> constraints)
+ : constraints(std::move(constraints)) {}
+
+ virtual ~AnyOfConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ SmallVector<unsigned> constraints;
+};
+
+/// A constraint checking that all of the given constraints are satisfied.
+class AllOfConstraint : public Constraint {
+public:
+ AllOfConstraint(SmallVector<unsigned> constraints)
+ : constraints(std::move(constraints)) {}
+
+ virtual ~AllOfConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+
+private:
+ SmallVector<unsigned> constraints;
+};
+
+/// A constraint that is always satisfied.
+class AnyAttributeConstraint : public Constraint {
+public:
+ virtual ~AnyAttributeConstraint() = default;
+
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const override;
+};
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index 31efd5e37a665..7af0e42293573 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
IRDLLoading.cpp
+ IRDLVerifiers.cpp
DEPENDS
MLIRIRDLIncGen
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
new file mode 100644
index 0000000000000..71d27764f437d
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -0,0 +1,177 @@
+//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Verifiers for objects declared by IRDL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+ConstraintVerifier::ConstraintVerifier(
+ ArrayRef<std::unique_ptr<Constraint>> constraints)
+ : constraints(constraints), assigned() {
+ assigned.resize(this->constraints.size());
+}
+
+LogicalResult
+ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, unsigned variable) {
+
+ assert(variable < constraints.size() && "invalid constraint variable");
+
+ // If the variable is already assigned, check that the attribute is the same.
+ if (assigned[variable].has_value()) {
+ if (attr == assigned[variable].value()) {
+ return success();
+ } else {
+ if (emitError)
+ return emitError() << "expected '" << assigned[variable].value()
+ << "' but got '" << attr << "'";
+ return failure();
+ }
+ }
+
+ // Otherwise, check the constraint and assign the attribute to the variable.
+ LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
+ if (succeeded(result))
+ assigned[variable] = attr;
+
+ return result;
+}
+
+LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const {
+ if (attr == expectedAttribute)
+ return success();
+
+ if (emitError)
+ return emitError() << "expected '" << expectedAttribute << "' but got '"
+ << attr << "'";
+ return failure();
+}
+
+LogicalResult DynParametricAttrConstraint::verify(
+ function_ref<InFlightDiagnostic()> emitError, Attribute attr,
+ ConstraintVerifier &context) const {
+
+ // Check that the base is the expected one.
+ auto dynAttr = attr.dyn_cast<DynamicAttr>();
+ if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
+ if (emitError) {
+ StringRef dialectName = attrDef->getDialect()->getNamespace();
+ StringRef attrName = attrDef->getName();
+ return emitError() << "expected base attribute '" << attrName << '.'
+ << dialectName << "' but got '" << attr << "'";
+ }
+ return failure();
+ }
+
+ // Check that the parameters satisfy the constraints.
+ ArrayRef<Attribute> params = dynAttr.getParams();
+ if (params.size() != constraints.size()) {
+ if (emitError) {
+ StringRef dialectName = attrDef->getDialect()->getNamespace();
+ StringRef attrName = attrDef->getName();
+ emitError() << "attribute '" << dialectName << "." << attrName
+ << "' expects " << params.size() << " parameters but got "
+ << constraints.size();
+ }
+ return failure();
+ }
+
+ for (size_t i = 0, s = params.size(); i < s; i++)
+ if (failed(context.verify(emitError, params[i], constraints[i])))
+ return failure();
+
+ return success();
+}
+
+LogicalResult DynParametricTypeConstraint::verify(
+ function_ref<InFlightDiagnostic()> emitError, Attribute attr,
+ ConstraintVerifier &context) const {
+ // Check that the base is a TypeAttr.
+ auto typeAttr = attr.dyn_cast<TypeAttr>();
+ if (!typeAttr) {
+ if (emitError)
+ return emitError() << "expected type, got attribute '" << attr;
+ return failure();
+ }
+
+ // Check that the type base is the expected one.
+ auto dynType = typeAttr.getValue().dyn_cast<DynamicType>();
+ if (!dynType || dynType.getTypeDef() != typeDef) {
+ if (emitError) {
+ StringRef dialectName = typeDef->getDialect()->getNamespace();
+ StringRef attrName = typeDef->getName();
+ return emitError() << "expected base type '" << dialectName << '.'
+ << attrName << "' but got '" << attr << "'";
+ }
+ return failure();
+ }
+
+ // Check that the parameters satisfy the constraints.
+ ArrayRef<Attribute> params = dynType.getParams();
+ if (params.size() != constraints.size()) {
+ if (emitError) {
+ StringRef dialectName = typeDef->getDialect()->getNamespace();
+ StringRef attrName = typeDef->getName();
+ emitError() << "attribute '" << dialectName << "." << attrName
+ << "' expects " << params.size() << " parameters but got "
+ << constraints.size();
+ }
+ return failure();
+ }
+
+ for (size_t i = 0, s = params.size(); i < s; i++)
+ if (failed(context.verify(emitError, params[i], constraints[i])))
+ return failure();
+
+ return success();
+}
+
+LogicalResult
+AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, ConstraintVerifier &context) const {
+ for (unsigned constr : constraints) {
+ // We do not pass the `emitError` here, since we want to emit an error
+ // only if none of the constraints are satisfied.
+ if (succeeded(context.verify({}, attr, constr))) {
+ return success();
+ }
+ }
+
+ if (emitError)
+ return emitError() << "'" << attr << "' does not satisfy the constraint";
+ return failure();
+}
+
+LogicalResult
+AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr, ConstraintVerifier &context) const {
+ for (unsigned constr : constraints) {
+ if (failed(context.verify(emitError, attr, constr))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+LogicalResult
+AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+ Attribute attr,
+ ConstraintVerifier &context) const {
+ return success();
+}
More information about the Mlir-commits
mailing list