[Mlir-commits] [mlir] 596da62 - Add support for custom op parser/printer hooks to know about result names.
Chris Lattner
llvmlistbot at llvm.org
Mon Mar 23 08:58:13 PDT 2020
Author: Chris Lattner
Date: 2020-03-23T08:58:05-07:00
New Revision: 596da62d21ede197dd95eca5146d2ddf0497275c
URL: https://github.com/llvm/llvm-project/commit/596da62d21ede197dd95eca5146d2ddf0497275c
DIFF: https://github.com/llvm/llvm-project/commit/596da62d21ede197dd95eca5146d2ddf0497275c.diff
LOG: Add support for custom op parser/printer hooks to know about result names.
Summary:
This allows the custom parser/printer hooks to do interesting things with
the SSA names. This patch:
- Adds a new 'getResultName' method to OpAsmParser that allows a parser
implementation to get information about its result names, along with
a getNumResults() method that allows op parser impls to know how many
results are expected.
- Adds a OpAsmPrinter::printOperand overload that takes an explicit stream.
- Adds a test.string_attr_pretty_name operation that uses these hooks to
do fancy things with the result name.
Reviewers: rriddle!
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76205
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 74b42243f236..f471e6ca0cc5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -37,6 +37,7 @@ class OpAsmPrinter {
/// Print implementations for various things an operation contains.
virtual void printOperand(Value value) = 0;
+ virtual void printOperand(Value value, raw_ostream &os) = 0;
/// Print a comma separated list of operands.
template <typename ContainerType>
@@ -245,6 +246,24 @@ class OpAsmParser {
return success();
}
+ /// Return the name of the specified result in the specified syntax, as well
+ /// as the sub-element in the name. It returns an empty string and ~0U for
+ /// invalid result numbers. For example, in this operation:
+ ///
+ /// %x, %y:2, %z = foo.op
+ ///
+ /// getResultName(0) == {"x", 0 }
+ /// getResultName(1) == {"y", 0 }
+ /// getResultName(2) == {"y", 1 }
+ /// getResultName(3) == {"z", 0 }
+ /// getResultName(4) == {"", ~0U }
+ virtual std::pair<StringRef, unsigned>
+ getResultName(unsigned resultNo) const = 0;
+
+ /// Return the number of declared SSA results. This returns 4 for the foo.op
+ /// example in the comment for `getResultName`.
+ virtual size_t getNumResults() const = 0;
+
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index bbb2462d176d..bf7905a14781 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -765,10 +765,10 @@ void SSANameState::setValueName(Value value, StringRef name) {
static bool isPunct(char c) {
return c == '$' || c == '.' || c == '_' || c == '-';
}
-
+
StringRef SSANameState::uniqueValueName(StringRef name) {
assert(!name.empty() && "Shouldn't have an empty name here");
-
+
// Check to see if this name is valid. If it starts with a digit, then it
// could conflict with the autogenerated numeric ID's (we unique them in a
//
diff erent map), so add an underscore prefix to avoid problems.
@@ -777,13 +777,13 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
tmpName += name;
return uniqueValueName(tmpName);
}
-
+
// Check to see if the name consists of all-valid identifiers. If not, we
// need to escape them.
for (char ch : name) {
if (isalpha(ch) || isPunct(ch) || isdigit(ch))
continue;
-
+
SmallString<16> tmpName;
for (char ch : name) {
if (isalpha(ch) || isPunct(ch) || isdigit(ch))
@@ -796,7 +796,7 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
}
return uniqueValueName(tmpName);
}
-
+
// Check to see if this name is already unique.
if (!usedNames.count(name)) {
name = name.copy(usedNameAllocator);
@@ -1963,7 +1963,8 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
bool printBlockTerminator = true);
/// Print the ID of the given value, optionally with its result number.
- void printValueID(Value value, bool printResultNo = true) const;
+ void printValueID(Value value, bool printResultNo = true,
+ raw_ostream *streamOverride = nullptr) const;
//===--------------------------------------------------------------------===//
// OpAsmPrinter methods
@@ -1988,6 +1989,9 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
/// Print the ID for the given value.
void printOperand(Value value) override { printValueID(value); }
+ void printOperand(Value value, raw_ostream &os) override {
+ printValueID(value, /*printResultNo=*/true, &os);
+ }
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
@@ -2195,8 +2199,10 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
currentIndent -= indentWidth;
}
-void OperationPrinter::printValueID(Value value, bool printResultNo) const {
- state->getSSANameState().printValueID(value, printResultNo, os);
+void OperationPrinter::printValueID(Value value, bool printResultNo,
+ raw_ostream *streamOverride) const {
+ state->getSSANameState().printValueID(value, printResultNo,
+ streamOverride ? *streamOverride : os);
}
void OperationPrinter::printSuccessor(Block *successor) {
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d987fd0d8add..a29b34b570d0 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3322,8 +3322,13 @@ class OperationParser : public Parser {
Operation *parseGenericOperation(Block *insertBlock,
Block::iterator insertPt);
+ /// This is the structure of a result specifier in the assembly syntax,
+ /// including the name, number of results, and location.
+ typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord;
+
/// Parse an operation instance that is in the op-defined custom form.
- Operation *parseCustomOperation();
+ /// resultInfo specifies information about the "%name =" specifiers.
+ Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo);
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -3728,7 +3733,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
///
ParseResult OperationParser::parseOperation() {
auto loc = getToken().getLoc();
- SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs;
+ SmallVector<ResultRecord, 1> resultIDs;
size_t numExpectedResults = 0;
if (getToken().is(Token::percent_identifier)) {
// Parse the group of result ids.
@@ -3769,7 +3774,7 @@ ParseResult OperationParser::parseOperation() {
Operation *op;
if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
- op = parseCustomOperation();
+ op = parseCustomOperation(resultIDs);
else if (getToken().is(Token::string))
op = parseGenericOperation();
else
@@ -3790,7 +3795,7 @@ ParseResult OperationParser::parseOperation() {
// Add definitions for each of the result groups.
unsigned opResI = 0;
- for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) {
+ for (ResultRecord &resIt : resultIDs) {
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)},
op->getResult(opResI++)))
@@ -3955,9 +3960,12 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
namespace {
class CustomOpAsmParser : public OpAsmParser {
public:
- CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition,
+ CustomOpAsmParser(SMLoc nameLoc,
+ ArrayRef<OperationParser::ResultRecord> resultIDs,
+ const AbstractOperation *opDefinition,
OperationParser &parser)
- : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {}
+ : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition),
+ parser(parser) {}
/// Parse an instance of the operation described by 'opDefinition' into the
/// provided operation state.
@@ -3992,6 +4000,41 @@ class CustomOpAsmParser : public OpAsmParser {
Builder &getBuilder() const override { return parser.builder; }
+ /// Return the name of the specified result in the specified syntax, as well
+ /// as the subelement in the name. For example, in this operation:
+ ///
+ /// %x, %y:2, %z = foo.op
+ ///
+ /// getResultName(0) == {"x", 0 }
+ /// getResultName(1) == {"y", 0 }
+ /// getResultName(2) == {"y", 1 }
+ /// getResultName(3) == {"z", 0 }
+ std::pair<StringRef, unsigned>
+ getResultName(unsigned resultNo) const override {
+ // Scan for the resultID that contains this result number.
+ for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) {
+ const auto &entry = resultIDs[nameID];
+ if (resultNo < std::get<1>(entry)) {
+ // Don't pass on the leading %.
+ StringRef name = std::get<0>(entry).drop_front();
+ return {name, resultNo};
+ }
+ resultNo -= std::get<1>(entry);
+ }
+
+ // Invalid result number.
+ return {"", ~0U};
+ }
+
+ /// Return the number of declared SSA results. This returns 4 for the foo.op
+ /// example in the comment for getResultName.
+ size_t getNumResults() const override {
+ size_t count = 0;
+ for (auto &entry : resultIDs)
+ count += std::get<1>(entry);
+ return count;
+ }
+
llvm::SMLoc getNameLoc() const override { return nameLoc; }
//===--------------------------------------------------------------------===//
@@ -4500,6 +4543,9 @@ class CustomOpAsmParser : public OpAsmParser {
/// The source location of the operation name.
SMLoc nameLoc;
+ /// Information about the result name specifiers.
+ ArrayRef<OperationParser::ResultRecord> resultIDs;
+
/// The abstract information of the operation.
const AbstractOperation *opDefinition;
@@ -4511,7 +4557,8 @@ class CustomOpAsmParser : public OpAsmParser {
};
} // end anonymous namespace.
-Operation *OperationParser::parseCustomOperation() {
+Operation *
+OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
auto opLoc = getToken().getLoc();
auto opName = getTokenSpelling();
@@ -4544,7 +4591,7 @@ Operation *OperationParser::parseCustomOperation() {
// Have the op implementation take a crack and parsing this.
OperationState opState(srcLocation, opDefinition->name);
CleanupOpStateRegions guard{opState};
- CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this);
+ CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this);
if (opAsmParser.parseOperation(opState))
return nullptr;
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index a6dc9b6617b3..253c1cc3e745 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1185,3 +1185,43 @@ func @custom_asm_names() -> (i32, i32, i32, i32, i32, i32, i32) {
// CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]]
return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32
}
+
+
+// CHECK-LABEL: func @pretty_names
+
+// This tests the behavior
+func @pretty_names() {
+ // Simple case, should parse and print as %x being an implied 'name'
+ // attribute.
+ %x = test.string_attr_pretty_name
+ // CHECK: %x = test.string_attr_pretty_name
+ // CHECK-NOT: attributes
+
+ // This specifies an explicit name, which should override the result.
+ %YY = test.string_attr_pretty_name attributes { names = ["y"] }
+ // CHECK: %y = test.string_attr_pretty_name
+ // CHECK-NOT: attributes
+
+ // Conflicts with the 'y' name, so need an explicit attribute.
+ %0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32
+ // CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]}
+
+ // Name contains a space.
+ %1 = "test.string_attr_pretty_name"() { names = ["space name"]} : () -> i32
+ // CHECK: %space_name = test.string_attr_pretty_name attributes {names = ["space name"]}
+
+ "unknown.use"(%x, %YY, %0, %1) : (i32, i32, i32, i32) -> ()
+
+ // Multi-result support.
+
+ %a, %b, %c = test.string_attr_pretty_name
+ // CHECK: %a, %b, %c = test.string_attr_pretty_name
+ // CHECK-NOT: attributes
+
+ %q:3, %r = test.string_attr_pretty_name
+ // CHECK: %q, %q_1, %q_2, %r = test.string_attr_pretty_name attributes {names = ["q", "q", "q", "r"]}
+
+ // CHECK: return
+ return
+}
+
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 166edf820206..21941f0da606 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -391,6 +391,87 @@ void SideEffectOp::getEffects(
}
}
+//===----------------------------------------------------------------------===//
+// StringAttrPrettyNameOp
+//===----------------------------------------------------------------------===//
+
+// This op has fancy handling of its SSA result name.
+static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Add the result types.
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
+ result.addTypes(parser.getBuilder().getIntegerType(32));
+
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // If the attribute dictionary contains no 'names' attribute, infer it from
+ // the SSA name (if specified).
+ bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
+ return attr.first.is("names");
+ });
+
+ // If there was no name specified, check to see if there was a useful name
+ // specified in the asm file.
+ if (hadNames || parser.getNumResults() == 0)
+ return success();
+
+ SmallVector<StringRef, 4> names;
+ auto *context = result.getContext();
+
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
+ auto resultName = parser.getResultName(i);
+ StringRef nameStr;
+ if (!resultName.first.empty() && !isdigit(resultName.first[0]))
+ nameStr = resultName.first;
+
+ names.push_back(nameStr);
+ }
+
+ auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
+ result.attributes.push_back({Identifier::get("names", context), namesAttr});
+ return success();
+}
+
+static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
+ p << "test.string_attr_pretty_name";
+
+ // Note that we only need to print the "name" attribute if the asmprinter
+ // result name disagrees with it. This can happen in strange cases, e.g.
+ // when there are conflicts.
+ bool namesDisagree = op.names().size() != op.getNumResults();
+
+ SmallString<32> resultNameStr;
+ for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
+ resultNameStr.clear();
+ llvm::raw_svector_ostream tmpStream(resultNameStr);
+ p.printOperand(op.getResult(i), tmpStream);
+
+ auto expectedName = op.names()[i].dyn_cast<StringAttr>();
+ if (!expectedName ||
+ tmpStream.str().drop_front() != expectedName.getValue()) {
+ namesDisagree = true;
+ }
+ }
+
+ if (namesDisagree)
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+ else
+ p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
+}
+
+// We set the SSA name in the asm syntax to the contents of the name
+// attribute.
+void StringAttrPrettyNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+
+ auto value = names();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = value[i].dyn_cast<StringAttr>())
+ if (!str.getValue().empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index cf0ec63fe6c5..80783001a4c9 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -496,6 +496,18 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
);
}
+// This is used to test encoding of a string attribute into an SSA name of a
+// pretty printed value name.
+def StringAttrPrettyNameOp
+ : TEST_Op<"string_attr_pretty_name",
+ [DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
+ let arguments = (ins StrArrayAttr:$names);
+ let results = (outs Variadic<I32>:$r);
+
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
//===----------------------------------------------------------------------===//
// Test Patterns
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list