[Mlir-commits] [mlir] aa9ae76 - [mlir][ODS] Verify type constraint in `TypeAttrOf`
Markus Böck
llvmlistbot at llvm.org
Fri Apr 7 03:30:57 PDT 2023
Author: Markus Böck
Date: 2023-04-07T12:30:15+02:00
New Revision: aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3
URL: https://github.com/llvm/llvm-project/commit/aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3
DIFF: https://github.com/llvm/llvm-project/commit/aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3.diff
LOG: [mlir][ODS] Verify type constraint in `TypeAttrOf`
The current implementation does not verify the type constraint, meaning that any type that happens to be of the same C++ type would pass the verifier.
E.g. a `TypeAttrOf<I64>` would happily accept a `i32` since both satisfy `isa<IntegerType>()`.
This patch fixes that by adding an optional type predicate parameter to `TypeAttrBase` that the type within `TypeAttr` has to satisfy. `TypeAttrOf` then simply passes the predicate of its type parameter as argument.
Differential Revision: https://reviews.llvm.org/D147778
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/test/IR/attribute.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-attribute.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 98866c83b4b4e..f4aa07ff2a443 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1296,11 +1296,14 @@ class TypedStrAttr<Type ty>
// Base class for attributes containing types. Example:
// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
// defines a type attribute containing an integer type.
-class TypeAttrBase<string retType, string summary> :
+class TypeAttrBase<string retType, string summary,
+ Pred typePred = CPred<"true">> :
Attr<And<[
CPred<"$_self.isa<::mlir::TypeAttr>()">,
CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<"
- # retType # ">()">]>,
+ # retType # ">()">,
+ SubstLeaves<"$_self",
+ "$_self.cast<::mlir::TypeAttr>().getValue()", typePred>]>,
summary> {
let storageType = [{ ::mlir::TypeAttr }];
let returnType = retType;
@@ -1313,7 +1316,8 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
}
class TypeAttrOf<Type ty>
- : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary> {
+ : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+ ty.predicate> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index c296507868cbc..25d237a74f3ad 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -879,3 +879,11 @@ func.func @default_value_printing(%arg0 : i32) {
"test.default_value_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> ()
return
}
+
+// -----
+
+func.func @type_attr_of_fail() {
+ // expected-error @below {{failed to satisfy constraint: type attribute of 64-bit signless integer}}
+ test.type_attr_of i32
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ff24ac94bdfb3..0306f0ed02f99 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -277,6 +277,13 @@ def TypedAttrOp : TEST_Op<"typed_attr"> {
}];
}
+def TypeAttrOfOp : TEST_Op<"type_attr_of"> {
+ let arguments = (ins TypeAttrOf<I64>:$type);
+ let assemblyFormat = [{
+ attr-dict $type
+ }];
+}
+
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
let arguments = (ins
DenseBoolArrayAttr:$i1attr,
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 3dc426cbf5d7f..af1f62221fc07 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -318,10 +318,10 @@ def BOp : NS_Op<"b_op", []> {
// DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>())))
// DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>())))
// DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>())))
-// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>()))))
+// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>())) && ((true))))
// DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>())))
// DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return attr && ((some-condition)); }))))
-// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>()))))
+// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())) && ((true))))
// Test common attribute kind getters' return types
// ---
More information about the Mlir-commits
mailing list