[Mlir-commits] [mlir] f8ac313 - Add a new AsmParserState class to capture detailed source information for .mlir files

River Riddle llvmlistbot at llvm.org
Wed Apr 21 14:46:41 PDT 2021


Author: River Riddle
Date: 2021-04-21T14:44:37-07:00
New Revision: f8ac31314b4296ce8f33809154bc4ad726161e29

URL: https://github.com/llvm/llvm-project/commit/f8ac31314b4296ce8f33809154bc4ad726161e29
DIFF: https://github.com/llvm/llvm-project/commit/f8ac31314b4296ce8f33809154bc4ad726161e29.diff

LOG: Add a new AsmParserState class to capture detailed source information for .mlir files

This information isn't useful for general compilation, but is useful for building tools that process .mlir files. This class will be used in a followup to start building an LSP language server for MLIR.

Differential Revision: https://reviews.llvm.org/D100438

Added: 
    mlir/include/mlir/Parser/AsmParserState.h
    mlir/lib/Parser/AsmParserState.cpp

Modified: 
    mlir/include/mlir/Parser.h
    mlir/lib/Parser/CMakeLists.txt
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/lib/Parser/ParserState.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h
index 907f31824628b..a534dfcbe2771 100644
--- a/mlir/include/mlir/Parser.h
+++ b/mlir/include/mlir/Parser.h
@@ -24,6 +24,8 @@ class StringRef;
 } // end namespace llvm
 
 namespace mlir {
+class AsmParserState;
+
 namespace detail {
 
 /// Given a block containing operations that have just been parsed, if the block
@@ -77,10 +79,14 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
 /// returned. Otherwise, an error message is emitted through the error handler
 /// registered in the context, and failure is returned. If `sourceFileLoc` is
 /// non-null, it is populated with a file location representing the start of the
-/// source file that is being parsed.
+/// source file that is being parsed. If `asmState` is non-null, it is populated
+/// with detailed information about the parsed IR (including exact locations for
+/// SSA uses and definitions). `asmState` should only be provided if this
+/// detailed information is desired.
 LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
                               MLIRContext *context,
-                              LocationAttr *sourceFileLoc = nullptr);
+                              LocationAttr *sourceFileLoc = nullptr,
+                              AsmParserState *asmState = nullptr);
 
 /// This parses the file specified by the indicated filename and appends parsed
 /// operations to the given block. If the block is non-empty, the operations are
@@ -99,11 +105,15 @@ LogicalResult parseSourceFile(llvm::StringRef filename, Block *block,
 /// parsing is successful, success is returned. Otherwise, an error message is
 /// emitted through the error handler registered in the context, and failure is
 /// returned. If `sourceFileLoc` is non-null, it is populated with a file
-/// location representing the start of the source file that is being parsed.
+/// location representing the start of the source file that is being parsed. If
+/// `asmState` is non-null, it is populated with detailed information about the
+/// parsed IR (including exact locations for SSA uses and definitions).
+/// `asmState` should only be provided if this detailed information is desired.
 LogicalResult parseSourceFile(llvm::StringRef filename,
                               llvm::SourceMgr &sourceMgr, Block *block,
                               MLIRContext *context,
-                              LocationAttr *sourceFileLoc = nullptr);
+                              LocationAttr *sourceFileLoc = nullptr,
+                              AsmParserState *asmState = nullptr);
 
 /// This parses the IR string and appends parsed operations to the given block.
 /// If the block is non-empty, the operations are placed before the current

diff  --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h
new file mode 100644
index 0000000000000..318ef174fec61
--- /dev/null
+++ b/mlir/include/mlir/Parser/AsmParserState.h
@@ -0,0 +1,131 @@
+//===- AsmParserState.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PARSER_ASMPARSERSTATE_H
+#define MLIR_PARSER_ASMPARSERSTATE_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/SMLoc.h"
+#include <cstddef>
+
+namespace mlir {
+class Block;
+class BlockArgument;
+class FileLineColLoc;
+class Operation;
+class Value;
+
+/// This class represents state from a parsed MLIR textual format string. It is
+/// useful for building additional analysis and language utilities on top of
+/// textual MLIR. This should generally not be used for traditional compilation.
+class AsmParserState {
+public:
+  /// This class represents a definition within the source manager, containing
+  /// it's defining location and locations of any uses. SMDefinitions are only
+  /// provided for entities that have uses within an input file, e.g. SSA
+  /// values, Blocks, and Symbols.
+  struct SMDefinition {
+    SMDefinition() = default;
+    SMDefinition(llvm::SMRange loc) : loc(loc) {}
+
+    /// The source location of the definition.
+    llvm::SMRange loc;
+    /// The source location of all uses of the definition.
+    SmallVector<llvm::SMRange> uses;
+  };
+
+  /// This class represents the information for an operation definition within
+  /// an input file.
+  struct OperationDefinition {
+    struct ResultGroupDefinition {
+      /// The result number that starts this group.
+      unsigned startIndex;
+      /// The source definition of the result group.
+      SMDefinition definition;
+    };
+
+    OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {}
+
+    /// The operation representing this definition.
+    Operation *op;
+
+    /// The source location for the operation, i.e. the location of its name.
+    llvm::SMRange loc;
+
+    /// Source definitions for any result groups of this operation.
+    SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;
+  };
+
+  /// This class represents the information for a block definition within the
+  /// input file.
+  struct BlockDefinition {
+    BlockDefinition(Block *block, llvm::SMRange loc = {})
+        : block(block), definition(loc) {}
+
+    /// The block representing this definition.
+    Block *block;
+
+    /// The source location for the block, i.e. the location of its name, and
+    /// any uses it has.
+    SMDefinition definition;
+
+    /// Source definitions for any arguments of this block.
+    SmallVector<SMDefinition> arguments;
+  };
+
+  AsmParserState();
+  ~AsmParserState();
+
+  //===--------------------------------------------------------------------===//
+  // Access State
+  //===--------------------------------------------------------------------===//
+
+  using BlockDefIterator = llvm::pointee_iterator<
+      ArrayRef<std::unique_ptr<BlockDefinition>>::iterator>;
+  using OperationDefIterator = llvm::pointee_iterator<
+      ArrayRef<std::unique_ptr<OperationDefinition>>::iterator>;
+
+  /// Return a range of the BlockDefinitions held by the current parser state.
+  iterator_range<BlockDefIterator> getBlockDefs() const;
+
+  /// Return a range of the OperationDefinitions held by the current parser
+  /// state.
+  iterator_range<OperationDefIterator> getOpDefs() const;
+
+  //===--------------------------------------------------------------------===//
+  // Populate State
+  //===--------------------------------------------------------------------===//
+
+  /// Add a definition of the given operation.
+  void addDefinition(
+      Operation *op, llvm::SMRange location,
+      ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);
+  void addDefinition(Block *block, llvm::SMLoc location);
+  void addDefinition(BlockArgument blockArg, llvm::SMLoc location);
+
+  /// Add a source uses of the given value.
+  void addUses(Value value, ArrayRef<llvm::SMLoc> locations);
+  void addUses(Block *block, ArrayRef<llvm::SMLoc> locations);
+
+  /// Refine the `oldValue` to the `newValue`. This is used to indicate that
+  /// `oldValue` was a placeholder, and the uses of it should really refer to
+  /// `newValue`.
+  void refineDefinition(Value oldValue, Value newValue);
+
+private:
+  struct Impl;
+
+  /// A pointer to the internal implementation of this class.
+  std::unique_ptr<Impl> impl;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_PARSER_ASMPARSERSTATE_H

diff  --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp
new file mode 100644
index 0000000000000..3fdb1c2df81de
--- /dev/null
+++ b/mlir/lib/Parser/AsmParserState.cpp
@@ -0,0 +1,168 @@
+//===- AsmParserState.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Parser/AsmParserState.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+
+/// Given a SMLoc corresponding to an identifier location, return a location
+/// representing the full range of the identifier.
+static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc) {
+  if (!loc.isValid())
+    return llvm::SMRange();
+
+  // Return if the given character is a valid identifier character.
+  auto isIdentifierChar = [](char c) {
+    return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
+  };
+
+  const char *curPtr = loc.getPointer();
+  while (isIdentifierChar(*(++curPtr)))
+    continue;
+  return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr));
+}
+
+//===----------------------------------------------------------------------===//
+// AsmParserState::Impl
+//===----------------------------------------------------------------------===//
+
+struct AsmParserState::Impl {
+  /// A mapping from operations in the input source file to their parser state.
+  SmallVector<std::unique_ptr<OperationDefinition>> operations;
+  DenseMap<Operation *, unsigned> operationToIdx;
+
+  /// A mapping from blocks in the input source file to their parser state.
+  SmallVector<std::unique_ptr<BlockDefinition>> blocks;
+  DenseMap<Block *, unsigned> blocksToIdx;
+
+  /// A set of value definitions that are placeholders for forward references.
+  /// This map should be empty if the parser finishes successfully.
+  DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
+};
+
+//===----------------------------------------------------------------------===//
+// AsmParserState
+//===----------------------------------------------------------------------===//
+
+AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
+AsmParserState::~AsmParserState() {}
+
+//===----------------------------------------------------------------------===//
+// Access State
+
+auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
+  return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
+}
+
+auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
+  return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
+}
+
+//===----------------------------------------------------------------------===//
+// Populate State
+
+void AsmParserState::addDefinition(
+    Operation *op, llvm::SMRange location,
+    ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
+  std::unique_ptr<OperationDefinition> def =
+      std::make_unique<OperationDefinition>(op, location);
+  for (auto &resultGroup : resultGroups)
+    def->resultGroups.emplace_back(resultGroup.first,
+                                   convertIdLocToRange(resultGroup.second));
+
+  impl->operationToIdx.try_emplace(op, impl->operations.size());
+  impl->operations.emplace_back(std::move(def));
+}
+
+void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
+  auto it = impl->blocksToIdx.find(block);
+  if (it == impl->blocksToIdx.end()) {
+    impl->blocksToIdx.try_emplace(block, impl->blocks.size());
+    impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
+        block, convertIdLocToRange(location)));
+    return;
+  }
+
+  // If an entry already exists, this was a forward declaration that now has a
+  // proper definition.
+  impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
+}
+
+void AsmParserState::addDefinition(BlockArgument blockArg,
+                                   llvm::SMLoc location) {
+  auto it = impl->blocksToIdx.find(blockArg.getOwner());
+  assert(it != impl->blocksToIdx.end() &&
+         "expected owner block to have an entry");
+  BlockDefinition &def = *impl->blocks[it->second];
+  unsigned argIdx = blockArg.getArgNumber();
+
+  if (def.arguments.size() <= argIdx)
+    def.arguments.resize(argIdx + 1);
+  def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
+}
+
+void AsmParserState::addUses(Value value, ArrayRef<llvm::SMLoc> locations) {
+  // Handle the case where the value is an operation result.
+  if (OpResult result = value.dyn_cast<OpResult>()) {
+    // Check to see if a definition for the parent operation has been recorded.
+    // If one hasn't, we treat the provided value as a placeholder value that
+    // will be refined further later.
+    Operation *parentOp = result.getOwner();
+    auto existingIt = impl->operationToIdx.find(parentOp);
+    if (existingIt == impl->operationToIdx.end()) {
+      impl->placeholderValueUses[value].append(locations.begin(),
+                                               locations.end());
+      return;
+    }
+
+    // If a definition does exist, locate the value's result group and add the
+    // use. The result groups are ordered by increasing start index, so we just
+    // need to find the last group that has a smaller/equal start index.
+    unsigned resultNo = result.getResultNumber();
+    OperationDefinition &def = *impl->operations[existingIt->second];
+    for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
+      if (resultNo >= resultGroup.first) {
+        for (llvm::SMLoc loc : locations)
+          resultGroup.second.uses.push_back(convertIdLocToRange(loc));
+        return;
+      }
+    }
+    llvm_unreachable("expected valid result group for value use");
+  }
+
+  // Otherwise, this is a block argument.
+  BlockArgument arg = value.cast<BlockArgument>();
+  auto existingIt = impl->blocksToIdx.find(arg.getOwner());
+  assert(existingIt != impl->blocksToIdx.end() &&
+         "expected valid block definition for block argument");
+  BlockDefinition &blockDef = *impl->blocks[existingIt->second];
+  SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
+  for (llvm::SMLoc loc : locations)
+    argDef.uses.emplace_back(convertIdLocToRange(loc));
+}
+
+void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
+  auto it = impl->blocksToIdx.find(block);
+  if (it == impl->blocksToIdx.end()) {
+    it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
+    impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
+  }
+
+  BlockDefinition &def = *impl->blocks[it->second];
+  for (llvm::SMLoc loc : locations)
+    def.definition.uses.push_back(convertIdLocToRange(loc));
+}
+
+void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
+  auto it = impl->placeholderValueUses.find(oldValue);
+  assert(it != impl->placeholderValueUses.end() &&
+         "expected `oldValue` to be a placeholder");
+  addUses(newValue, it->second);
+  impl->placeholderValueUses.erase(oldValue);
+}

diff  --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
index 4d68c5c839c9b..ad272a475dd61 100644
--- a/mlir/lib/Parser/CMakeLists.txt
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRParser
   AffineParser.cpp
+  AsmParserState.cpp
   AttributeParser.cpp
   DialectSymbolParser.cpp
   Lexer.cpp

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 5f8255fe9f3c1..e9d79c6913095 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -490,7 +490,7 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
       inputStr, /*BufferName=*/"<mlir_parser_buffer>",
       /*RequiresNullTerminator=*/false);
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
-  ParserState state(sourceMgr, context, symbolState);
+  ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
   Parser parser(state);
 
   Token startTok = parser.getToken();

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 1587c0700216d..df3b01d682356 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser.h"
+#include "mlir/Parser/AsmParserState.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/bit.h"
@@ -213,8 +214,8 @@ class OperationParser : public Parser {
     auto &values = isolatedNameScopes.back().values;
     if (!values.count(name) || number >= values[name].size())
       return {};
-    if (values[name][number].first)
-      return values[name][number].second;
+    if (values[name][number].value)
+      return values[name][number].loc;
     return {};
   }
 
@@ -278,8 +279,7 @@ class OperationParser : public Parser {
   ParseResult parseBlockBody(Block *block);
 
   /// Parse a (possibly empty) list of block arguments.
-  ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results,
-                                        Block *owner);
+  ParseResult parseOptionalBlockArgList(Block *owner);
 
   /// Get the block with the specified name, creating it if it doesn't
   /// already exist.  The location specified is the point of use, which allows
@@ -291,8 +291,23 @@ class OperationParser : public Parser {
   Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
 
 private:
+  /// This class represents a definition of a Block.
+  struct BlockDefinition {
+    /// A pointer to the defined Block.
+    Block *block;
+    /// The location that the Block was defined at.
+    SMLoc loc;
+  };
+  /// This class represents a definition of a Value.
+  struct ValueDefinition {
+    /// A pointer to the defined Value.
+    Value value;
+    /// The location that the Value was defined at.
+    SMLoc loc;
+  };
+
   /// Returns the info for a block at the current scope for the given name.
-  std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
+  BlockDefinition &getBlockInfoByName(StringRef name) {
     return blocksByName.back()[name];
   }
 
@@ -308,7 +323,7 @@ class OperationParser : public Parser {
   void recordDefinition(StringRef def);
 
   /// Get the value entry for the given SSA name.
-  SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name);
+  SmallVectorImpl<ValueDefinition> &getSSAValueEntry(StringRef name);
 
   /// Create a forward reference placeholder value with the given location and
   /// result type.
@@ -340,7 +355,7 @@ class OperationParser : public Parser {
 
     /// This keeps track of all of the SSA values we are tracking for each name
     /// scope, indexed by their name. This has one entry per result number.
-    llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values;
+    llvm::StringMap<SmallVector<ValueDefinition, 1>> values;
 
     /// This keeps track of all of the values defined by a specific name scope.
     SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
@@ -352,7 +367,7 @@ class OperationParser : public Parser {
   /// This keeps track of the block names as well as the location of the first
   /// reference for each nested name scope. This is used to diagnose invalid
   /// block references and memorize them.
-  SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName;
+  SmallVector<DenseMap<StringRef, BlockDefinition>, 2> blocksByName;
   SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef;
 
   /// These are all of the placeholders we've made along with the location of
@@ -408,7 +423,7 @@ ParseResult OperationParser::finalize() {
   }
 
   // Resolve the locations of any deferred operations.
-  auto &attributeAliases = getState().symbols.attributeAliasDefinitions;
+  auto &attributeAliases = state.symbols.attributeAliasDefinitions;
   for (std::pair<Operation *, Token> &it : opsWithDeferredLocs) {
     llvm::SMLoc tokLoc = it.second.getLoc();
     StringRef identifier = it.second.getSpelling().drop_front();
@@ -432,7 +447,7 @@ ParseResult OperationParser::finalize() {
 //===----------------------------------------------------------------------===//
 
 void OperationParser::pushSSANameScope(bool isIsolated) {
-  blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>());
+  blocksByName.push_back(DenseMap<StringRef, BlockDefinition>());
   forwardRef.push_back(DenseMap<Block *, SMLoc>());
 
   // Push back a new name definition scope.
@@ -484,11 +499,11 @@ ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) {
 
   // If we already have an entry for this, check to see if it was a definition
   // or a forward reference.
-  if (auto existing = entries[useInfo.number].first) {
+  if (auto existing = entries[useInfo.number].value) {
     if (!isForwardRefPlaceholder(existing)) {
       return emitError(useInfo.loc)
           .append("redefinition of SSA value '", useInfo.name, "'")
-          .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+          .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc))
           .append("previously defined here");
     }
 
@@ -496,7 +511,7 @@ ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) {
       return emitError(useInfo.loc)
           .append("definition of SSA value '", useInfo.name, "#",
                   useInfo.number, "' has type ", value.getType())
-          .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+          .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc))
           .append("previously used here with type ", existing.getType());
     }
 
@@ -506,6 +521,11 @@ ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) {
     existing.replaceAllUsesWith(value);
     existing.getDefiningOp()->destroy();
     forwardRefPlaceholders.erase(existing);
+
+    // If a definition of the value already exists, replace it in the assembly
+    // state.
+    if (state.asmState)
+      state.asmState->refineDefinition(existing, value);
   }
 
   /// Record this definition for the current scope.
@@ -560,18 +580,26 @@ ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) {
 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
   auto &entries = getSSAValueEntry(useInfo.name);
 
+  // Functor used to record the use of the given value if the assembly state
+  // field is populated.
+  auto maybeRecordUse = [&](Value value) {
+    if (state.asmState)
+      state.asmState->addUses(value, useInfo.loc);
+    return value;
+  };
+
   // If we have already seen a value of this name, return it.
-  if (useInfo.number < entries.size() && entries[useInfo.number].first) {
-    auto result = entries[useInfo.number].first;
+  if (useInfo.number < entries.size() && entries[useInfo.number].value) {
+    Value result = entries[useInfo.number].value;
     // Check that the type matches the other uses.
     if (result.getType() == type)
-      return result;
+      return maybeRecordUse(result);
 
     emitError(useInfo.loc, "use of value '")
         .append(useInfo.name,
                 "' expects 
diff erent type than prior uses: ", type, " vs ",
                 result.getType())
-        .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
+        .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc))
         .append("prior use here");
     return nullptr;
   }
@@ -582,16 +610,15 @@ Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
 
   // If the value has already been defined and this is an overly large result
   // number, diagnose that.
-  if (entries[0].first && !isForwardRefPlaceholder(entries[0].first))
+  if (entries[0].value && !isForwardRefPlaceholder(entries[0].value))
     return (emitError(useInfo.loc, "reference to invalid result number"),
             nullptr);
 
   // Otherwise, this is a forward reference.  Create a placeholder and remember
   // that we did so.
   auto result = createForwardRefPlaceholder(useInfo.loc, type);
-  entries[useInfo.number].first = result;
-  entries[useInfo.number].second = useInfo.loc;
-  return result;
+  entries[useInfo.number] = {result, useInfo.loc};
+  return maybeRecordUse(result);
 }
 
 /// Parse an SSA use with an associated type.
@@ -653,8 +680,8 @@ void OperationParser::recordDefinition(StringRef def) {
 }
 
 /// Get the value entry for the given SSA name.
-SmallVectorImpl<std::pair<Value, SMLoc>> &
-OperationParser::getSSAValueEntry(StringRef name) {
+auto OperationParser::getSSAValueEntry(StringRef name)
+    -> SmallVectorImpl<ValueDefinition> & {
   return isolatedNameScopes.back().values[name];
 }
 
@@ -732,9 +759,10 @@ ParseResult OperationParser::parseOperation() {
   }
 
   Operation *op;
-  if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
+  Token nameTok = getToken();
+  if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword())
     op = parseCustomOperation(resultIDs);
-  else if (getToken().is(Token::string))
+  else if (nameTok.is(Token::string))
     op = parseGenericOperation();
   else
     return emitError("expected operation name in quotes");
@@ -752,6 +780,18 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
+    // Add this operation to the assembly state if it was provided to populate.
+    if (state.asmState) {
+      unsigned resultIt = 0;
+      SmallVector<std::pair<unsigned, llvm::SMLoc>> asmResultGroups;
+      asmResultGroups.reserve(resultIDs.size());
+      for (ResultRecord &record : resultIDs) {
+        asmResultGroups.emplace_back(resultIt, std::get<2>(record));
+        resultIt += std::get<1>(record);
+      }
+      state.asmState->addDefinition(op, nameTok.getLocRange(), asmResultGroups);
+    }
+
     // Add definitions for each of the result groups.
     unsigned opResI = 0;
     for (ResultRecord &resIt : resultIDs) {
@@ -761,6 +801,10 @@ ParseResult OperationParser::parseOperation() {
           return failure();
       }
     }
+
+    // Add this operation to the assembly state if it was provided to populate.
+  } else if (state.asmState) {
+    state.asmState->addDefinition(op, nameTok.getLocRange());
   }
 
   return success();
@@ -1772,8 +1816,7 @@ ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) {
     }
 
     // If this alias can be resolved, do it now.
-    Attribute attr =
-        getState().symbols.attributeAliasDefinitions.lookup(identifier);
+    Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier);
     if (attr) {
       if (!(directLoc = attr.dyn_cast<LocationAttr>()))
         return emitError(tok.getLoc())
@@ -1809,6 +1852,7 @@ ParseResult OperationParser::parseRegion(
     ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
     bool isIsolatedNameScope) {
   // Parse the '{'.
+  Token lBraceTok = getToken();
   if (parseToken(Token::l_brace, "expected '{' to begin a region"))
     return failure();
 
@@ -1824,10 +1868,17 @@ ParseResult OperationParser::parseRegion(
   auto owning_block = std::make_unique<Block>();
   Block *block = owning_block.get();
 
+  // If this block is not defined in the source file, add a definition for it
+  // now in the assembly state. Blocks with a name will be defined when the name
+  // is parsed.
+  if (state.asmState && getToken().isNot(Token::caret_identifier))
+    state.asmState->addDefinition(block, lBraceTok.getLoc());
+
   // Add arguments to the entry block.
   if (!entryArguments.empty()) {
     for (auto &placeholderArgPair : entryArguments) {
       auto &argInfo = placeholderArgPair.first;
+
       // Ensure that the argument was not already defined.
       if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) {
         return emitError(argInfo.loc, "region entry argument '" + argInfo.name +
@@ -1835,10 +1886,15 @@ ParseResult OperationParser::parseRegion(
                    .attachNote(getEncodedSourceLocation(*defLoc))
                << "previously referenced here";
       }
-      if (addDefinition(placeholderArgPair.first,
-                        block->addArgument(placeholderArgPair.second))) {
+      BlockArgument arg = block->addArgument(placeholderArgPair.second);
+
+      // Add a definition of this arg to the assembly state if provided.
+      if (state.asmState)
+        state.asmState->addDefinition(arg, argInfo.loc);
+
+      // Record the definition for this argument.
+      if (addDefinition(argInfo, arg))
         return failure();
-      }
     }
 
     // If we had named arguments, then don't allow a block name.
@@ -1846,9 +1902,8 @@ ParseResult OperationParser::parseRegion(
       return emitError("invalid block name in region with named arguments");
   }
 
-  if (parseBlock(block)) {
+  if (parseBlock(block))
     return failure();
-  }
 
   // Verify that no other arguments were parsed.
   if (!entryArguments.empty() &&
@@ -1915,8 +1970,7 @@ ParseResult OperationParser::parseBlock(Block *&block) {
 
   // If an argument list is present, parse it.
   if (consumeIf(Token::l_paren)) {
-    SmallVector<BlockArgument, 8> bbArgs;
-    if (parseOptionalBlockArgList(bbArgs, block) ||
+    if (parseOptionalBlockArgList(block) ||
         parseToken(Token::r_paren, "expected ')' to end argument list"))
       return failure();
   }
@@ -1943,13 +1997,17 @@ ParseResult OperationParser::parseBlockBody(Block *block) {
 /// exist.  The location specified is the point of use, which allows
 /// us to diagnose references to blocks that are not defined precisely.
 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) {
-  auto &blockAndLoc = getBlockInfoByName(name);
-  if (!blockAndLoc.first) {
-    blockAndLoc = {new Block(), loc};
-    insertForwardRef(blockAndLoc.first, loc);
+  BlockDefinition &blockDef = getBlockInfoByName(name);
+  if (!blockDef.block) {
+    blockDef = {new Block(), loc};
+    insertForwardRef(blockDef.block, blockDef.loc);
   }
 
-  return blockAndLoc.first;
+  // Populate the high level assembly state if necessary.
+  if (state.asmState)
+    state.asmState->addUses(blockDef.block, loc);
+
+  return blockDef.block;
 }
 
 /// Define the block with the specified name. Returns the Block* or nullptr in
@@ -1957,29 +2015,32 @@ Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) {
 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc,
                                          Block *existing) {
   auto &blockAndLoc = getBlockInfoByName(name);
-  if (!blockAndLoc.first) {
-    // If the caller provided a block, use it.  Otherwise create a new one.
-    if (!existing)
-      existing = new Block();
-    blockAndLoc.first = existing;
-    blockAndLoc.second = loc;
-    return blockAndLoc.first;
-  }
-
-  // Forward declarations are removed once defined, so if we are defining a
-  // existing block and it is not a forward declaration, then it is a
-  // redeclaration.
-  if (!eraseForwardRef(blockAndLoc.first))
+  blockAndLoc.loc = loc;
+
+  // If a block has yet to be set, this is a new definition. If the caller
+  // provided a block, use it. Otherwise create a new one.
+  if (!blockAndLoc.block) {
+    blockAndLoc.block = existing ? existing : new Block();
+
+    // Otherwise, the block has a forward declaration. Forward declarations are
+    // removed once defined, so if we are defining a existing block and it is
+    // not a forward declaration, then it is a redeclaration.
+  } else if (!eraseForwardRef(blockAndLoc.block)) {
     return nullptr;
-  return blockAndLoc.first;
+  }
+
+  // Populate the high level assembly state if necessary.
+  if (state.asmState)
+    state.asmState->addDefinition(blockAndLoc.block, loc);
+
+  return blockAndLoc.block;
 }
 
 /// Parse a (possibly empty) list of SSA operands with types as block arguments.
 ///
 ///   ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
 ///
-ParseResult OperationParser::parseOptionalBlockArgList(
-    SmallVectorImpl<BlockArgument> &results, Block *owner) {
+ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
   if (getToken().is(Token::r_brace))
     return success();
 
@@ -1991,18 +2052,28 @@ ParseResult OperationParser::parseOptionalBlockArgList(
   return parseCommaSeparatedList([&]() -> ParseResult {
     return parseSSADefOrUseAndType(
         [&](SSAUseInfo useInfo, Type type) -> ParseResult {
-          // If this block did not have existing arguments, define a new one.
-          if (!definingExistingArgs)
-            return addDefinition(useInfo, owner->addArgument(type));
-
-          // Otherwise, ensure that this argument has already been created.
-          if (nextArgument >= owner->getNumArguments())
-            return emitError("too many arguments specified in argument list");
-
-          // Finally, make sure the existing argument has the correct type.
-          auto arg = owner->getArgument(nextArgument++);
-          if (arg.getType() != type)
-            return emitError("argument and block argument type mismatch");
+          BlockArgument arg;
+
+          // If we are defining existing arguments, ensure that the argument
+          // has already been created with the right type.
+          if (definingExistingArgs) {
+            // Otherwise, ensure that this argument has already been created.
+            if (nextArgument >= owner->getNumArguments())
+              return emitError("too many arguments specified in argument list");
+
+            // Finally, make sure the existing argument has the correct type.
+            arg = owner->getArgument(nextArgument++);
+            if (arg.getType() != type)
+              return emitError("argument and block argument type mismatch");
+          } else {
+            arg = owner->addArgument(type);
+          }
+
+          // Mark this block argument definition in the parser state if it was
+          // provided.
+          if (state.asmState)
+            state.asmState->addDefinition(arg, useInfo.loc);
+
           return addDefinition(useInfo, arg);
         });
   });
@@ -2040,7 +2111,7 @@ ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0)
+  if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of attribute alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect attribute namespace.
@@ -2059,7 +2130,7 @@ ParseResult TopLevelOperationParser::parseAttributeAliasDef() {
   if (!attr)
     return failure();
 
-  getState().symbols.attributeAliasDefinitions[aliasName] = attr;
+  state.symbols.attributeAliasDefinitions[aliasName] = attr;
   return success();
 }
 
@@ -2072,7 +2143,7 @@ ParseResult TopLevelOperationParser::parseTypeAliasDef() {
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0)
+  if (state.symbols.typeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of type alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect type namespace.
@@ -2093,7 +2164,7 @@ ParseResult TopLevelOperationParser::parseTypeAliasDef() {
     return failure();
 
   // Register this alias with the parser state.
-  getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+  state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
   return success();
 }
 
@@ -2101,7 +2172,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
                                            Location parserLoc) {
   // Create a top-level operation to contain the parsed state.
   OwningOpRef<Operation *> topLevelOp(ModuleOp::create(parserLoc));
-  OperationParser opParser(getState(), topLevelOp.get());
+  OperationParser opParser(state, topLevelOp.get());
   while (true) {
     switch (getToken().getKind()) {
     default:
@@ -2153,7 +2224,8 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
 
 LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
                                     Block *block, MLIRContext *context,
-                                    LocationAttr *sourceFileLoc) {
+                                    LocationAttr *sourceFileLoc,
+                                    AsmParserState *asmState) {
   const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
 
   Location parserLoc = FileLineColLoc::get(
@@ -2162,7 +2234,7 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
     *sourceFileLoc = parserLoc;
 
   SymbolState aliasState;
-  ParserState state(sourceMgr, context, aliasState);
+  ParserState state(sourceMgr, context, aliasState, asmState);
   return TopLevelOperationParser(state).parse(block, parserLoc);
 }
 
@@ -2176,7 +2248,8 @@ LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block,
 LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
                                     llvm::SourceMgr &sourceMgr, Block *block,
                                     MLIRContext *context,
-                                    LocationAttr *sourceFileLoc) {
+                                    LocationAttr *sourceFileLoc,
+                                    AsmParserState *asmState) {
   if (sourceMgr.getNumBuffers() != 0) {
     // TODO: Extend to support multiple buffers.
     return emitError(mlir::UnknownLoc::get(context),
@@ -2189,7 +2262,7 @@ LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
 
   // Load the MLIR source file.
   sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
-  return parseSourceFile(sourceMgr, block, context, sourceFileLoc);
+  return parseSourceFile(sourceMgr, block, context, sourceFileLoc, asmState);
 }
 
 LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block,

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 8d1910e73843b..93ea18746c841 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -268,7 +268,7 @@ class Parser {
                          function_ref<ParseResult(bool)> parseElement,
                          OpAsmParser::Delimiter delimiter);
 
-private:
+protected:
   /// The Parser is subclassed and reinstantiated.  Do not add additional
   /// non-trivial state here, add it to the ParserState class.
   ParserState &state;

diff  --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h
index 27048d5d2f034..13d506d27e302 100644
--- a/mlir/lib/Parser/ParserState.h
+++ b/mlir/lib/Parser/ParserState.h
@@ -48,9 +48,10 @@ struct SymbolState {
 /// such as the current lexer position etc.
 struct ParserState {
   ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
-              SymbolState &symbols)
+              SymbolState &symbols, AsmParserState *asmState)
       : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
-        symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
+        symbols(symbols), parserDepth(symbols.nestedParserLocs.size()),
+        asmState(asmState) {
     // Set the top level lexer for the symbol state if one doesn't exist.
     if (!symbols.topLevelLexer)
       symbols.topLevelLexer = &lex;
@@ -77,6 +78,10 @@ struct ParserState {
 
   /// The depth of this parser in the nested parsing stack.
   size_t parserDepth;
+
+  /// An optional pointer to a struct containing high level parser state to be
+  /// populated during parsing.
+  AsmParserState *asmState;
 };
 
 } // end namespace detail


        


More information about the Mlir-commits mailing list