[Mlir-commits] [mlir] [mlir][tblgen] Add PredTypeTrait/PredAttrTrait support (PR #169153)

Tim Noack llvmlistbot at llvm.org
Sat Nov 22 00:16:23 PST 2025


https://github.com/timnoack created https://github.com/llvm/llvm-project/pull/169153

This patch adds support for `PredTypeTrait` and `PredAttrTrait` in type and attribute definitions, enabling declarative predicate-based verification similar to how `PredOpTrait` works for operations.

  ## Motivation

 Previously, `PredTypeTrait`/`PredAttrTrait` were defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to cast `PredTrait` to `InterfaceTrait`. This patch fixes the crash and implements the actual verification code generation.

  ## Usage

 Use `$paramName` syntax in predicates to reference type/attribute parameters:

  ```tablegen
  def MyType : MyDialect_Type<"MyType",
      [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
    let parameters = (ins "unsigned":$value);
    let mnemonic = "my_type";
    let assemblyFormat = "`<` $value `>`";
  }
  ```

  This generates verification code in `verifyInvariantsImpl()`:
```cpp
  if (!(value > 0)) {
    emitError() << "failed to verify that value must be positive";
    return ::mlir::failure();
  }
  ```

>From fce09402983f69c052d75779b38122c97f500b59 Mon Sep 17 00:00:00 2001
From: tn <noack at esa.tu-darmstadt.de>
Date: Sat, 22 Nov 2025 09:11:12 +0100
Subject: [PATCH] [mlir][tblgen] Add PredTypeTrait/PredAttrTrait support for
 type/attribute verification

---
 mlir/lib/TableGen/AttrOrTypeDef.cpp         |  8 ++--
 mlir/test/IR/test-verifiers-type.mlir       | 48 +++++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td  | 30 +++++++++++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 36 +++++++++++++---
 4 files changed, 113 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
 }
 
 bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
-  return any_of(parameters, [](const AttrOrTypeParameter &p) {
-    return p.getConstraint() != std::nullopt;
-  });
+  return any_of(parameters,
+                [](const AttrOrTypeParameter &p) {
+                  return p.getConstraint() != std::nullopt;
+                }) ||
+         any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
 }
 
 std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
 
 // expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
 "test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+    [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+  let parameters = (ins "unsigned":$value);
+  let mnemonic = "type_pred_trait";
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+    [PredTypeTrait<"value must be at least min",
+                   CPred<"$value >= $minValue">>]> {
+  let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+  let mnemonic = "type_pred_trait_multi";
+  let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+    [PredTypeTrait<"count must match number of elements",
+                   CPred<"$count == $elements.size()">>]> {
+  let parameters = (ins
+    "unsigned":$count,
+    ArrayRefParameter<"int64_t">:$elements,
+    AnyTypeOf<[I16, I32]>:$elementType
+  );
+  let mnemonic = "type_pred_trait_combined";
+  let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
 def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
     [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
   let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
                                  ? strfmt("{0}::{1}", def.getStorageNamespace(),
                                           def.getStorageClassName())
                                  : strfmt("::mlir::{0}Storage", valueType));
-  SmallVector<std::string> traitNames =
-      llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
-        return isa<NativeTrait>(&trait)
-                   ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
-                   : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
-      }));
+  SmallVector<std::string> traitNames;
+  for (auto &trait : def.getTraits()) {
+    // Skip PredTrait as it doesn't generate a C++ trait class.
+    if (isa<PredTrait>(&trait))
+      continue;
+    traitNames.push_back(
+        isa<NativeTrait>(&trait)
+            ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+            : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+  }
   for (auto &traitName : traitNames)
     defParent.addTemplateParam(traitName);
 
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
                                 param.getName(), constraint->getSummary())
                      << "\n";
   }
+  {
+    // Generate verification for PredTraits.
+    FmtContext traitCtx;
+    for (auto it : llvm::enumerate(def.getParameters())) {
+      // Note: Skip over the first method parameter (`emitError`).
+      traitCtx.addSubst(it.value().getName(),
+                        builderParams[it.index() + 1].getName());
+    }
+    for (const Trait &trait : def.getTraits()) {
+      if (auto *t = dyn_cast<PredTrait>(&trait)) {
+        verifier->body() << tgfmt(
+            "if (!($0)) {\n"
+            "  emitError() << \"failed to verify that $1\";\n"
+            "  return ::mlir::failure();\n"
+            "}\n",
+            &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+      }
+    }
+  }
+
   verifier->body() << "return ::mlir::success();";
 }
 



More information about the Mlir-commits mailing list