[Mlir-commits] [mlir] 596da62 - Add support for custom op parser/printer hooks to know about result names.

Chris Lattner llvmlistbot at llvm.org
Mon Mar 23 08:58:13 PDT 2020


Author: Chris Lattner
Date: 2020-03-23T08:58:05-07:00
New Revision: 596da62d21ede197dd95eca5146d2ddf0497275c

URL: https://github.com/llvm/llvm-project/commit/596da62d21ede197dd95eca5146d2ddf0497275c
DIFF: https://github.com/llvm/llvm-project/commit/596da62d21ede197dd95eca5146d2ddf0497275c.diff

LOG: Add support for custom op parser/printer hooks to know about result names.

Summary:
This allows the custom parser/printer hooks to do interesting things with
the SSA names.  This patch:

 - Adds a new 'getResultName' method to OpAsmParser that allows a parser
   implementation to get information about its result names, along with
   a getNumResults() method that allows op parser impls to know how many
   results are expected.
 - Adds a OpAsmPrinter::printOperand overload that takes an explicit stream.
 - Adds a test.string_attr_pretty_name operation that uses these hooks to
   do fancy things with the result name.

Reviewers: rriddle!

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76205

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/parser.mlir
    mlir/test/lib/TestDialect/TestDialect.cpp
    mlir/test/lib/TestDialect/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 74b42243f236..f471e6ca0cc5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -37,6 +37,7 @@ class OpAsmPrinter {
 
   /// Print implementations for various things an operation contains.
   virtual void printOperand(Value value) = 0;
+  virtual void printOperand(Value value, raw_ostream &os) = 0;
 
   /// Print a comma separated list of operands.
   template <typename ContainerType>
@@ -245,6 +246,24 @@ class OpAsmParser {
     return success();
   }
 
+  /// Return the name of the specified result in the specified syntax, as well
+  /// as the sub-element in the name.  It returns an empty string and ~0U for
+  /// invalid result numbers.  For example, in this operation:
+  ///
+  ///  %x, %y:2, %z = foo.op
+  ///
+  ///    getResultName(0) == {"x", 0 }
+  ///    getResultName(1) == {"y", 0 }
+  ///    getResultName(2) == {"y", 1 }
+  ///    getResultName(3) == {"z", 0 }
+  ///    getResultName(4) == {"", ~0U }
+  virtual std::pair<StringRef, unsigned>
+  getResultName(unsigned resultNo) const = 0;
+
+  /// Return the number of declared SSA results.  This returns 4 for the foo.op
+  /// example in the comment for `getResultName`.
+  virtual size_t getNumResults() const = 0;
+
   /// Return the location of the original name token.
   virtual llvm::SMLoc getNameLoc() const = 0;
 

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index bbb2462d176d..bf7905a14781 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -765,10 +765,10 @@ void SSANameState::setValueName(Value value, StringRef name) {
 static bool isPunct(char c) {
   return c == '$' || c == '.' || c == '_' || c == '-';
 }
-  
+
 StringRef SSANameState::uniqueValueName(StringRef name) {
   assert(!name.empty() && "Shouldn't have an empty name here");
-  
+
   // Check to see if this name is valid.  If it starts with a digit, then it
   // could conflict with the autogenerated numeric ID's (we unique them in a
   // 
diff erent map), so add an underscore prefix to avoid problems.
@@ -777,13 +777,13 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
     tmpName += name;
     return uniqueValueName(tmpName);
   }
-  
+
   // Check to see if the name consists of all-valid identifiers.  If not, we
   // need to escape them.
   for (char ch : name) {
     if (isalpha(ch) || isPunct(ch) || isdigit(ch))
       continue;
-    
+
     SmallString<16> tmpName;
     for (char ch : name) {
       if (isalpha(ch) || isPunct(ch) || isdigit(ch))
@@ -796,7 +796,7 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
     }
     return uniqueValueName(tmpName);
   }
-  
+
   // Check to see if this name is already unique.
   if (!usedNames.count(name)) {
     name = name.copy(usedNameAllocator);
@@ -1963,7 +1963,8 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
              bool printBlockTerminator = true);
 
   /// Print the ID of the given value, optionally with its result number.
-  void printValueID(Value value, bool printResultNo = true) const;
+  void printValueID(Value value, bool printResultNo = true,
+                    raw_ostream *streamOverride = nullptr) const;
 
   //===--------------------------------------------------------------------===//
   // OpAsmPrinter methods
@@ -1988,6 +1989,9 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
 
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
+  void printOperand(Value value, raw_ostream &os) override {
+    printValueID(value, /*printResultNo=*/true, &os);
+  }
 
   /// Print an optional attribute dictionary with a given set of elided values.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
@@ -2195,8 +2199,10 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
   currentIndent -= indentWidth;
 }
 
-void OperationPrinter::printValueID(Value value, bool printResultNo) const {
-  state->getSSANameState().printValueID(value, printResultNo, os);
+void OperationPrinter::printValueID(Value value, bool printResultNo,
+                                    raw_ostream *streamOverride) const {
+  state->getSSANameState().printValueID(value, printResultNo,
+                                        streamOverride ? *streamOverride : os);
 }
 
 void OperationPrinter::printSuccessor(Block *successor) {

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d987fd0d8add..a29b34b570d0 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3322,8 +3322,13 @@ class OperationParser : public Parser {
   Operation *parseGenericOperation(Block *insertBlock,
                                    Block::iterator insertPt);
 
+  /// This is the structure of a result specifier in the assembly syntax,
+  /// including the name, number of results, and location.
+  typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord;
+
   /// Parse an operation instance that is in the op-defined custom form.
-  Operation *parseCustomOperation();
+  /// resultInfo specifies information about the "%name =" specifiers.
+  Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo);
 
   //===--------------------------------------------------------------------===//
   // Region Parsing
@@ -3728,7 +3733,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
 ///
 ParseResult OperationParser::parseOperation() {
   auto loc = getToken().getLoc();
-  SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs;
+  SmallVector<ResultRecord, 1> resultIDs;
   size_t numExpectedResults = 0;
   if (getToken().is(Token::percent_identifier)) {
     // Parse the group of result ids.
@@ -3769,7 +3774,7 @@ ParseResult OperationParser::parseOperation() {
 
   Operation *op;
   if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
-    op = parseCustomOperation();
+    op = parseCustomOperation(resultIDs);
   else if (getToken().is(Token::string))
     op = parseGenericOperation();
   else
@@ -3790,7 +3795,7 @@ ParseResult OperationParser::parseOperation() {
 
     // Add definitions for each of the result groups.
     unsigned opResI = 0;
-    for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) {
+    for (ResultRecord &resIt : resultIDs) {
       for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
         if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)},
                           op->getResult(opResI++)))
@@ -3955,9 +3960,12 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
 namespace {
 class CustomOpAsmParser : public OpAsmParser {
 public:
-  CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition,
+  CustomOpAsmParser(SMLoc nameLoc,
+                    ArrayRef<OperationParser::ResultRecord> resultIDs,
+                    const AbstractOperation *opDefinition,
                     OperationParser &parser)
-      : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {}
+      : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition),
+        parser(parser) {}
 
   /// Parse an instance of the operation described by 'opDefinition' into the
   /// provided operation state.
@@ -3992,6 +4000,41 @@ class CustomOpAsmParser : public OpAsmParser {
 
   Builder &getBuilder() const override { return parser.builder; }
 
+  /// Return the name of the specified result in the specified syntax, as well
+  /// as the subelement in the name.  For example, in this operation:
+  ///
+  ///  %x, %y:2, %z = foo.op
+  ///
+  ///    getResultName(0) == {"x", 0 }
+  ///    getResultName(1) == {"y", 0 }
+  ///    getResultName(2) == {"y", 1 }
+  ///    getResultName(3) == {"z", 0 }
+  std::pair<StringRef, unsigned>
+  getResultName(unsigned resultNo) const override {
+    // Scan for the resultID that contains this result number.
+    for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) {
+      const auto &entry = resultIDs[nameID];
+      if (resultNo < std::get<1>(entry)) {
+        // Don't pass on the leading %.
+        StringRef name = std::get<0>(entry).drop_front();
+        return {name, resultNo};
+      }
+      resultNo -= std::get<1>(entry);
+    }
+
+    // Invalid result number.
+    return {"", ~0U};
+  }
+
+  /// Return the number of declared SSA results.  This returns 4 for the foo.op
+  /// example in the comment for getResultName.
+  size_t getNumResults() const override {
+    size_t count = 0;
+    for (auto &entry : resultIDs)
+      count += std::get<1>(entry);
+    return count;
+  }
+
   llvm::SMLoc getNameLoc() const override { return nameLoc; }
 
   //===--------------------------------------------------------------------===//
@@ -4500,6 +4543,9 @@ class CustomOpAsmParser : public OpAsmParser {
   /// The source location of the operation name.
   SMLoc nameLoc;
 
+  /// Information about the result name specifiers.
+  ArrayRef<OperationParser::ResultRecord> resultIDs;
+
   /// The abstract information of the operation.
   const AbstractOperation *opDefinition;
 
@@ -4511,7 +4557,8 @@ class CustomOpAsmParser : public OpAsmParser {
 };
 } // end anonymous namespace.
 
-Operation *OperationParser::parseCustomOperation() {
+Operation *
+OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   auto opLoc = getToken().getLoc();
   auto opName = getTokenSpelling();
 
@@ -4544,7 +4591,7 @@ Operation *OperationParser::parseCustomOperation() {
   // Have the op implementation take a crack and parsing this.
   OperationState opState(srcLocation, opDefinition->name);
   CleanupOpStateRegions guard{opState};
-  CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this);
+  CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this);
   if (opAsmParser.parseOperation(opState))
     return nullptr;
 

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index a6dc9b6617b3..253c1cc3e745 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1185,3 +1185,43 @@ func @custom_asm_names() -> (i32, i32, i32, i32, i32, i32, i32) {
   // CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]]
   return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32
 }
+
+
+// CHECK-LABEL: func @pretty_names
+
+// This tests the behavior
+func @pretty_names() {
+  // Simple case, should parse and print as %x being an implied 'name'
+  // attribute.
+  %x = test.string_attr_pretty_name
+  // CHECK: %x = test.string_attr_pretty_name
+  // CHECK-NOT: attributes
+  
+  // This specifies an explicit name, which should override the result.
+  %YY = test.string_attr_pretty_name attributes { names = ["y"] }
+  // CHECK: %y = test.string_attr_pretty_name
+  // CHECK-NOT: attributes
+  
+  // Conflicts with the 'y' name, so need an explicit attribute.
+  %0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32
+  // CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]}
+
+  // Name contains a space.
+  %1 = "test.string_attr_pretty_name"() { names = ["space name"]} : () -> i32
+  // CHECK: %space_name = test.string_attr_pretty_name attributes {names = ["space name"]}
+
+  "unknown.use"(%x, %YY, %0, %1) : (i32, i32, i32, i32) -> ()
+
+  // Multi-result support.
+
+  %a, %b, %c = test.string_attr_pretty_name
+  // CHECK: %a, %b, %c = test.string_attr_pretty_name
+  // CHECK-NOT: attributes
+
+  %q:3, %r = test.string_attr_pretty_name
+  // CHECK: %q, %q_1, %q_2, %r = test.string_attr_pretty_name attributes {names = ["q", "q", "q", "r"]}
+
+  // CHECK: return
+  return
+}
+

diff  --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 166edf820206..21941f0da606 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -391,6 +391,87 @@ void SideEffectOp::getEffects(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// StringAttrPrettyNameOp
+//===----------------------------------------------------------------------===//
+
+// This op has fancy handling of its SSA result name.
+static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
+                                               OperationState &result) {
+  // Add the result types.
+  for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
+    result.addTypes(parser.getBuilder().getIntegerType(32));
+
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // If the attribute dictionary contains no 'names' attribute, infer it from
+  // the SSA name (if specified).
+  bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
+    return attr.first.is("names");
+  });
+
+  // If there was no name specified, check to see if there was a useful name
+  // specified in the asm file.
+  if (hadNames || parser.getNumResults() == 0)
+    return success();
+
+  SmallVector<StringRef, 4> names;
+  auto *context = result.getContext();
+
+  for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
+    auto resultName = parser.getResultName(i);
+    StringRef nameStr;
+    if (!resultName.first.empty() && !isdigit(resultName.first[0]))
+      nameStr = resultName.first;
+
+    names.push_back(nameStr);
+  }
+
+  auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
+  result.attributes.push_back({Identifier::get("names", context), namesAttr});
+  return success();
+}
+
+static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
+  p << "test.string_attr_pretty_name";
+
+  // Note that we only need to print the "name" attribute if the asmprinter
+  // result name disagrees with it.  This can happen in strange cases, e.g.
+  // when there are conflicts.
+  bool namesDisagree = op.names().size() != op.getNumResults();
+
+  SmallString<32> resultNameStr;
+  for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
+    resultNameStr.clear();
+    llvm::raw_svector_ostream tmpStream(resultNameStr);
+    p.printOperand(op.getResult(i), tmpStream);
+
+    auto expectedName = op.names()[i].dyn_cast<StringAttr>();
+    if (!expectedName ||
+        tmpStream.str().drop_front() != expectedName.getValue()) {
+      namesDisagree = true;
+    }
+  }
+
+  if (namesDisagree)
+    p.printOptionalAttrDictWithKeyword(op.getAttrs());
+  else
+    p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
+}
+
+// We set the SSA name in the asm syntax to the contents of the name
+// attribute.
+void StringAttrPrettyNameOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+
+  auto value = names();
+  for (size_t i = 0, e = value.size(); i != e; ++i)
+    if (auto str = value[i].dyn_cast<StringAttr>())
+      if (!str.getValue().empty())
+        setNameFn(getResult(i), str.getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index cf0ec63fe6c5..80783001a4c9 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -496,6 +496,18 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
   );
 }
 
+// This is used to test encoding of a string attribute into an SSA name of a
+// pretty printed value name.
+def StringAttrPrettyNameOp
+ : TEST_Op<"string_attr_pretty_name",
+           [DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
+  let arguments = (ins StrArrayAttr:$names);
+  let results = (outs Variadic<I32>:$r);
+
+  let printer = [{ return ::print(p, *this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list