[Mlir-commits] [mlir] [mlir][ODS] Optionally generate public C++ functions for type constraints (PR #104577)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 16 03:10:24 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/104577

Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field.

Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++.

Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way.

Alternatives considered:
1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files.
2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)

>From 83e7d34c2fe08c96056987316fea309e7d5928e6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 16 Aug 2024 12:00:14 +0200
Subject: [PATCH] [mlir][ODS] Optionally generate public C++ functions for type
 constraints

Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field.

Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++.

Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way.

Alternatives considered:
1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files.
2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)
---
 mlir/include/mlir/IR/BuiltinTypes.h         |  1 +
 mlir/include/mlir/IR/BuiltinTypes.td        | 14 ++---
 mlir/include/mlir/IR/CMakeLists.txt         |  3 ++
 mlir/include/mlir/IR/Constraints.td         |  6 ++-
 mlir/include/mlir/TableGen/Constraint.h     |  4 ++
 mlir/lib/IR/BuiltinTypes.cpp                | 16 +++++-
 mlir/lib/IR/CMakeLists.txt                  |  1 +
 mlir/lib/TableGen/Constraint.cpp            | 10 +++-
 mlir/test/mlir-tblgen/type-constraints.td   | 14 +++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 60 +++++++++++++++++++++
 10 files changed, 118 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/type-constraints.td

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d12522ba55c96e..eefa4279df1a01 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
 #include "mlir/IR/BuiltinTypes.h.inc"
 
 namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.h.inc"
 
 //===----------------------------------------------------------------------===//
 // MemRefType
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4b3add2035263c..1ab1bbe9bfc9b2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+  let cppFunctionName = "isValidVectorTypeElementType";
+}
+
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
+    Builtin_VectorTypeElementType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
@@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
     class Builder;
 
     /// Returns true if the given type can be used as an element of a vector
-    /// type. In particular, vectors can consist of integer, index, or float
-    /// primitives.
-    static bool isValidElementType(Type t) {
-      // TODO: Auto-generate this function from $elementType.
-      return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
-    }
+    /// type. See "Builtin_VectorTypeElementType" for allowed types.
+    static bool isValidElementType(Type t);
 
     /// Returns true if the vector contains scalable dimensions.
     bool isScalable() const {
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 04a57d26a068d5..b741eb18d47916 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
 mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
 mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
+mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
+add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
 mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index 39bc55db63da1a..13223aa8abcdaa 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {
 
 // Subclass for constraints on a type.
 class TypeConstraint<Pred predicate, string summary = "",
-                     string cppTypeParam = "::mlir::Type"> :
+                     string cppTypeParam = "::mlir::Type",
+                     string cppFunctionNameParam = ""> :
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
   string cppType = cppTypeParam;
+  // The name of the C++ function that is generated for this type constraint.
+  // If empty, no C++ function is generated.
+  string cppFunctionName = cppFunctionNameParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0d0c28e651ee99..8877daaa775145 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -69,6 +69,10 @@ class Constraint {
   /// context on the def).
   std::string getUniqueDefName() const;
 
+  /// Returns the name of the C++ function that should be generated for this
+  /// constraint, or std::nullopt if no C++ function should be generated.
+  std::optional<StringRef> getCppFunctionName() const;
+
   Kind getKind() const { return kind; }
 
   /// Return the underlying def.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a3f5ece8c17369..f3f58efa5683f3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -32,6 +32,10 @@ using namespace mlir::detail;
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.cpp.inc"
 
+namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // BuiltinDialect
 //===----------------------------------------------------------------------===//
@@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 // VectorType
 //===----------------------------------------------------------------------===//
 
+bool VectorType::isValidElementType(Type t) {
+  return succeeded(isValidVectorTypeElementType(t));
+}
+
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
                                  ArrayRef<bool> scalableDims) {
@@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
           [](auto type) { return type.getElementType(); });
 }
 
-bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
+bool TensorType::hasRank() const {
+  return !llvm::isa<UnrankedTensorType>(*this);
+}
 
 ArrayRef<int64_t> TensorType::getShape() const {
   return llvm::cast<RankedTensorType>(*this).getShape();
@@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
           [](auto type) { return type.getElementType(); });
 }
 
-bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
+bool BaseMemRefType::hasRank() const {
+  return !llvm::isa<UnrankedMemRefType>(*this);
+}
 
 ArrayRef<int64_t> BaseMemRefType::getShape() const {
   return llvm::cast<MemRefType>(*this).getShape();
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..4cabac185171c2 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
   MLIRBuiltinLocationAttributesIncGen
   MLIRBuiltinOpsIncGen
   MLIRBuiltinTypesIncGen
+  MLIRBuiltinTypeConstraintsIncGen
   MLIRBuiltinTypeInterfacesIncGen
   MLIRCallInterfacesIncGen
   MLIRCastInterfacesIncGen
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 4ccbd0a685e09a..8cf4ed08a2d54f 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
     kind = CK_Region;
   } else if (def->isSubClassOf("SuccessorConstraint")) {
     kind = CK_Successor;
-  } else if(!def->isSubClassOf("Constraint")) {
+  } else if (!def->isSubClassOf("Constraint")) {
     llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
     llvm::report_fatal_error("Abort");
   }
@@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
   }
 }
 
+std::optional<StringRef> Constraint::getCppFunctionName() const {
+  std::optional<StringRef> name =
+      def->getValueAsOptionalString("cppFunctionName");
+  if (!name || *name == "")
+    return std::nullopt;
+  return name;
+}
+
 AppliedConstraint::AppliedConstraint(Constraint &&constraint,
                                      llvm::StringRef self,
                                      std::vector<std::string> &&entities)
diff --git a/mlir/test/mlir-tblgen/type-constraints.td b/mlir/test/mlir-tblgen/type-constraints.td
new file mode 100644
index 00000000000000..9b0d1a97c4ab9d
--- /dev/null
+++ b/mlir/test/mlir-tblgen/type-constraints.td
@@ -0,0 +1,14 @@
+// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/CommonTypeConstraints.td"
+
+def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+  let cppFunctionName = "isValidDummy";
+}
+
+// DECL: ::llvm::LogicalResult isValidDummy(::mlir::Type type);
+
+// DEF: ::llvm::LogicalResult isValidDummy(::mlir::Type type) {
+// DEF:   return ::llvm::success((((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type)))));
+// DEF: }
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 71ba6a5c73da9e..5aabe33d479214 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -1023,6 +1023,51 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// Type Constraints
+//===----------------------------------------------------------------------===//
+
+static const char *const typeConstraintDecl = R"(
+::llvm::LogicalResult {0}(::mlir::Type type);
+)";
+
+static const char *const typeConstraintDef = R"(
+::llvm::LogicalResult {0}(::mlir::Type type) {
+  return ::llvm::success(({1}));
+}
+)";
+
+/// Find all type constraints for which a C++ function should be generated.
+static std::vector<Constraint>
+getAllTypeConstraints(const llvm::RecordKeeper &records) {
+  std::vector<Constraint> result;
+  for (llvm::Record *def :
+       records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
+    Constraint constr(def);
+    // Generate C++ function only if "cppFunctionName" is set.
+    if (!constr.getCppFunctionName())
+      continue;
+    result.push_back(constr);
+  }
+  return result;
+}
+
+static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
+                                    raw_ostream &os) {
+  for (Constraint constr : getAllTypeConstraints(records))
+    os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
+}
+
+static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
+                                   raw_ostream &os) {
+  for (Constraint constr : getAllTypeConstraints(records)) {
+    FmtContext ctx;
+    ctx.withSelf("type");
+    std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
+    os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // GEN: Registration hooks
 //===----------------------------------------------------------------------===//
@@ -1070,3 +1115,18 @@ static mlir::GenRegistration
                    TypeDefGenerator generator(records, os);
                    return generator.emitDecls(typeDialect);
                  });
+
+static mlir::GenRegistration
+    genTypeConstrDefs("gen-type-constraint-defs",
+                      "Generate type constraint definitions",
+                      [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                        emitTypeConstraintDefs(records, os);
+                        return false;
+                      });
+static mlir::GenRegistration
+    genTypeConstrDecls("gen-type-constraint-decls",
+                       "Generate type constraint declarations",
+                       [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                         emitTypeConstraintDecls(records, os);
+                         return false;
+                       });



More information about the Mlir-commits mailing list