[Mlir-commits] [mlir] [mlir][ODS] Optionally generate public C++ functions for type constraints (PR #104577)
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 19 01:15:12 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/104577
>From 5ba236fd6728a5c654991053e34668a24828c8fd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 16 Aug 2024 12:00:14 +0200
Subject: [PATCH 1/2] [mlir][ODS] Optionally generate public C++ functions for
type constraints
Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field.
Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++.
Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way.
Alternatives considered:
1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files.
2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)
---
mlir/include/mlir/IR/BuiltinTypes.h | 1 +
mlir/include/mlir/IR/BuiltinTypes.td | 14 ++---
mlir/include/mlir/IR/CMakeLists.txt | 3 ++
mlir/include/mlir/IR/Constraints.td | 6 ++-
mlir/include/mlir/TableGen/Constraint.h | 4 ++
mlir/lib/IR/BuiltinTypes.cpp | 16 +++++-
mlir/lib/IR/CMakeLists.txt | 1 +
mlir/lib/TableGen/Constraint.cpp | 10 +++-
mlir/test/mlir-tblgen/type-constraints.td | 14 +++++
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 60 +++++++++++++++++++++
10 files changed, 118 insertions(+), 11 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/type-constraints.td
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d12522ba55c96e..eefa4279df1a01 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
#include "mlir/IR/BuiltinTypes.h.inc"
namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.h.inc"
//===----------------------------------------------------------------------===//
// MemRefType
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4b3add2035263c..1ab1bbe9bfc9b2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+ let cppFunctionName = "isValidVectorTypeElementType";
+}
+
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
+ Builtin_VectorTypeElementType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
@@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
class Builder;
/// Returns true if the given type can be used as an element of a vector
- /// type. In particular, vectors can consist of integer, index, or float
- /// primitives.
- static bool isValidElementType(Type t) {
- // TODO: Auto-generate this function from $elementType.
- return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
- }
+ /// type. See "Builtin_VectorTypeElementType" for allowed types.
+ static bool isValidElementType(Type t);
/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 04a57d26a068d5..b741eb18d47916 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
+mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
+mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
+add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index 39bc55db63da1a..13223aa8abcdaa 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
- string cppTypeParam = "::mlir::Type"> :
+ string cppTypeParam = "::mlir::Type",
+ string cppFunctionNameParam = ""> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppType = cppTypeParam;
+ // The name of the C++ function that is generated for this type constraint.
+ // If empty, no C++ function is generated.
+ string cppFunctionName = cppFunctionNameParam;
}
// Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0d0c28e651ee99..8877daaa775145 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -69,6 +69,10 @@ class Constraint {
/// context on the def).
std::string getUniqueDefName() const;
+ /// Returns the name of the C++ function that should be generated for this
+ /// constraint, or std::nullopt if no C++ function should be generated.
+ std::optional<StringRef> getCppFunctionName() const;
+
Kind getKind() const { return kind; }
/// Return the underlying def.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a3f5ece8c17369..f3f58efa5683f3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -32,6 +32,10 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
+namespace mlir {
+#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
@@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
// VectorType
//===----------------------------------------------------------------------===//
+bool VectorType::isValidElementType(Type t) {
+ return succeeded(isValidVectorTypeElementType(t));
+}
+
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
@@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
+bool TensorType::hasRank() const {
+ return !llvm::isa<UnrankedTensorType>(*this);
+}
ArrayRef<int64_t> TensorType::getShape() const {
return llvm::cast<RankedTensorType>(*this).getShape();
@@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
+bool BaseMemRefType::hasRank() const {
+ return !llvm::isa<UnrankedMemRefType>(*this);
+}
ArrayRef<int64_t> BaseMemRefType::getShape() const {
return llvm::cast<MemRefType>(*this).getShape();
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..4cabac185171c2 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen
MLIRBuiltinTypesIncGen
+ MLIRBuiltinTypeConstraintsIncGen
MLIRBuiltinTypeInterfacesIncGen
MLIRCallInterfacesIncGen
MLIRCastInterfacesIncGen
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 4ccbd0a685e09a..8cf4ed08a2d54f 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Region;
} else if (def->isSubClassOf("SuccessorConstraint")) {
kind = CK_Successor;
- } else if(!def->isSubClassOf("Constraint")) {
+ } else if (!def->isSubClassOf("Constraint")) {
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
llvm::report_fatal_error("Abort");
}
@@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
}
}
+std::optional<StringRef> Constraint::getCppFunctionName() const {
+ std::optional<StringRef> name =
+ def->getValueAsOptionalString("cppFunctionName");
+ if (!name || *name == "")
+ return std::nullopt;
+ return name;
+}
+
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
diff --git a/mlir/test/mlir-tblgen/type-constraints.td b/mlir/test/mlir-tblgen/type-constraints.td
new file mode 100644
index 00000000000000..9b0d1a97c4ab9d
--- /dev/null
+++ b/mlir/test/mlir-tblgen/type-constraints.td
@@ -0,0 +1,14 @@
+// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/CommonTypeConstraints.td"
+
+def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+ let cppFunctionName = "isValidDummy";
+}
+
+// DECL: ::llvm::LogicalResult isValidDummy(::mlir::Type type);
+
+// DEF: ::llvm::LogicalResult isValidDummy(::mlir::Type type) {
+// DEF: return ::llvm::success((((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type)))));
+// DEF: }
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 71ba6a5c73da9e..5aabe33d479214 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -1023,6 +1023,51 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
return false;
}
+//===----------------------------------------------------------------------===//
+// Type Constraints
+//===----------------------------------------------------------------------===//
+
+static const char *const typeConstraintDecl = R"(
+::llvm::LogicalResult {0}(::mlir::Type type);
+)";
+
+static const char *const typeConstraintDef = R"(
+::llvm::LogicalResult {0}(::mlir::Type type) {
+ return ::llvm::success(({1}));
+}
+)";
+
+/// Find all type constraints for which a C++ function should be generated.
+static std::vector<Constraint>
+getAllTypeConstraints(const llvm::RecordKeeper &records) {
+ std::vector<Constraint> result;
+ for (llvm::Record *def :
+ records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
+ Constraint constr(def);
+ // Generate C++ function only if "cppFunctionName" is set.
+ if (!constr.getCppFunctionName())
+ continue;
+ result.push_back(constr);
+ }
+ return result;
+}
+
+static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
+ raw_ostream &os) {
+ for (Constraint constr : getAllTypeConstraints(records))
+ os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
+}
+
+static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
+ raw_ostream &os) {
+ for (Constraint constr : getAllTypeConstraints(records)) {
+ FmtContext ctx;
+ ctx.withSelf("type");
+ std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
+ os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
+ }
+}
+
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
@@ -1070,3 +1115,18 @@ static mlir::GenRegistration
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});
+
+static mlir::GenRegistration
+ genTypeConstrDefs("gen-type-constraint-defs",
+ "Generate type constraint definitions",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ emitTypeConstraintDefs(records, os);
+ return false;
+ });
+static mlir::GenRegistration
+ genTypeConstrDecls("gen-type-constraint-decls",
+ "Generate type constraint declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ emitTypeConstraintDecls(records, os);
+ return false;
+ });
>From 82d69f29e82cc999aeada0e0f6f126e80d75888a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 19 Aug 2024 10:14:39 +0200
Subject: [PATCH 2/2] address comments
---
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 5aabe33d479214..43b5791788ec26 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -1043,6 +1043,10 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
std::vector<Constraint> result;
for (llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
+ // Ignore constraints defined outside of the top-level file.
+ if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+ llvm::SrcMgr.getMainFileID())
+ continue;
Constraint constr(def);
// Generate C++ function only if "cppFunctionName" is set.
if (!constr.getCppFunctionName())
More information about the Mlir-commits
mailing list