[Mlir-commits] [mlir] [mlir][ODS] Optionally generate public C++ functions for type constraints (PR #104577)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 16 03:10:53 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field.
Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++.
Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way.
Alternatives considered:
1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files.
2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)
---
Full diff: https://github.com/llvm/llvm-project/pull/104577.diff
10 Files Affected:
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+1)
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+7-7)
- (modified) mlir/include/mlir/IR/CMakeLists.txt (+3)
- (modified) mlir/include/mlir/IR/Constraints.td (+5-1)
- (modified) mlir/include/mlir/TableGen/Constraint.h (+4)
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+14-2)
- (modified) mlir/lib/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/TableGen/Constraint.cpp (+9-1)
- (added) mlir/test/mlir-tblgen/type-constraints.td (+14)
- (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+60)
``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d12522ba55c96e..eefa4279df1a01 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
#include "mlir/IR/BuiltinTypes.h.inc"
namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.h.inc"
//===----------------------------------------------------------------------===//
// MemRefType
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4b3add2035263c..1ab1bbe9bfc9b2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+ let cppFunctionName = "isValidVectorTypeElementType";
+}
+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
+ Builtin_VectorTypeElementType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
@@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
class Builder;
/// Returns true if the given type can be used as an element of a vector
- /// type. In particular, vectors can consist of integer, index, or float
- /// primitives.
- static bool isValidElementType(Type t) {
- // TODO: Auto-generate this function from $elementType.
- return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
- }
+ /// type. See "Builtin_VectorTypeElementType" for allowed types.
+ static bool isValidElementType(Type t);
/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 04a57d26a068d5..b741eb18d47916 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
+mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
+add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index 39bc55db63da1a..13223aa8abcdaa 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
- string cppTypeParam = "::mlir::Type"> :
+ string cppTypeParam = "::mlir::Type",
+ string cppFunctionNameParam = ""> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppType = cppTypeParam;
+ // The name of the C++ function that is generated for this type constraint.
+ // If empty, no C++ function is generated.
+ string cppFunctionName = cppFunctionNameParam;
}
// Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0d0c28e651ee99..8877daaa775145 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -69,6 +69,10 @@ class Constraint {
/// context on the def).
std::string getUniqueDefName() const;
+ /// Returns the name of the C++ function that should be generated for this
+ /// constraint, or std::nullopt if no C++ function should be generated.
+ std::optional<StringRef> getCppFunctionName() const;
+
Kind getKind() const { return kind; }
/// Return the underlying def.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a3f5ece8c17369..f3f58efa5683f3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -32,6 +32,10 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
+namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
@@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
// VectorType
//===----------------------------------------------------------------------===//
+bool VectorType::isValidElementType(Type t) {
+ return succeeded(isValidVectorTypeElementType(t));
+}
+
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
@@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
+bool TensorType::hasRank() const {
+ return !llvm::isa<UnrankedTensorType>(*this);
+}
ArrayRef<int64_t> TensorType::getShape() const {
return llvm::cast<RankedTensorType>(*this).getShape();
@@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
+bool BaseMemRefType::hasRank() const {
+ return !llvm::isa<UnrankedMemRefType>(*this);
+}
ArrayRef<int64_t> BaseMemRefType::getShape() const {
return llvm::cast<MemRefType>(*this).getShape();
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..4cabac185171c2 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen
MLIRBuiltinTypesIncGen
+ MLIRBuiltinTypeConstraintsIncGen
MLIRBuiltinTypeInterfacesIncGen
MLIRCallInterfacesIncGen
MLIRCastInterfacesIncGen
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 4ccbd0a685e09a..8cf4ed08a2d54f 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Region;
} else if (def->isSubClassOf("SuccessorConstraint")) {
kind = CK_Successor;
- } else if(!def->isSubClassOf("Constraint")) {
+ } else if (!def->isSubClassOf("Constraint")) {
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
llvm::report_fatal_error("Abort");
}
@@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
}
}
+std::optional<StringRef> Constraint::getCppFunctionName() const {
+ std::optional<StringRef> name =
+ def->getValueAsOptionalString("cppFunctionName");
+ if (!name || *name == "")
+ return std::nullopt;
+ return name;
+}
+
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
diff --git a/mlir/test/mlir-tblgen/type-constraints.td b/mlir/test/mlir-tblgen/type-constraints.td
new file mode 100644
index 00000000000000..9b0d1a97c4ab9d
--- /dev/null
+++ b/mlir/test/mlir-tblgen/type-constraints.td
@@ -0,0 +1,14 @@
+// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/CommonTypeConstraints.td"
+
+def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+ let cppFunctionName = "isValidDummy";
+}
+
+// DECL: ::llvm::LogicalResult isValidDummy(::mlir::Type type);
+
+// DEF: ::llvm::LogicalResult isValidDummy(::mlir::Type type) {
+// DEF: return ::llvm::success((((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type)))));
+// DEF: }
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 71ba6a5c73da9e..5aabe33d479214 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -1023,6 +1023,51 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
return false;
}
+//===----------------------------------------------------------------------===//
+// Type Constraints
+//===----------------------------------------------------------------------===//
+
+static const char *const typeConstraintDecl = R"(
+::llvm::LogicalResult {0}(::mlir::Type type);
+)";
+
+static const char *const typeConstraintDef = R"(
+::llvm::LogicalResult {0}(::mlir::Type type) {
+ return ::llvm::success(({1}));
+}
+)";
+
+/// Find all type constraints for which a C++ function should be generated.
+static std::vector<Constraint>
+getAllTypeConstraints(const llvm::RecordKeeper &records) {
+ std::vector<Constraint> result;
+ for (llvm::Record *def :
+ records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
+ Constraint constr(def);
+ // Generate C++ function only if "cppFunctionName" is set.
+ if (!constr.getCppFunctionName())
+ continue;
+ result.push_back(constr);
+ }
+ return result;
+}
+
+static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
+ raw_ostream &os) {
+ for (Constraint constr : getAllTypeConstraints(records))
+ os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
+}
+
+static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
+ raw_ostream &os) {
+ for (Constraint constr : getAllTypeConstraints(records)) {
+ FmtContext ctx;
+ ctx.withSelf("type");
+ std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
+ os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
+ }
+}
+
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
@@ -1070,3 +1115,18 @@ static mlir::GenRegistration
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});
+
+static mlir::GenRegistration
+ genTypeConstrDefs("gen-type-constraint-defs",
+ "Generate type constraint definitions",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ emitTypeConstraintDefs(records, os);
+ return false;
+ });
+static mlir::GenRegistration
+ genTypeConstrDecls("gen-type-constraint-decls",
+ "Generate type constraint declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ emitTypeConstraintDecls(records, os);
+ return false;
+ });
``````````
</details>
https://github.com/llvm/llvm-project/pull/104577
More information about the Mlir-commits
mailing list