[Mlir-commits] [mlir] 9fcd14d - [MLIR][ODS] Optionally generate public C++ functions for attribute constraints (#144275)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 16 00:21:09 PDT 2025


Author: Henrich Lauko
Date: 2025-06-16T09:21:05+02:00
New Revision: 9fcd14d9b013d0c4b8ec245772b3be3d5c31b885

URL: https://github.com/llvm/llvm-project/commit/9fcd14d9b013d0c4b8ec245772b3be3d5c31b885
DIFF: https://github.com/llvm/llvm-project/commit/9fcd14d9b013d0c4b8ec245772b3be3d5c31b885.diff

LOG: [MLIR][ODS] Optionally generate public C++ functions for attribute constraints (#144275)

Add `gen-attr-constraint-decls` and `gen-attr-constraint-defs`, which
generate public C++ functions for attribute constraints. The name of the C++
function is specified in the `cppFunctionName` field.

This generalize `cppFunctionName` from `TypeConstraint` introduced in
 https://github.com/llvm/llvm-project/pull/104577 to be usable also in `AttrConstraint`.

Added: 
    mlir/test/mlir-tblgen/attr-constraints.td

Modified: 
    mlir/docs/DefiningDialects/Constraints.md
    mlir/include/mlir/IR/Constraints.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DefiningDialects/Constraints.md b/mlir/docs/DefiningDialects/Constraints.md
index 52a4283d6084c..40863e7aecf4a 100644
--- a/mlir/docs/DefiningDialects/Constraints.md
+++ b/mlir/docs/DefiningDialects/Constraints.md
@@ -24,8 +24,8 @@ code is generated for type/attribute constraints. Type constraints can not only
 be used when defining operation arguments, but also when defining type
 parameters.
 
-Optionally, C++ functions can be generated, so that type constraints can be
-checked from C++. The name of the C++ function must be specified in the
+Optionally, C++ functions can be generated, so that type/attribute constraints
+can be checked from C++. The name of the C++ function must be specified in the
 `cppFunctionName` field. If no function name is specified, no C++ function is
 emitted.
 
@@ -43,17 +43,20 @@ bool isValidVectorTypeElementType(::mlir::Type type) {
 }
 ```
 
-An extra TableGen rule is needed to emit C++ code for type constraints. This
-will generate only the declarations/definitions of the type constaraints that
-are defined in the specified `.td` file, but not those that are in included
-`.td` files.
+An extra TableGen rule is needed to emit C++ code for type/attribute
+constraints. This will generate only the declarations/definitions of the
+type/attribute constaraints that are defined in the specified `.td` file, but
+not those that are in included `.td` files.
 
 ```cmake
 mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
 mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
+mlir_tablegen(<Your Dialect>AttrConstraints.h.inc -gen-attr-constraint-decls)
+mlir_tablegen(<Your Dialect>AttrConstraints.cpp.inc -gen-attr-constraint-defs)
 ```
 
-The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
-whereever you are referencing the type constraint in C++. Note that no C++
-namespace will be emitted by the code generator. The `#include` statements of
-the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
+The generated `<Your Dialect>TypeConstraints.h.inc` respectivelly
+`<Your Dialect>AttrConstraints.h.inc` will need to be included whereever you are
+referencing the type/attributes constraint in C++. Note that no C++ namespace
+will be emitted by the code generator. The `#include` statements of the
+`.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.

diff  --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index 33e8581ecd356..0d59fffce9df9 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -148,6 +148,15 @@ class Constraint<Pred pred, string desc = ""> {
   string summary = desc;
 }
 
+// Base class for constraints on types and attributes.
+class AttrTypeConstraint<Pred pred, string summary = "",
+                         string cppFunctionNameParam = ""> :
+    Constraint<pred, summary> {
+  // The name of the C++ function that is generated for this constraint.
+  // If empty, no C++ function is generated.
+  string cppFunctionName = cppFunctionNameParam;
+}
+
 // Subclasses used to 
diff erentiate 
diff erent constraint kinds. These are used
 // as markers for the TableGen backend to handle 
diff erent constraint kinds
 // 
diff erently if needed. Constraints not deriving from the following subclasses
@@ -157,17 +166,15 @@ class Constraint<Pred pred, string desc = ""> {
 class TypeConstraint<Pred predicate, string summary = "",
                      string cppTypeParam = "::mlir::Type",
                      string cppFunctionNameParam = ""> :
-    Constraint<predicate, summary> {
+    AttrTypeConstraint<predicate, summary, cppFunctionNameParam> {
   // The name of the C++ Type class if known, or Type if not.
   string cppType = cppTypeParam;
-  // The name of the C++ function that is generated for this type constraint.
-  // If empty, no C++ function is generated.
-  string cppFunctionName = cppFunctionNameParam;
 }
 
 // Subclass for constraints on an attribute.
-class AttrConstraint<Pred predicate, string summary = ""> :
-    Constraint<predicate, summary>;
+class AttrConstraint<Pred predicate, string summary = "",
+                     string cppFunctionNameParam = ""> :
+    AttrTypeConstraint<predicate, summary, cppFunctionNameParam>;
 
 // Subclass for constraints on a property.
 class PropConstraint<Pred predicate, string summary = "", string interfaceTypeParam = ""> :

diff  --git a/mlir/test/mlir-tblgen/attr-constraints.td b/mlir/test/mlir-tblgen/attr-constraints.td
new file mode 100644
index 0000000000000..59bc5f2526603
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-constraints.td
@@ -0,0 +1,14 @@
+// RUN: mlir-tblgen -gen-attr-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-attr-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/CommonAttrConstraints.td"
+
+def DummyConstraint : AnyAttrOf<[APIntAttr, ArrayAttr, UnitAttr]> {
+  let cppFunctionName = "isValidDummy";
+}
+
+// DECL: bool isValidDummy(::mlir::Attribute attr);
+
+// DEF: bool isValidDummy(::mlir::Attribute attr) {
+// DEF:   return (((::llvm::isa<::mlir::IntegerAttr>(attr))) || ((::llvm::isa<::mlir::ArrayAttr>(attr))) || ((::llvm::isa<::mlir::UnitAttr>(attr))));
+// DEF: }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a6071602fa49..defd1fa12ca1a 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -1083,15 +1083,15 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
 }
 
 //===----------------------------------------------------------------------===//
-// Type Constraints
+// Constraints
 //===----------------------------------------------------------------------===//
 
 /// Find all type constraints for which a C++ function should be generated.
-static std::vector<Constraint>
-getAllTypeConstraints(const RecordKeeper &records) {
+static std::vector<Constraint> getAllCppConstraints(const RecordKeeper &records,
+                                                    StringRef constraintKind) {
   std::vector<Constraint> result;
   for (const Record *def :
-       records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
+       records.getAllDerivedDefinitionsIfDefined(constraintKind)) {
     // Ignore constraints defined outside of the top-level file.
     if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
         llvm::SrcMgr.getMainFileID())
@@ -1105,32 +1105,74 @@ getAllTypeConstraints(const RecordKeeper &records) {
   return result;
 }
 
+static std::vector<Constraint>
+getAllCppTypeConstraints(const RecordKeeper &records) {
+  return getAllCppConstraints(records, "TypeConstraint");
+}
+
+static std::vector<Constraint>
+getAllCppAttrConstraints(const RecordKeeper &records) {
+  return getAllCppConstraints(records, "AttrConstraint");
+}
+
+/// Emit the declarations for the given constraints, of the form:
+/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
+static void emitConstraintDecls(const std::vector<Constraint> &constraints,
+                                raw_ostream &os, StringRef parameterTypeName,
+                                StringRef parameterName) {
+  static const char *const constraintDecl = "bool {0}({1} {2});\n";
+  for (Constraint constr : constraints)
+    os << strfmt(constraintDecl, *constr.getCppFunctionName(),
+                 parameterTypeName, parameterName);
+}
+
 static void emitTypeConstraintDecls(const RecordKeeper &records,
                                     raw_ostream &os) {
-  static const char *const typeConstraintDecl = R"(
-bool {0}(::mlir::Type type);
-)";
+  emitConstraintDecls(getAllCppTypeConstraints(records), os, "::mlir::Type",
+                      "type");
+}
 
-  for (Constraint constr : getAllTypeConstraints(records))
-    os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
+static void emitAttrConstraintDecls(const RecordKeeper &records,
+                                    raw_ostream &os) {
+  emitConstraintDecls(getAllCppAttrConstraints(records), os,
+                      "::mlir::Attribute", "attr");
 }
 
-static void emitTypeConstraintDefs(const RecordKeeper &records,
-                                   raw_ostream &os) {
-  static const char *const typeConstraintDef = R"(
-bool {0}(::mlir::Type type) {
-  return ({1});
+/// Emit the definitions for the given constraints, of the form:
+/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
+///   return (<condition>); }`
+/// where `<condition>` is the condition template with the `self` variable
+/// replaced with the `selfName` parameter.
+static void emitConstraintDefs(const std::vector<Constraint> &constraints,
+                               raw_ostream &os, StringRef parameterTypeName,
+                               StringRef selfName) {
+  static const char *const constraintDef = R"(
+bool {0}({1} {2}) {
+return ({3});
 }
 )";
 
-  for (Constraint constr : getAllTypeConstraints(records)) {
+  for (Constraint constr : constraints) {
     FmtContext ctx;
-    ctx.withSelf("type");
+    ctx.withSelf(selfName);
     std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
-    os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
+    os << strfmt(constraintDef, *constr.getCppFunctionName(), parameterTypeName,
+                 selfName, condition);
   }
 }
 
+static void emitTypeConstraintDefs(const RecordKeeper &records,
+                                   raw_ostream &os) {
+  emitConstraintDefs(getAllCppTypeConstraints(records), os, "::mlir::Type",
+                     "type");
+}
+
+static void emitAttrConstraintDefs(const RecordKeeper &records,
+                                   raw_ostream &os) {
+  emitConstraintDefs(getAllCppAttrConstraints(records), os, "::mlir::Attribute",
+                     "attr");
+}
+
 //===----------------------------------------------------------------------===//
 // GEN: Registration hooks
 //===----------------------------------------------------------------------===//
@@ -1158,6 +1200,21 @@ static mlir::GenRegistration
                    return generator.emitDecls(attrDialect);
                  });
 
+static mlir::GenRegistration
+    genAttrConstrDefs("gen-attr-constraint-defs",
+                      "Generate attribute constraint definitions",
+                      [](const RecordKeeper &records, raw_ostream &os) {
+                        emitAttrConstraintDefs(records, os);
+                        return false;
+                      });
+static mlir::GenRegistration
+    genAttrConstrDecls("gen-attr-constraint-decls",
+                       "Generate attribute constraint declarations",
+                       [](const RecordKeeper &records, raw_ostream &os) {
+                         emitAttrConstraintDecls(records, os);
+                         return false;
+                       });
+
 //===----------------------------------------------------------------------===//
 // TypeDef
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list