[Mlir-commits] [mlir] 3043be9 - [IR] Add a Location to BlockArgument.

Chris Lattner llvmlistbot at llvm.org
Tue May 18 10:18:13 PDT 2021


Author: Chris Lattner
Date: 2021-05-18T10:18:04-07:00
New Revision: 3043be9d2db4d0cdf079adb5e1bdff032405e941

URL: https://github.com/llvm/llvm-project/commit/3043be9d2db4d0cdf079adb5e1bdff032405e941
DIFF: https://github.com/llvm/llvm-project/commit/3043be9d2db4d0cdf079adb5e1bdff032405e941.diff

LOG: [IR] Add a Location to BlockArgument.

This adds the ability to specify a location when creating BlockArguments.
Notably Value::getLoc() will return this correctly, which makes diagnostics
more precise (e.g. the example in test-legalize-type-conversion.mlir).

This is currently optional to avoid breaking any existing code - if
absent, the BlockArgument defaults to using the location of its enclosing
operation (preserving existing behavior).

The bulk of this change is plumbing location tracking through the parser
and printer to make sure it can round trip (in -mlir-print-debuginfo
mode).  This is complete for generic operations, but requires manual
adoption for custom ops.

I added support for function-like ops to round trip their argument
locations - they print correctly, but when parsing the locations are
dropped on the floor.  I intend to fix this, but it will require more
invasive plumbing through "function_like_impl" stuff so I think it
best to split it out to its own patch.

Differential Revision: https://reviews.llvm.org/D102567

Added: 
    

Modified: 
    mlir/include/mlir/IR/Block.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Block.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/Value.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/locations.mlir
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/mlir-tblgen/pattern.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index f5436ef49d3ca..cafcf585bf2d6 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -22,7 +22,8 @@ class BitVector;
 
 namespace mlir {
 class TypeRange;
-template <typename ValueRangeT> class ValueTypeRange;
+template <typename ValueRangeT>
+class ValueTypeRange;
 
 /// `Block` represents an ordered list of `Operation`s.
 class Block : public IRObjectWithUseList<BlockOperand>,
@@ -87,18 +88,21 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   bool args_empty() { return arguments.empty(); }
 
   /// Add one value to the argument list.
-  BlockArgument addArgument(Type type);
+  BlockArgument addArgument(Type type, Optional<Location> loc = {});
 
   /// Insert one value to the position in the argument list indicated by the
   /// given iterator. The existing arguments are shifted. The block is expected
   /// not to have predecessors.
-  BlockArgument insertArgument(args_iterator it, Type type);
+  BlockArgument insertArgument(args_iterator it, Type type,
+                               Optional<Location> loc = {});
 
   /// Add one argument to the argument list for each type specified in the list.
-  iterator_range<args_iterator> addArguments(TypeRange types);
+  iterator_range<args_iterator> addArguments(TypeRange types,
+                                             ArrayRef<Location> locs = {});
 
   /// Add one value to the argument list at the specified position.
-  BlockArgument insertArgument(unsigned index, Type type);
+  BlockArgument insertArgument(unsigned index, Type type,
+                               Optional<Location> loc = {});
 
   /// Erase the argument at 'index' and remove it from the argument list.
   void eraseArgument(unsigned index);
@@ -177,15 +181,18 @@ class Block : public IRObjectWithUseList<BlockOperand>,
 
   /// Return an iterator range over the operations within this block that are of
   /// 'OpT'.
-  template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
+  template <typename OpT>
+  iterator_range<op_iterator<OpT>> getOps() {
     auto endIt = end();
     return {detail::op_filter_iterator<OpT, iterator>(begin(), endIt),
             detail::op_filter_iterator<OpT, iterator>(endIt, endIt)};
   }
-  template <typename OpT> op_iterator<OpT> op_begin() {
+  template <typename OpT>
+  op_iterator<OpT> op_begin() {
     return detail::op_filter_iterator<OpT, iterator>(begin(), end());
   }
-  template <typename OpT> op_iterator<OpT> op_end() {
+  template <typename OpT>
+  op_iterator<OpT> op_end() {
     return detail::op_filter_iterator<OpT, iterator>(end(), end());
   }
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1e0863c7a7a42..1788fa715c2fb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -79,7 +79,8 @@ class Builder {
   NoneType getNoneType();
 
   /// Get or construct an instance of the type 'ty' with provided arguments.
-  template <typename Ty, typename... Args> Ty getType(Args... args) {
+  template <typename Ty, typename... Args>
+  Ty getType(Args... args) {
     return Ty::get(context, args...);
   }
 
@@ -372,11 +373,13 @@ class OpBuilder : public Builder {
   /// end of it. The block is inserted at the provided insertion point of
   /// 'parent'.
   Block *createBlock(Region *parent, Region::iterator insertPt = {},
-                     TypeRange argTypes = llvm::None);
+                     TypeRange argTypes = llvm::None,
+                     ArrayRef<Location> locs = {});
 
   /// Add new block with 'argTypes' arguments and set the insertion point to the
   /// end of it. The block is placed before 'insertBefore'.
-  Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
+  Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None,
+                     ArrayRef<Location> locs = {});
 
   //===--------------------------------------------------------------------===//
   // Operation Creation
@@ -472,7 +475,8 @@ class OpBuilder : public Builder {
   Operation *cloneWithoutRegions(Operation &op) {
     return insert(op.cloneWithoutRegions());
   }
-  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+  template <typename OpT>
+  OpT cloneWithoutRegions(OpT op) {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index db1e98d95b966..a9c96473a3a1e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -40,6 +40,15 @@ class OpAsmPrinter {
   /// operation.
   virtual void printNewline() = 0;
 
+  /// Print a block argument in the usual format of:
+  ///   %ssaName : type {attr1=42} loc("here")
+  /// where location printing is controlled by the standard internal option.
+  /// You may pass omitType=true to not print a type, and pass an empty
+  /// attribute list if you don't care for attributes.
+  virtual void printRegionArgument(BlockArgument arg,
+                                   ArrayRef<NamedAttribute> argAttrs = {},
+                                   bool omitType = false) = 0;
+
   /// Print implementations for various things an operation contains.
   virtual void printOperand(Value value) = 0;
   virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -578,6 +587,10 @@ class OpAsmParser {
                                               StringRef attrName,
                                               NamedAttrList &attrs) = 0;
 
+  /// Parse a loc(...) specifier if present, filling in result if so.
+  virtual ParseResult
+  parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
+
   //===--------------------------------------------------------------------===//
   // Operand Parsing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index bd80b2b582d11..1901802e3dff9 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -85,7 +85,8 @@ class Value {
   Value(const Value &) = default;
   Value &operator=(const Value &) = default;
 
-  template <typename U> bool isa() const {
+  template <typename U>
+  bool isa() const {
     assert(*this && "isa<> used on a null type.");
     return U::classof(*this);
   }
@@ -94,13 +95,16 @@ class Value {
   bool isa() const {
     return isa<First>() || isa<Second, Rest...>();
   }
-  template <typename U> U dyn_cast() const {
+  template <typename U>
+  U dyn_cast() const {
     return isa<U>() ? U(impl) : U(nullptr);
   }
-  template <typename U> U dyn_cast_or_null() const {
+  template <typename U>
+  U dyn_cast_or_null() const {
     return (*this && isa<U>()) ? U(impl) : U(nullptr);
   }
-  template <typename U> U cast() const {
+  template <typename U>
+  U cast() const {
     assert(isa<U>());
     return U(impl);
   }
@@ -134,9 +138,9 @@ class Value {
     return llvm::dyn_cast_or_null<OpTy>(getDefiningOp());
   }
 
-  /// If this value is the result of an operation, use it as a location,
-  /// otherwise return an unknown location.
+  /// Return the location of this value.
   Location getLoc() const;
+  void setLoc(Location loc);
 
   /// Return the Region in which this Value is defined.
   Region *getParentRegion();
@@ -250,8 +254,9 @@ class BlockArgumentImpl : public ValueImpl {
   }
 
 private:
-  BlockArgumentImpl(Type type, Block *owner, int64_t index)
-      : ValueImpl(type, Kind::BlockArgument), owner(owner), index(index) {}
+  BlockArgumentImpl(Type type, Block *owner, int64_t index, Location loc)
+      : ValueImpl(type, Kind::BlockArgument), owner(owner), index(index),
+        loc(loc) {}
 
   /// The owner of this argument.
   Block *owner;
@@ -259,6 +264,9 @@ class BlockArgumentImpl : public ValueImpl {
   /// The position in the argument list.
   int64_t index;
 
+  /// The source location of this argument.
+  Location loc;
+
   /// Allow access to owner and constructor.
   friend BlockArgument;
 };
@@ -279,10 +287,15 @@ class BlockArgument : public Value {
   /// Returns the number of this argument.
   unsigned getArgNumber() const { return getImpl()->index; }
 
+  /// Return the location for this argument.
+  Location getLoc() const { return getImpl()->loc; }
+  void setLoc(Location loc) { getImpl()->loc = loc; }
+
 private:
   /// Allocate a new argument with the given type and owner.
-  static BlockArgument create(Type type, Block *owner, int64_t index) {
-    return new detail::BlockArgumentImpl(type, owner, index);
+  static BlockArgument create(Type type, Block *owner, int64_t index,
+                              Location loc) {
+    return new detail::BlockArgumentImpl(type, owner, index, loc);
   }
 
   /// Destroy and deallocate this argument.
@@ -426,7 +439,8 @@ inline ::llvm::hash_code hash_value(Value arg) {
 
 namespace llvm {
 
-template <> struct DenseMapInfo<mlir::Value> {
+template <>
+struct DenseMapInfo<mlir::Value> {
   static mlir::Value getEmptyKey() {
     void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return mlir::Value::getFromOpaquePointer(pointer);
@@ -453,7 +467,8 @@ struct DenseMapInfo<mlir::BlockArgument> : public DenseMapInfo<mlir::Value> {
 };
 
 /// Allow stealing the low bits of a value.
-template <> struct PointerLikeTypeTraits<mlir::Value> {
+template <>
+struct PointerLikeTypeTraits<mlir::Value> {
 public:
   static inline void *getAsVoidPointer(mlir::Value value) {
     return const_cast<void *>(value.getAsOpaquePointer());

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d36f5839d7c32..df803c889ccf9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -485,6 +485,8 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   void printSuccessor(Block *) override {}
   void printSuccessorAndUseList(Block *, ValueRange) override {}
   void shadowRegionArgs(Region &, ValueRange) override {}
+  void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
+                           bool omitType) override {}
 
   /// The printer flags to use when determining potential aliases.
   const OpPrintingFlags &printerFlags;
@@ -2345,6 +2347,15 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
     ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
   }
 
+  /// Print a block argument in the usual format of:
+  ///   %ssaName : type {attr1=42} loc("here")
+  /// where location printing is controlled by the standard internal option.
+  /// You may pass omitType=true to not print a type, and pass an empty
+  /// attribute list if you don't care for attributes.
+  void printRegionArgument(BlockArgument arg,
+                           ArrayRef<NamedAttribute> argAttrs = {},
+                           bool omitType = false) override;
+
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
   void printOperand(Value value, raw_ostream &os) override {
@@ -2419,6 +2430,23 @@ void OperationPrinter::printTopLevelOperation(Operation *op) {
   state->getAliasState().printDeferredAliases(os, newLine);
 }
 
+/// Print a block argument in the usual format of:
+///   %ssaName : type {attr1=42} loc("here")
+/// where location printing is controlled by the standard internal option.
+/// You may pass omitType=true to not print a type, and pass an empty
+/// attribute list if you don't care for attributes.
+void OperationPrinter::printRegionArgument(BlockArgument arg,
+                                           ArrayRef<NamedAttribute> argAttrs,
+                                           bool omitType) {
+  printOperand(arg);
+  if (!omitType) {
+    os << ": ";
+    printType(arg.getType());
+  }
+  printOptionalAttrDict(argAttrs);
+  printTrailingLocation(arg.getLoc());
+}
+
 void OperationPrinter::print(Operation *op) {
   // Track the location of this operation.
   state->registerOperationLocation(op, newLine.curLine, currentIndent);
@@ -2529,6 +2557,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
         printValueID(arg);
         os << ": ";
         printType(arg.getType());
+        printTrailingLocation(arg.getLoc());
       });
       os << ')';
     }
@@ -2700,7 +2729,7 @@ void IntegerSet::print(raw_ostream &os) const {
 void Value::print(raw_ostream &os) {
   if (auto *op = getDefiningOp())
     return op->print(os);
-  // TODO: Improve this.
+  // TODO: Improve BlockArgument print'ing.
   BlockArgument arg = this->cast<BlockArgument>();
   os << "<block argument> of type '" << arg.getType()
      << "' at index: " << arg.getArgNumber() << '\n';
@@ -2709,7 +2738,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
   if (auto *op = getDefiningOp())
     return op->print(os, state);
 
-  // TODO: Improve this.
+  // TODO: Improve BlockArgument print'ing.
   BlockArgument arg = this->cast<BlockArgument>();
   os << "<block argument> of type '" << arg.getType()
      << "' at index: " << arg.getArgNumber() << '\n';

diff  --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index d390d490b176d..d053788067763 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -138,23 +138,54 @@ auto Block::getArgumentTypes() -> ValueTypeRange<BlockArgListType> {
   return ValueTypeRange<BlockArgListType>(getArguments());
 }
 
-BlockArgument Block::addArgument(Type type) {
-  BlockArgument arg = BlockArgument::create(type, this, arguments.size());
+BlockArgument Block::addArgument(Type type, Optional<Location> loc) {
+  // TODO: Require locations for BlockArguments.
+  if (!loc.hasValue()) {
+    // Use the location of the parent operation if the block is attached.
+    if (Operation *parentOp = getParentOp())
+      loc = parentOp->getLoc();
+    else
+      loc = UnknownLoc::get(type.getContext());
+  }
+
+  BlockArgument arg = BlockArgument::create(type, this, arguments.size(), *loc);
   arguments.push_back(arg);
   return arg;
 }
 
 /// Add one argument to the argument list for each type specified in the list.
-auto Block::addArguments(TypeRange types) -> iterator_range<args_iterator> {
+auto Block::addArguments(TypeRange types, ArrayRef<Location> locs)
+    -> iterator_range<args_iterator> {
+  // TODO: Require locations for BlockArguments.
+  assert((locs.empty() || types.size() == locs.size()) &&
+         "incorrect number of block argument locations");
   size_t initialSize = arguments.size();
+
   arguments.reserve(initialSize + types.size());
-  for (auto type : types)
-    addArgument(type);
+
+  // TODO: Require locations for BlockArguments.
+  if (locs.empty()) {
+    for (auto type : types)
+      addArgument(type);
+  } else {
+    for (auto typeAndLoc : llvm::zip(types, locs))
+      addArgument(std::get<0>(typeAndLoc), std::get<1>(typeAndLoc));
+  }
   return {arguments.data() + initialSize, arguments.data() + arguments.size()};
 }
 
-BlockArgument Block::insertArgument(unsigned index, Type type) {
-  auto arg = BlockArgument::create(type, this, index);
+BlockArgument Block::insertArgument(unsigned index, Type type,
+                                    Optional<Location> loc) {
+  // TODO: Require locations for BlockArguments.
+  if (!loc.hasValue()) {
+    // Use the location of the parent operation if the block is attached.
+    if (Operation *parentOp = getParentOp())
+      loc = parentOp->getLoc();
+    else
+      loc = UnknownLoc::get(type.getContext());
+  }
+
+  auto arg = BlockArgument::create(type, this, index, *loc);
   assert(index <= arguments.size());
   arguments.insert(arguments.begin() + index, arg);
   // Update the cached position for all the arguments after the newly inserted
@@ -167,10 +198,11 @@ BlockArgument Block::insertArgument(unsigned index, Type type) {
 
 /// Insert one value to the given position of the argument list. The existing
 /// arguments are shifted. The block is expected not to have predecessors.
-BlockArgument Block::insertArgument(args_iterator it, Type type) {
+BlockArgument Block::insertArgument(args_iterator it, Type type,
+                                    Optional<Location> loc) {
   assert(llvm::empty(getPredecessors()) &&
          "cannot insert arguments to blocks with predecessors");
-  return insertArgument(it->getArgNumber(), type);
+  return insertArgument(it->getArgNumber(), type, loc);
 }
 
 void Block::eraseArgument(unsigned index) {

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4f8aa9e820757..737ec74461993 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -371,13 +371,13 @@ Operation *OpBuilder::insert(Operation *op) {
 /// end of it. The block is inserted at the provided insertion point of
 /// 'parent'.
 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
-                              TypeRange argTypes) {
+                              TypeRange argTypes, ArrayRef<Location> locs) {
   assert(parent && "expected valid parent region");
   if (insertPt == Region::iterator())
     insertPt = parent->end();
 
   Block *b = new Block();
-  b->addArguments(argTypes);
+  b->addArguments(argTypes, locs);
   parent->getBlocks().insert(insertPt, b);
   setInsertionPointToEnd(b);
 
@@ -388,10 +388,11 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
 
 /// Add new block with 'argTypes' arguments and set the insertion point to the
 /// end of it.  The block is placed before 'insertBefore'.
-Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
+Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
+                              ArrayRef<Location> locs) {
   assert(insertBefore && "expected valid insertion block");
   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
-                     argTypes);
+                     argTypes, locs);
 }
 
 /// Create an operation given the fields represented as an OperationState.

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index aadf5456126be..a4ea0400003cb 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -59,6 +59,13 @@ ParseResult mlir::function_like_impl::parseFunctionArgumentList(
     if (!allowAttributes && !attrs.empty())
       return parser.emitError(loc, "expected arguments without attributes");
     argAttrs.push_back(attrs);
+
+    // Parse a location if specified.  TODO: Don't drop it on the floor.
+    Optional<Location> explicitLoc;
+    if (!argument.name.empty() &&
+        parser.parseOptionalLocationSpecifier(explicitLoc))
+      return failure();
+
     return success();
   };
 
@@ -298,13 +305,15 @@ void mlir::function_like_impl::printFunctionSignature(
       p << ", ";
 
     if (!isExternal) {
-      p.printOperand(body.getArgument(i));
-      p << ": ";
+      ArrayRef<NamedAttribute> attrs;
+      if (argAttrs)
+        attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
+      p.printRegionArgument(body.getArgument(i), attrs);
+    } else {
+      p.printType(argTypes[i]);
+      if (argAttrs)
+        p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
     }
-
-    p.printType(argTypes[i]);
-    if (argAttrs)
-      p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
   }
 
   if (isVariadic) {

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index a4baa93110019..1ff0bba64fdf7 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -27,10 +27,14 @@ Location Value::getLoc() const {
   if (auto *op = getDefiningOp())
     return op->getLoc();
 
-  // Use the location of the parent operation if this is a block argument.
-  // TODO: Should we just add locations to block arguments?
-  Operation *parentOp = cast<BlockArgument>().getOwner()->getParentOp();
-  return parentOp ? parentOp->getLoc() : UnknownLoc::get(getContext());
+  return cast<BlockArgument>().getLoc();
+}
+
+void Value::setLoc(Location loc) {
+  if (auto *op = getDefiningOp())
+    return op->setLoc(loc);
+
+  return cast<BlockArgument>().setLoc(loc);
 }
 
 /// Return the Region in which this Value is defined.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index fea26bcb2e32d..f88f94cc57d91 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -249,11 +249,17 @@ class OperationParser : public Parser {
   Operation *parseGenericOperation(Block *insertBlock,
                                    Block::iterator insertPt);
 
-  /// Parse an optional trailing location for the given operation.
+  /// This type is used to keep track of things that are either an Operation or
+  /// a BlockArgument.  We cannot use Value for this, because not all Operations
+  /// have results.
+  using OpOrArgument = llvm::PointerUnion<Operation *, BlockArgument>;
+
+  /// Parse an optional trailing location and add it to the specifier Operation
+  /// or `OperandType` if present.
   ///
   ///   trailing-location ::= (`loc` (`(` location `)` | attribute-alias))?
   ///
-  ParseResult parseTrailingOperationLocation(Operation *op);
+  ParseResult parseTrailingLocationSpecifier(OpOrArgument opOrArgument);
 
   /// This is the structure of a result specifier in the assembly syntax,
   /// including the name, number of results, and location.
@@ -385,7 +391,8 @@ class OperationParser : public Parser {
 
   /// A set of operations whose locations reference aliases that have yet to
   /// be resolved.
-  SmallVector<std::pair<Operation *, Token>, 8> opsWithDeferredLocs;
+  SmallVector<std::pair<OpOrArgument, Token>, 8>
+      opsAndArgumentsWithDeferredLocs;
 
   /// The builder used when creating parsed operation instances.
   OpBuilder opBuilder;
@@ -433,7 +440,7 @@ ParseResult OperationParser::finalize() {
 
   // Resolve the locations of any deferred operations.
   auto &attributeAliases = state.symbols.attributeAliasDefinitions;
-  for (std::pair<Operation *, Token> &it : opsWithDeferredLocs) {
+  for (std::pair<OpOrArgument, Token> &it : opsAndArgumentsWithDeferredLocs) {
     llvm::SMLoc tokLoc = it.second.getLoc();
     StringRef identifier = it.second.getSpelling().drop_front();
     Attribute attr = attributeAliases.lookup(identifier);
@@ -444,7 +451,11 @@ ParseResult OperationParser::finalize() {
     if (!locAttr)
       return emitError(tokLoc)
              << "expected location, but found '" << attr << "'";
-    it.first->setLoc(locAttr);
+    auto opOrArgument = it.first;
+    if (auto *op = opOrArgument.dyn_cast<Operation *>())
+      op->setLoc(locAttr);
+    else
+      opOrArgument.get<BlockArgument>().setLoc(locAttr);
   }
 
   // Pop the top level name scope.
@@ -963,7 +974,7 @@ Operation *OperationParser::parseGenericOperation() {
 
   // Create the operation and try to parse a location for it.
   Operation *op = opBuilder.createOperation(result);
-  if (parseTrailingOperationLocation(op))
+  if (parseTrailingLocationSpecifier(op))
     return nullptr;
   return op;
 }
@@ -1359,6 +1370,22 @@ class CustomOpAsmParser : public OpAsmParser {
     return success();
   }
 
+  /// Parse a loc(...) specifier if present, filling in result if so.
+  ParseResult
+  parseOptionalLocationSpecifier(Optional<Location> &result) override {
+    // If there is a 'loc' we parse a trailing location.
+    if (!parser.consumeIf(Token::kw_loc))
+      return success();
+    LocationAttr directLoc;
+    if (parser.parseToken(Token::l_paren, "expected '(' in location") ||
+        parser.parseLocationInstance(directLoc) ||
+        parser.parseToken(Token::r_paren, "expected ')' in location"))
+      return failure();
+
+    result = directLoc;
+    return success();
+  }
+
   //===--------------------------------------------------------------------===//
   // Operand Parsing
   //===--------------------------------------------------------------------===//
@@ -1846,12 +1873,13 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.createOperation(opState);
-  if (parseTrailingOperationLocation(op))
+  if (parseTrailingLocationSpecifier(op))
     return nullptr;
   return op;
 }
 
-ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) {
+ParseResult
+OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
   // If there is a 'loc' we parse a trailing location.
   if (!consumeIf(Token::kw_loc))
     return success();
@@ -1879,7 +1907,7 @@ ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) {
                << "expected location, but found '" << attr << "'";
     } else {
       // Otherwise, remember this operation and resolve its location later.
-      opsWithDeferredLocs.emplace_back(op, tok);
+      opsAndArgumentsWithDeferredLocs.emplace_back(opOrArgument, tok);
     }
 
     // Otherwise, we parse the location directly.
@@ -1890,8 +1918,12 @@ ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) {
   if (parseToken(Token::r_paren, "expected ')' in location"))
     return failure();
 
-  if (directLoc)
-    op->setLoc(directLoc);
+  if (directLoc) {
+    if (auto *op = opOrArgument.dyn_cast<Operation *>())
+      op->setLoc(directLoc);
+    else
+      opOrArgument.get<BlockArgument>().setLoc(directLoc);
+  }
   return success();
 }
 
@@ -1942,7 +1974,8 @@ ParseResult OperationParser::parseRegion(
                    .attachNote(getEncodedSourceLocation(*defLoc))
                << "previously referenced here";
       }
-      BlockArgument arg = block->addArgument(placeholderArgPair.second);
+      auto loc = getEncodedSourceLocation(placeholderArgPair.first.loc);
+      BlockArgument arg = block->addArgument(placeholderArgPair.second, loc);
 
       // Add a definition of this arg to the assembly state if provided.
       if (state.asmState)
@@ -2122,9 +2155,15 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
             if (arg.getType() != type)
               return emitError("argument and block argument type mismatch");
           } else {
-            arg = owner->addArgument(type);
+            auto loc = getEncodedSourceLocation(useInfo.loc);
+            arg = owner->addArgument(type, loc);
           }
 
+          // If the argument has an explicit loc(...) specifier, parse and apply
+          // it.
+          if (parseTrailingLocationSpecifier(arg))
+            return failure();
+
           // Mark this block argument definition in the parser state if it was
           // provided.
           if (state.asmState)

diff  --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 5ad854eedc10e..e4a46f84fe579 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -44,5 +44,25 @@ func @escape_strings() {
 // CHECK-ALIAS: "foo.op"() : () -> () loc(#[[LOC:.*]])
 "foo.op"() : () -> () loc(#loc)
 
+// CHECK-LABEL: func @argLocs(
+// CHECK-SAME:  %arg0: i32 loc({{.*}}locations.mlir":[[# @LINE+1]]:15),
+func @argLocs(%x: i32,
+// CHECK-SAME:  %arg1: i64 loc({{.*}}locations.mlir":[[# @LINE+1]]:15))
+              %y: i64 loc("hotdog")) {
+  return
+}
+
+// CHECK-LABEL: "foo.unknown_op_with_bbargs"()
+"foo.unknown_op_with_bbargs"() ({
+// CHECK-NEXT: ^bb0(%arg0: i32 loc({{.*}}locations.mlir":[[# @LINE+1]]:7),
+ ^bb0(%x: i32,
+// CHECK-SAME: %arg1: i32 loc("cheetos"),
+      %y: i32 loc("cheetos"),
+// CHECK-SAME: %arg2: i32 loc("out_of_line_location")):
+      %z: i32 loc(#loc)):
+    %1 = addi %x, %y : i32
+    "foo.yield"(%1) : (i32) -> ()
+  }) : () -> ()
+
 // CHECK-ALIAS: #[[LOC]] = loc("out_of_line_location")
 #loc = loc("out_of_line_location")

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 9ce69519006a7..e7ffb7ae6a3ee 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -1,7 +1,9 @@
 // RUN: mlir-opt %s -test-legalize-type-conversion -allow-unregistered-dialect -split-input-file -verify-diagnostics | FileCheck %s
 
-// expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
-func @test_invalid_arg_materialization(%arg0: i16) {
+
+func @test_invalid_arg_materialization(
+  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  %arg0: i16) {
   // expected-note at below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 6918f319198ca..affc3d7a93968 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -37,7 +37,7 @@ func @verifyZeroArg() -> i32 {
 }
 
 // CHECK-LABEL: testIgnoreArgMatch
-// CHECK-SAME: (%{{[a-z0-9]*}}: i32, %[[ARG1:[a-z0-9]*]]: i32
+// CHECK-SAME: (%{{[a-z0-9]*}}: i32 loc({{[^)]*}}), %[[ARG1:[a-z0-9]*]]: i32 loc({{[^)]*}}),
 func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
   // CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) {f = 15 : i64}
   "test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42, e = 24, f = 15} : (i32, i32, i32) -> ()
@@ -53,7 +53,7 @@ func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
 }
 
 // CHECK-LABEL: verifyInterleavedOperandAttribute
-// CHECK-SAME:    %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+// CHECK-SAME:    %[[ARG0:.*]]: i32 loc({{[^)]*}}), %[[ARG1:.*]]: i32 loc({{[^)]*}})
 func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) {
   // CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) {attr1 = 15 : i64, attr2 = 42 : i64}
   "test.interleaved_operand_attr1"(%arg0, %arg1) {attr1 = 15, attr2 = 42} : (i32, i32) -> ()
@@ -114,7 +114,7 @@ func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
 }
 
 // CHECK-LABEL: verifyManyArgs
-// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-SAME: (%[[ARG:.*]]: i32 loc({{[^)]*}}))
 func @verifyManyArgs(%arg: i32) {
   // CHECK: "test.many_arguments"(%[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]])
   // CHECK-SAME: {attr1 = 24 : i64, attr2 = 42 : i64, attr3 = 42 : i64, attr4 = 42 : i64, attr5 = 42 : i64, attr6 = 42 : i64, attr7 = 42 : i64, attr8 = 42 : i64, attr9 = 42 : i64}


        


More information about the Mlir-commits mailing list