[Mlir-commits] [mlir] 3ffd13b - [mlir][irdl] Add IRDL verification constraint classes

Mathieu Fehr llvmlistbot at llvm.org
Fri May 5 03:17:32 PDT 2023


Author: Mathieu Fehr
Date: 2023-05-05T11:17:23+01:00
New Revision: 3ffd13b183ed5472d4dc84d15769a511a28e7680

URL: https://github.com/llvm/llvm-project/commit/3ffd13b183ed5472d4dc84d15769a511a28e7680
DIFF: https://github.com/llvm/llvm-project/commit/3ffd13b183ed5472d4dc84d15769a511a28e7680.diff

LOG: [mlir][irdl] Add IRDL verification constraint classes

This patch adds the necessary constraint classes that are be used
by IRDL to define Operation, Type, and Attribute verifiers.

A constraint is a class inheriting the `irdl::Constraint` class,
which may call other constraints that are indexed by `unsigned`.
A constraint represent an invariant over an Attribute.

The `ConstraintVerifier` class group these constraints together,
and make sure that a constraint can only identify a single
attribute. So, once a constraint is used to check the
satisfiability of an `Attribute`, the `Attribute` will be
memorized for this constraint. This ensure that in IRDL, a
single `!irdl.attribute` value only correspond to a single
`Attribute`.

Depends on D144693

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D145733

Added: 
    mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
    mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp

Modified: 
    mlir/lib/Dialect/IRDL/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
new file mode 100644
index 0000000000000..8f0628e37f1eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -0,0 +1,184 @@
+//===- IRDLVerifiers.h - IRDL verifiers --------------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Verifiers for objects declared by IRDL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
+#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <optional>
+
+namespace mlir {
+struct LogicalResult;
+class InFlightDiagnostic;
+class DynamicAttrDefinition;
+class DynamicTypeDefinition;
+} // namespace mlir
+
+namespace mlir {
+namespace irdl {
+
+class Constraint;
+
+/// Provides context to the verification of constraints.
+/// It contains the assignment of variables to attributes, and the assignment
+/// of variables to constraints.
+class ConstraintVerifier {
+public:
+  ConstraintVerifier(ArrayRef<std::unique_ptr<Constraint>> constraints);
+
+  /// Check that a constraint is satisfied by an attribute.
+  ///
+  /// Constraints may call other constraint verifiers. If that is the case,
+  /// the constraint verifier will check if the variable is already assigned,
+  /// and if so, check that the attribute is the same as the one assigned.
+  /// If the variable is not assigned, the constraint verifier will
+  /// assign the attribute to the variable, and check that the constraint
+  /// is satisfied.
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr, unsigned variable);
+
+private:
+  /// The constraints that can be used for verification.
+  ArrayRef<std::unique_ptr<Constraint>> constraints;
+
+  /// The assignment of variables to attributes. Variables that are not assigned
+  /// are represented by nullopt. Null attributes needs to be supported here as
+  /// some attributes or types might use the null attribute to represent
+  /// optional parameters.
+  SmallVector<std::optional<Attribute>> assigned;
+};
+
+/// Once turned into IRDL verifiers, all constraints are
+/// attribute constraints. Type constraints are represented
+/// as `TypeAttr` attribute constraints to simplify verification.
+/// Verification that a type constraint must yield a
+/// `TypeAttr` attribute happens before conversion, at the MLIR level.
+class Constraint {
+public:
+  virtual ~Constraint() = default;
+
+  /// Check that an attribute is satisfying the constraint.
+  ///
+  /// Constraints may call other constraint verifiers. If that is the case,
+  /// the constraint verifier will check if the variable is already assigned,
+  /// and if so, check that the attribute is the same as the one assigned.
+  /// If the variable is not assigned, the constraint verifier will
+  /// assign the attribute to the variable, and check that the constraint
+  /// is satisfied.
+  virtual LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                               Attribute attr,
+                               ConstraintVerifier &context) const = 0;
+};
+
+/// A constraint that checks that an attribute is equal to a given attribute.
+class IsConstraint : public Constraint {
+public:
+  IsConstraint(Attribute expectedAttribute)
+      : expectedAttribute(expectedAttribute) {}
+
+  virtual ~IsConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  Attribute expectedAttribute;
+};
+
+/// A constraint that checks that an attribute is of a
+/// specific dynamic attribute definition, and that all of its parameters
+/// satisfy the given constraints.
+class DynParametricAttrConstraint : public Constraint {
+public:
+  DynParametricAttrConstraint(DynamicAttrDefinition *attrDef,
+                              SmallVector<unsigned> constraints)
+      : attrDef(attrDef), constraints(std::move(constraints)) {}
+
+  virtual ~DynParametricAttrConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  DynamicAttrDefinition *attrDef;
+  SmallVector<unsigned> constraints;
+};
+
+/// A constraint that checks that a type is of a specific dynamic type
+/// definition, and that all of its parameters satisfy the given constraints.
+class DynParametricTypeConstraint : public Constraint {
+public:
+  DynParametricTypeConstraint(DynamicTypeDefinition *typeDef,
+                              SmallVector<unsigned> constraints)
+      : typeDef(typeDef), constraints(std::move(constraints)) {}
+
+  virtual ~DynParametricTypeConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  DynamicTypeDefinition *typeDef;
+  SmallVector<unsigned> constraints;
+};
+
+/// A constraint checking that one of the given constraints is satisfied.
+class AnyOfConstraint : public Constraint {
+public:
+  AnyOfConstraint(SmallVector<unsigned> constraints)
+      : constraints(std::move(constraints)) {}
+
+  virtual ~AnyOfConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  SmallVector<unsigned> constraints;
+};
+
+/// A constraint checking that all of the given constraints are satisfied.
+class AllOfConstraint : public Constraint {
+public:
+  AllOfConstraint(SmallVector<unsigned> constraints)
+      : constraints(std::move(constraints)) {}
+
+  virtual ~AllOfConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  SmallVector<unsigned> constraints;
+};
+
+/// A constraint that is always satisfied.
+class AnyAttributeConstraint : public Constraint {
+public:
+  virtual ~AnyAttributeConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+};
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLVERIFIERS_H

diff  --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index 31efd5e37a665..7af0e42293573 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
   IRDLLoading.cpp
+  IRDLVerifiers.cpp
 
   DEPENDS
   MLIRIRDLIncGen

diff  --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
new file mode 100644
index 0000000000000..71d27764f437d
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -0,0 +1,177 @@
+//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Verifiers for objects declared by IRDL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+ConstraintVerifier::ConstraintVerifier(
+    ArrayRef<std::unique_ptr<Constraint>> constraints)
+    : constraints(constraints), assigned() {
+  assigned.resize(this->constraints.size());
+}
+
+LogicalResult
+ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Attribute attr, unsigned variable) {
+
+  assert(variable < constraints.size() && "invalid constraint variable");
+
+  // If the variable is already assigned, check that the attribute is the same.
+  if (assigned[variable].has_value()) {
+    if (attr == assigned[variable].value()) {
+      return success();
+    } else {
+      if (emitError)
+        return emitError() << "expected '" << assigned[variable].value()
+                           << "' but got '" << attr << "'";
+      return failure();
+    }
+  }
+
+  // Otherwise, check the constraint and assign the attribute to the variable.
+  LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
+  if (succeeded(result))
+    assigned[variable] = attr;
+
+  return result;
+}
+
+LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                                   Attribute attr,
+                                   ConstraintVerifier &context) const {
+  if (attr == expectedAttribute)
+    return success();
+
+  if (emitError)
+    return emitError() << "expected '" << expectedAttribute << "' but got '"
+                       << attr << "'";
+  return failure();
+}
+
+LogicalResult DynParametricAttrConstraint::verify(
+    function_ref<InFlightDiagnostic()> emitError, Attribute attr,
+    ConstraintVerifier &context) const {
+
+  // Check that the base is the expected one.
+  auto dynAttr = attr.dyn_cast<DynamicAttr>();
+  if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
+    if (emitError) {
+      StringRef dialectName = attrDef->getDialect()->getNamespace();
+      StringRef attrName = attrDef->getName();
+      return emitError() << "expected base attribute '" << attrName << '.'
+                         << dialectName << "' but got '" << attr << "'";
+    }
+    return failure();
+  }
+
+  // Check that the parameters satisfy the constraints.
+  ArrayRef<Attribute> params = dynAttr.getParams();
+  if (params.size() != constraints.size()) {
+    if (emitError) {
+      StringRef dialectName = attrDef->getDialect()->getNamespace();
+      StringRef attrName = attrDef->getName();
+      emitError() << "attribute '" << dialectName << "." << attrName
+                  << "' expects " << params.size() << " parameters but got "
+                  << constraints.size();
+    }
+    return failure();
+  }
+
+  for (size_t i = 0, s = params.size(); i < s; i++)
+    if (failed(context.verify(emitError, params[i], constraints[i])))
+      return failure();
+
+  return success();
+}
+
+LogicalResult DynParametricTypeConstraint::verify(
+    function_ref<InFlightDiagnostic()> emitError, Attribute attr,
+    ConstraintVerifier &context) const {
+  // Check that the base is a TypeAttr.
+  auto typeAttr = attr.dyn_cast<TypeAttr>();
+  if (!typeAttr) {
+    if (emitError)
+      return emitError() << "expected type, got attribute '" << attr;
+    return failure();
+  }
+
+  // Check that the type base is the expected one.
+  auto dynType = typeAttr.getValue().dyn_cast<DynamicType>();
+  if (!dynType || dynType.getTypeDef() != typeDef) {
+    if (emitError) {
+      StringRef dialectName = typeDef->getDialect()->getNamespace();
+      StringRef attrName = typeDef->getName();
+      return emitError() << "expected base type '" << dialectName << '.'
+                         << attrName << "' but got '" << attr << "'";
+    }
+    return failure();
+  }
+
+  // Check that the parameters satisfy the constraints.
+  ArrayRef<Attribute> params = dynType.getParams();
+  if (params.size() != constraints.size()) {
+    if (emitError) {
+      StringRef dialectName = typeDef->getDialect()->getNamespace();
+      StringRef attrName = typeDef->getName();
+      emitError() << "attribute '" << dialectName << "." << attrName
+                  << "' expects " << params.size() << " parameters but got "
+                  << constraints.size();
+    }
+    return failure();
+  }
+
+  for (size_t i = 0, s = params.size(); i < s; i++)
+    if (failed(context.verify(emitError, params[i], constraints[i])))
+      return failure();
+
+  return success();
+}
+
+LogicalResult
+AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                        Attribute attr, ConstraintVerifier &context) const {
+  for (unsigned constr : constraints) {
+    // We do not pass the `emitError` here, since we want to emit an error
+    // only if none of the constraints are satisfied.
+    if (succeeded(context.verify({}, attr, constr))) {
+      return success();
+    }
+  }
+
+  if (emitError)
+    return emitError() << "'" << attr << "' does not satisfy the constraint";
+  return failure();
+}
+
+LogicalResult
+AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                        Attribute attr, ConstraintVerifier &context) const {
+  for (unsigned constr : constraints) {
+    if (failed(context.verify(emitError, attr, constr))) {
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+LogicalResult
+AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                               Attribute attr,
+                               ConstraintVerifier &context) const {
+  return success();
+}


        


More information about the Mlir-commits mailing list