[Mlir-commits] [mlir] aa9ae76 - [mlir][ODS] Verify type constraint in `TypeAttrOf`

Markus Böck llvmlistbot at llvm.org
Fri Apr 7 03:30:57 PDT 2023


Author: Markus Böck
Date: 2023-04-07T12:30:15+02:00
New Revision: aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3

URL: https://github.com/llvm/llvm-project/commit/aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3
DIFF: https://github.com/llvm/llvm-project/commit/aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3.diff

LOG: [mlir][ODS] Verify type constraint in `TypeAttrOf`

The current implementation does not verify the type constraint, meaning that any type that happens to be of the same C++ type would pass the verifier.
E.g. a `TypeAttrOf<I64>` would happily accept a `i32` since both satisfy `isa<IntegerType>()`.

This patch fixes that by adding an optional type predicate parameter to `TypeAttrBase` that the type within `TypeAttr` has to satisfy. `TypeAttrOf` then simply passes the predicate of its type parameter as argument.

Differential Revision: https://reviews.llvm.org/D147778

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-attribute.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 98866c83b4b4e..f4aa07ff2a443 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1296,11 +1296,14 @@ class TypedStrAttr<Type ty>
 // Base class for attributes containing types. Example:
 //   def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
 // defines a type attribute containing an integer type.
-class TypeAttrBase<string retType, string summary> :
+class TypeAttrBase<string retType, string summary,
+                        Pred typePred = CPred<"true">> :
     Attr<And<[
       CPred<"$_self.isa<::mlir::TypeAttr>()">,
       CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<"
-            # retType # ">()">]>,
+            # retType # ">()">,
+      SubstLeaves<"$_self",
+                    "$_self.cast<::mlir::TypeAttr>().getValue()", typePred>]>,
     summary> {
   let storageType = [{ ::mlir::TypeAttr }];
   let returnType = retType;
@@ -1313,7 +1316,8 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
 }
 
 class TypeAttrOf<Type ty>
-   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary> {
+   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+                    ty.predicate> {
   let constBuilderCall = "::mlir::TypeAttr::get($0)";
 }
 

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index c296507868cbc..25d237a74f3ad 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -879,3 +879,11 @@ func.func @default_value_printing(%arg0 : i32) {
   "test.default_value_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> ()
   return
 }
+
+// -----
+
+func.func @type_attr_of_fail() {
+    // expected-error @below {{failed to satisfy constraint: type attribute of 64-bit signless integer}}
+    test.type_attr_of i32
+    return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ff24ac94bdfb3..0306f0ed02f99 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -277,6 +277,13 @@ def TypedAttrOp : TEST_Op<"typed_attr"> {
   }];
 }
 
+def TypeAttrOfOp : TEST_Op<"type_attr_of"> {
+  let arguments = (ins TypeAttrOf<I64>:$type);
+  let assemblyFormat = [{
+    attr-dict $type
+  }];
+}
+
 def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
   let arguments = (ins
     DenseBoolArrayAttr:$i1attr,

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 3dc426cbf5d7f..af1f62221fc07 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -318,10 +318,10 @@ def BOp : NS_Op<"b_op", []> {
 // DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>())))
 // DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>())))
 // DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>())))
-// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>()))))
+// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>())) && ((true))))
 // DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>())))
 // DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return attr && ((some-condition)); }))))
-// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>()))))
+// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())) && ((true))))
 
 // Test common attribute kind getters' return types
 // ---


        


More information about the Mlir-commits mailing list