[Mlir-commits] [mlir] 122e685 - [mlir] do not elide dialect prefix for ops with dots in the name

Alex Zinenko llvmlistbot at llvm.org
Fri May 20 03:55:39 PDT 2022


Author: Alex Zinenko
Date: 2022-05-20T12:55:32+02:00
New Revision: 122e685878991b63e414733094e9876e715d08e0

URL: https://github.com/llvm/llvm-project/commit/122e685878991b63e414733094e9876e715d08e0
DIFF: https://github.com/llvm/llvm-project/commit/122e685878991b63e414733094e9876e715d08e0.diff

LOG: [mlir] do not elide dialect prefix for ops with dots in the name

For the hypothetical "a.b.c" op printed within a region that declares "a" as
the default dialect, MLIR would currently elide the "a." prefix and only print
"b.c". However, this becomes ambiguous while parsing as "b.c" may be exist as
the "c" op in the "b" dialect. If it does not, the parsing currently fails. Do
not elide the default dialect if the op name contains further dots to avoid the
ambiguity.

See https://discourse.llvm.org/t/dropping-dialect-prefix-for-ops-with-multiple-dots-in-the-name/62562

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D125975

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0836055494310..62aab6dd9306d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2718,7 +2718,10 @@ void OperationPrinter::printOperation(Operation *op) {
       if (auto opPrinter = dialect->getOperationPrinter(op)) {
         // Print the op name first.
         StringRef name = op->getName().getStringRef();
-        name.consume_front((defaultDialectStack.back() + ".").str());
+        // Only drop the default dialect prefix when it cannot lead to
+        // ambiguities.
+        if (name.count('.') == 1)
+          name.consume_front((defaultDialectStack.back() + ".").str());
         printEscapedString(name, os);
         // Print the rest of the op now.
         opPrinter(op, *this);

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index d7d59055479ce..d529385c25e16 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -624,11 +624,12 @@ void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
   }
 }
 
-/// Print an operation name, eliding the dialect prefix if necessary.
+/// Print an operation name, eliding the dialect prefix if necessary and doesn't
+/// lead to ambiguities.
 void OpState::printOpName(Operation *op, OpAsmPrinter &p,
                           StringRef defaultDialect) {
   StringRef name = op->getName().getStringRef();
-  if (name.startswith((defaultDialect + ".").str()))
+  if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1)
     name = name.drop_front(defaultDialect.size() + 1);
   p.getStream() << name;
 }

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 4187ae1ffc5d7..a28e63d5e498d 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1282,6 +1282,14 @@ func.func @default_dialect(%bool : i1) {
     // example.
     // CHECK:  "test.op_with_attr"() {test.attr = "test.value"} : () -> ()
     "test.op_with_attr"() {test.attr = "test.value"} : () -> ()
+    // Verify that the prefix is not stripped when it can lead to ambiguity.
+    // CHECK: test.op.with_dot_in_name
+    test.op.with_dot_in_name
+    // This is an unregistered operation, the printing/parsing is handled by the
+    // dialect, and the dialect prefix should not be stripped while printing
+    // because of potential ambiguity.
+    // CHECK: test.dialect_custom_printer.with.dot
+    test.dialect_custom_printer.with.dot
     "test.terminator"() : ()->()
   }
   // The same operation outside of the region does not have an func. prefix.

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index c0948020cf1c0..24e5bc6123603 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -371,6 +371,11 @@ TestDialect::getParseOperationHook(StringRef opName) const {
       return parser.parseKeyword("custom_format_fallback");
     }};
   }
+  if (opName == "test.dialect_custom_printer.with.dot") {
+    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+      return ParseResult::success();
+    }};
+  }
   return None;
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7826f842bcba8..0fd0566803df6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -810,6 +810,12 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
   let assemblyFormat = "regions attr-dict-with-keyword";
 }
 
+// This is used to test that the default dialect is not elided when printing an
+// op with dots in the name to avoid parsing ambiguity.
+def OpWithDotInNameOp : TEST_Op<"op.with_dot_in_name"> {
+  let assemblyFormat = "attr-dict";
+}
+
 // This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
 // blocks nested in a region under this op will have a name defined by the
 // interface.


        


More information about the Mlir-commits mailing list