[Mlir-commits] [mlir] [mlir][tblgen] Add PredTypeTrait/PredAttrTrait support (PR #169153)
Tim Noack
llvmlistbot at llvm.org
Sat Nov 22 00:16:23 PST 2025
https://github.com/timnoack created https://github.com/llvm/llvm-project/pull/169153
This patch adds support for `PredTypeTrait` and `PredAttrTrait` in type and attribute definitions, enabling declarative predicate-based verification similar to how `PredOpTrait` works for operations.
## Motivation
Previously, `PredTypeTrait`/`PredAttrTrait` were defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to cast `PredTrait` to `InterfaceTrait`. This patch fixes the crash and implements the actual verification code generation.
## Usage
Use `$paramName` syntax in predicates to reference type/attribute parameters:
```tablegen
def MyType : MyDialect_Type<"MyType",
[PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
let parameters = (ins "unsigned":$value);
let mnemonic = "my_type";
let assemblyFormat = "`<` $value `>`";
}
```
This generates verification code in `verifyInvariantsImpl()`:
```cpp
if (!(value > 0)) {
emitError() << "failed to verify that value must be positive";
return ::mlir::failure();
}
```
>From fce09402983f69c052d75779b38122c97f500b59 Mon Sep 17 00:00:00 2001
From: tn <noack at esa.tu-darmstadt.de>
Date: Sat, 22 Nov 2025 09:11:12 +0100
Subject: [PATCH] [mlir][tblgen] Add PredTypeTrait/PredAttrTrait support for
type/attribute verification
---
mlir/lib/TableGen/AttrOrTypeDef.cpp | 8 ++--
mlir/test/IR/test-verifiers-type.mlir | 48 +++++++++++++++++++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 30 +++++++++++++
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 36 +++++++++++++---
4 files changed, 113 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
}
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
- return any_of(parameters, [](const AttrOrTypeParameter &p) {
- return p.getConstraint() != std::nullopt;
- });
+ return any_of(parameters,
+ [](const AttrOrTypeParameter &p) {
+ return p.getConstraint() != std::nullopt;
+ }) ||
+ any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
}
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
"test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+ [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+ let parameters = (ins "unsigned":$value);
+ let mnemonic = "type_pred_trait";
+ let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+ [PredTypeTrait<"value must be at least min",
+ CPred<"$value >= $minValue">>]> {
+ let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+ let mnemonic = "type_pred_trait_multi";
+ let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+ [PredTypeTrait<"count must match number of elements",
+ CPred<"$count == $elements.size()">>]> {
+ let parameters = (ins
+ "unsigned":$count,
+ ArrayRefParameter<"int64_t">:$elements,
+ AnyTypeOf<[I16, I32]>:$elementType
+ );
+ let mnemonic = "type_pred_trait_combined";
+ let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
- SmallVector<std::string> traitNames =
- llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
- return isa<NativeTrait>(&trait)
- ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
- : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
- }));
+ SmallVector<std::string> traitNames;
+ for (auto &trait : def.getTraits()) {
+ // Skip PredTrait as it doesn't generate a C++ trait class.
+ if (isa<PredTrait>(&trait))
+ continue;
+ traitNames.push_back(
+ isa<NativeTrait>(&trait)
+ ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+ : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+ }
for (auto &traitName : traitNames)
defParent.addTemplateParam(traitName);
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
param.getName(), constraint->getSummary())
<< "\n";
}
+ {
+ // Generate verification for PredTraits.
+ FmtContext traitCtx;
+ for (auto it : llvm::enumerate(def.getParameters())) {
+ // Note: Skip over the first method parameter (`emitError`).
+ traitCtx.addSubst(it.value().getName(),
+ builderParams[it.index() + 1].getName());
+ }
+ for (const Trait &trait : def.getTraits()) {
+ if (auto *t = dyn_cast<PredTrait>(&trait)) {
+ verifier->body() << tgfmt(
+ "if (!($0)) {\n"
+ " emitError() << \"failed to verify that $1\";\n"
+ " return ::mlir::failure();\n"
+ "}\n",
+ &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+ }
+ }
+ }
+
verifier->body() << "return ::mlir::success();";
}
More information about the Mlir-commits
mailing list