[Mlir-commits] [mlir] 92bdc5c - [mlir][ods] Add convertFromStorage field to parameters

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 27 15:57:25 PDT 2022


Author: Mogball
Date: 2022-06-27T15:57:21-07:00
New Revision: 92bdc5c3e55fb1d48fe3a2bfad99911e8534db92

URL: https://github.com/llvm/llvm-project/commit/92bdc5c3e55fb1d48fe3a2bfad99911e8534db92
DIFF: https://github.com/llvm/llvm-project/commit/92bdc5c3e55fb1d48fe3a2bfad99911e8534db92.diff

LOG: [mlir][ods] Add convertFromStorage field to parameters

This patch adds a `convertFromStorage` field to attribute or type parameters that can implement more complex logic for converting from the parameter's C++ storage type (e.g. `Optional<SmallVector<T>>`) to its C++ type (e.g. `Optional<ArrayRef<T>>`).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D128293

Added: 
    

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 419f5865c2e4b..bfe9e6c476604 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -71,7 +71,7 @@ Operations, etc.
 
 ```tablegen
 // Include the definition of the necessary tablegen constructs for defining
-// our types. 
+// our types.
 include "mlir/IR/AttrTypeBase.td"
 
 // It's common to define a base classes for types in the same dialect. This
@@ -108,7 +108,7 @@ Below is an example of an Attribute:
 
 ```tablegen
 // Include the definition of the necessary tablegen constructs for defining
-// our attributes. 
+// our attributes.
 include "mlir/IR/AttrTypeBase.td"
 
 // It's common to define a base classes for attributes in the same dialect. This
@@ -128,11 +128,11 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
   }];
   /// Here we've defined two parameters, one is the `self` type of the attribute
   /// (i.e. the type of the Attribute itself), and the other is the integer value
-  /// of the attribute. 
+  /// of the attribute.
   let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
-  
+
   /// Here we've defined a custom builder for the type, that removes the need to pass
-  /// in an MLIRContext instance; as it can be infered from the `type`. 
+  /// in an MLIRContext instance; as it can be infered from the `type`.
   let builders = [
     AttrBuilderWithInferredContext<(ins "Type":$type,
                                         "const APInt &":$value), [{
@@ -147,7 +147,7 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
   ///    #my.int<50> : !my.int<32> // a 32-bit integer of value 50.
   ///
   let assemblyFormat = "`<` $value `>`";
-  
+
   /// Indicate that our attribute will add additional verification to the parameters.
   let genVerifyDecl = 1;
 
@@ -612,6 +612,10 @@ example, `StringRefParameter` uses `std::string` as its storage type, whereas
 `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for
 these parameters are expected to return `FailureOr<$cppStorageType>`.
 
+To add a custom conversion between the `cppStorageType` and the C++ type of the
+parameter, parameters can override `convertFromStorage`, which by default is
+`"$_self"` (i.e., it attempts an implicit conversion from `cppStorageType`).
+
 ###### Optional Parameters
 
 Optional parameters in the assembly format can be indicated by setting
@@ -1060,7 +1064,7 @@ void MyDialect::initialize() {
 #define GET_ATTRDEF_LIST
 #include "MyDialect/Attributes.cpp.inc"
   >();
-  
+
     /// Add the defined types to the dialect.
   addTypes<
 #define GET_TYPEDEF_LIST

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index ead971f01bbc1..88d35e12ef547 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -295,6 +295,8 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // The C++ storage type of of this parameter if it is a reference, e.g.
   // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
   string cppStorageType = ?;
+  // The C++ code to convert from the storage type to the parameter type.
+  string convertFromStorage = "$_self";
   // One-line human-readable description of the argument.
   string summary = desc;
   // The format string for the asm syntax (documentation only).

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 0a23e0fed56cc..043d63794bdf0 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -76,6 +76,9 @@ class AttrOrTypeParameter {
   /// Get the C++ storage type of this parameter.
   StringRef getCppStorageType() const;
 
+  /// Get the C++ code to convert from the storage type to the parameter type.
+  StringRef getConvertFromStorage() const;
+
   /// Get an optional C++ parameter parser.
   Optional<StringRef> getParser() const;
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4479a64940bc0..a202cb5d47d34 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -229,14 +229,9 @@ StringRef AttrOrTypeParameter::getComparator() const {
 }
 
 StringRef AttrOrTypeParameter::getCppType() const {
-  llvm::Init *parameterType = getDef();
-  if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+  if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
     return stringType->getValue();
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
-    return param->getDef()->getValueAsString("cppType");
-  llvm::PrintFatalError(
-      "Parameters DAG arguments must be either strings or defs "
-      "which inherit from AttrOrTypeParameter\n");
+  return getDefValue<llvm::StringInit>("cppType").getValue();
 }
 
 StringRef AttrOrTypeParameter::getCppAccessorType() const {
@@ -248,6 +243,11 @@ StringRef AttrOrTypeParameter::getCppStorageType() const {
   return getDefValue<llvm::StringInit>("cppStorageType").value_or(getCppType());
 }
 
+StringRef AttrOrTypeParameter::getConvertFromStorage() const {
+  return getDefValue<llvm::StringInit>("convertFromStorage")
+      .getValueOr("$_self");
+}
+
 Optional<StringRef> AttrOrTypeParameter::getParser() const {
   return getDefValue<llvm::StringInit>("parser");
 }

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 7e09f223435e3..751e9c1b33e3b 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -55,8 +55,8 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
 // ATTR:   if (odsParser.parseRParen())
 // ATTR:     return {};
 // ATTR:   return TestAAttr::get(odsParser.getContext(),
-// ATTR:                         *_result_value,
-// ATTR:                         *_result_complex);
+// ATTR:                         (*_result_value),
+// ATTR:                         (*_result_complex));
 // ATTR: }
 
 // ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -114,8 +114,8 @@ def AttrA : TestAttr<"TestA"> {
 // ATTR:       return {};
 // ATTR:   }
 // ATTR:   return TestBAttr::get(odsParser.getContext(),
-// ATTR:                         *_result_v0,
-// ATTR:                         *_result_v1);
+// ATTR:                         (*_result_v0),
+// ATTR:                         (*_result_v1));
 // ATTR: }
 
 // ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -151,8 +151,8 @@ def AttrB : TestAttr<"TestB"> {
 // ATTR:   if (::mlir::failed(_result_v1))
 // ATTR:     return {};
 // ATTR:   return TestFAttr::get(odsParser.getContext(),
-// ATTR:     *_result_v0,
-// ATTR:     *_result_v1);
+// ATTR:     (*_result_v0),
+// ATTR:     (*_result_v1));
 // ATTR: }
 
 def AttrC : TestAttr<"TestF"> {
@@ -278,10 +278,10 @@ def TypeA : TestType<"TestC"> {
 // TYPE:   if (::mlir::failed(_result_v3))
 // TYPE:     return {};
 // TYPE:   return TestDType::get(odsParser.getContext(),
-// TYPE:                         *_result_v0,
-// TYPE:                         *_result_v1,
-// TYPE:                         *_result_v2,
-// TYPE:                         *_result_v3);
+// TYPE:                         (*_result_v0),
+// TYPE:                         (*_result_v1),
+// TYPE:                         (*_result_v2),
+// TYPE:                         (*_result_v3));
 // TYPE: }
 
 // TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -369,10 +369,10 @@ def TypeB : TestType<"TestD"> {
 // TYPE:       return {};
 // TYPE:   }
 // TYPE:   return TestEType::get(odsParser.getContext(),
-// TYPE:     *_result_v0,
-// TYPE:     *_result_v1,
-// TYPE:     *_result_v2,
-// TYPE:     *_result_v3);
+// TYPE:     (*_result_v0),
+// TYPE:     (*_result_v1),
+// TYPE:     (*_result_v2),
+// TYPE:     (*_result_v3));
 // TYPE: }
 
 // TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -535,3 +535,19 @@ def TypeJ : TestType<"TestL"> {
   let mnemonic = "type_j";
   let assemblyFormat = "custom<A>($a) custom<B>($b, ref($a))";
 }
+
+// TYPE: ::mlir::Type TestMType::parse
+// TYPE: FailureOr<float> _result_a
+// TYPE: return TestMType::get
+// TYPE: static_cast<int>((*_result_a))
+
+def ConvertFromStorageParameter : TypeParameter<"int", ""> {
+  let cppStorageType = "float";
+  let convertFromStorage = "static_cast<int>($_self)";
+}
+
+def TypeK : TestType<"TestM"> {
+  let parameters = (ins ConvertFromStorageParameter:$a);
+  let mnemonic = "type_k";
+  let assemblyFormat = "$a";
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index ba986d081e2bd..a8a63f187d0b4 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -297,18 +297,21 @@ void DefFormat::genParser(MethodBody &os) {
   }
   for (const AttrOrTypeParameter &param : params) {
     os << ",\n    ";
+    std::string paramSelfStr;
+    llvm::raw_string_ostream selfOs(paramSelfStr);
     if (param.isOptional()) {
-      os << formatv("_result_{0}.value_or(", param.getName());
+      selfOs << formatv("(_result_{0}.value_or(", param.getName());
       if (Optional<StringRef> defaultValue = param.getDefaultValue())
-        os << tgfmt(*defaultValue, &ctx);
+        selfOs << tgfmt(*defaultValue, &ctx);
       else
-        os << param.getCppStorageType() << "()";
-      os << ")";
+        selfOs << param.getCppStorageType() << "()";
+      selfOs << "))";
     } else if (isa<AttributeSelfTypeParameter>(param)) {
-      os << tgfmt("$_type", &ctx);
+      selfOs << tgfmt("$_type", &ctx);
     } else {
-      os << formatv("*_result_{0}", param.getName());
+      selfOs << formatv("(*_result_{0})", param.getName());
     }
+    os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()));
   }
   os << ");";
 }


        


More information about the Mlir-commits mailing list