[Mlir-commits] [mlir] [mlir] Make `StringRefParameter` roundtrippable (PR #65813)

Markus Böck llvmlistbot at llvm.org
Fri Sep 8 14:40:51 PDT 2023


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/65813:

The current printer of `StringRefParameter` simply prints out the content of the string as is without escaping it any way. This leads to it generating invalid syntax, causing parser errors when read in again.

This PR fixes that by adding ´printString` to `AsmPrinter`, allowing one to print a string that can be parsed with `parseString`, using the same escaping syntax as `StringAttr`.

>From 1fc7cddad413f7a3e714fb2dea67b5d01e9e615f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Fri, 8 Sep 2023 23:38:38 +0200
Subject: [PATCH] [mlir] Make `StringRefParameter` roundtrippable
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The current printer of `StringRefParameter` simply prints out the content of the string as is without escaping it any way. This leads to it generating invalid syntax, causing parser errors when read in again.

This PR fixes that by adding ´printString` to `AsmPrinter`, allowing one to print a string that can be parsed with `parseString`, using the same escaping syntax as `StringAttr`.
---
 mlir/include/mlir/IR/AttrTypeBase.td                     | 2 +-
 mlir/include/mlir/IR/OpImplementation.h                  | 4 ++++
 mlir/lib/IR/AsmPrinter.cpp                               | 9 +++++++++
 mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir | 4 +++-
 4 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 3e356373cbd7353..42a611ee8e42205 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -363,7 +363,7 @@ class DefaultValuedParameter<string type, string value, string desc = ""> :
 class StringRefParameter<string desc = "", string value = ""> :
     AttrOrTypeParameter<"::llvm::StringRef", desc> {
   let allocator = [{$_dst = $_allocator.copyInto($_self);}];
-  let printer = [{$_printer << '"' << $_self << '"';}];
+  let printer = [{$_printer.printString($_self);}];
   let cppStorageType = "std::string";
   let defaultValue = value;
 }
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f894ee64a27b0cf..8864ef02cd3cbba 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -184,6 +184,10 @@ class AsmPrinter {
   /// has any special or non-printable characters in it.
   virtual void printKeywordOrString(StringRef keyword);
 
+  /// Print the given string as a quoted string, escaping any special or
+  /// non-printable characters in it.
+  virtual void printString(StringRef string);
+
   /// Print the given string as a symbol reference, i.e. a form representable by
   /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
   /// with '@'. The reference is surrounded with ""'s and escaped if it has any
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c662edd592036ce..7b0da30541b16a4 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
     os << "%";
   }
   void printKeywordOrString(StringRef) override {}
+  void printString(StringRef) override {}
   void printResourceHandle(const AsmDialectResourceHandle &) override {}
   void printSymbolName(StringRef) override {}
   void printSuccessor(Block *) override {}
@@ -919,6 +920,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
   /// determining potential aliases.
   void printFloat(const APFloat &) override {}
   void printKeywordOrString(StringRef) override {}
+  void printString(StringRef) override {}
   void printSymbolName(StringRef) override {}
   void printResourceHandle(const AsmDialectResourceHandle &) override {}
 
@@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) {
   ::printKeywordOrString(keyword, impl->getStream());
 }
 
+void AsmPrinter::printString(StringRef keyword) {
+  assert(impl && "expected AsmPrinter::printString to be overriden");
+  *this << '"';
+  printEscapedString(keyword, getStream());
+  *this << '"';
+}
+
 void AsmPrinter::printSymbolName(StringRef symbolRef) {
   assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
   ::printSymbolReference(symbolRef, impl->getStream());
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index 12289b4d732593b..160c388cedf7565 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -70,6 +70,7 @@ attributes {
 // CHECK: !test.optional_type_string
 // CHECK: !test.optional_type_string
 // CHECK: !test.optional_type_string<"non default">
+// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F">
 
 func.func private @test_roundtrip_default_parsers_struct(
   !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct(
   !test.custom_type_string<"bar" bar>,
   !test.optional_type_string,
   !test.optional_type_string<"default">,
-  !test.optional_type_string<"non default">
+  !test.optional_type_string<"non default">,
+  !test.optional_type_string<"containing\n \"escape\" characters\0f">
 )



More information about the Mlir-commits mailing list