[Mlir-commits] [mlir] d6af89b - [mlir-lsp-server] Add support for tracking the use/def chains of symbols

River Riddle llvmlistbot at llvm.org
Thu Jun 3 16:12:43 PDT 2021


Author: River Riddle
Date: 2021-06-03T16:12:27-07:00
New Revision: d6af89beb26df549d5b9e9041dac3b205b44e512

URL: https://github.com/llvm/llvm-project/commit/d6af89beb26df549d5b9e9041dac3b205b44e512
DIFF: https://github.com/llvm/llvm-project/commit/d6af89beb26df549d5b9e9041dac3b205b44e512.diff

LOG: [mlir-lsp-server] Add support for tracking the use/def chains of symbols

This revision adds assembly state tracking for uses of symbols, allowing for go-to-definition and references support for SymbolRefAttrs.

Differential Revision: https://reviews.llvm.org/D103585

Added: 
    

Modified: 
    mlir/include/mlir/Parser/AsmParserState.h
    mlir/lib/Parser/AsmParserState.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
    mlir/test/mlir-lsp-server/definition-split-file.test
    mlir/test/mlir-lsp-server/definition.test
    mlir/test/mlir-lsp-server/hover.test
    mlir/test/mlir-lsp-server/references.test

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h
index 86625e6afe88..a519731d97dd 100644
--- a/mlir/include/mlir/Parser/AsmParserState.h
+++ b/mlir/include/mlir/Parser/AsmParserState.h
@@ -20,6 +20,8 @@ class Block;
 class BlockArgument;
 class FileLineColLoc;
 class Operation;
+class OperationName;
+class SymbolRefAttr;
 class Value;
 
 /// This class represents state from a parsed MLIR textual format string. It is
@@ -61,6 +63,10 @@ class AsmParserState {
 
     /// Source definitions for any result groups of this operation.
     SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;
+
+    /// If this operation is a symbol operation, this vector contains symbol
+    /// uses of this operation.
+    SmallVector<llvm::SMRange> symbolUses;
   };
 
   /// This class represents the information for a block definition within the
@@ -112,10 +118,28 @@ class AsmParserState {
   // Populate State
   //===--------------------------------------------------------------------===//
 
-  /// Add a definition of the given operation.
-  void addDefinition(
-      Operation *op, llvm::SMRange location,
+  /// Initialize the state in preparation for populating more parser state under
+  /// the given top-level operation.
+  void initialize(Operation *topLevelOp);
+
+  /// Finalize any in-progress parser state under the given top-level operation.
+  void finalize(Operation *topLevelOp);
+
+  /// Start a definition for an operation with the given name.
+  void startOperationDefinition(const OperationName &opName);
+
+  /// Finalize the most recently started operation definition.
+  void finalizeOperationDefinition(
+      Operation *op, llvm::SMRange nameLoc,
       ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);
+
+  /// Start a definition for a region nested under the current operation.
+  void startRegionDefinition();
+
+  /// Finalize the most recently started region definition.
+  void finalizeRegionDefinition();
+
+  /// Add a definition of the given entity.
   void addDefinition(Block *block, llvm::SMLoc location);
   void addDefinition(BlockArgument blockArg, llvm::SMLoc location);
 
@@ -123,6 +147,12 @@ class AsmParserState {
   void addUses(Value value, ArrayRef<llvm::SMLoc> locations);
   void addUses(Block *block, ArrayRef<llvm::SMLoc> locations);
 
+  /// Add source uses for all the references nested under `refAttr`. The
+  /// provided `locations` should match 1-1 with the number of references in
+  /// `refAttr`, i.e.:
+  ///   nestedReferences.size() + /*leafReference=*/1 == refLocations.size()
+  void addUses(SymbolRefAttr refAttr, ArrayRef<llvm::SMRange> refLocations);
+
   /// Refine the `oldValue` to the `newValue`. This is used to indicate that
   /// `oldValue` was a placeholder, and the uses of it should really refer to
   /// `newValue`.

diff  --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp
index 71b13c945205..d85c068d493f 100644
--- a/mlir/lib/Parser/AsmParserState.cpp
+++ b/mlir/lib/Parser/AsmParserState.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Parser/AsmParserState.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
 
 using namespace mlir;
 
@@ -16,6 +17,27 @@ using namespace mlir;
 //===----------------------------------------------------------------------===//
 
 struct AsmParserState::Impl {
+  /// A map from a SymbolRefAttr to a range of uses.
+  using SymbolUseMap = DenseMap<Attribute, SmallVector<llvm::SMRange>>;
+
+  struct PartialOpDef {
+    explicit PartialOpDef(const OperationName &opName) {
+      const auto *abstractOp = opName.getAbstractOperation();
+      if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
+        symbolTable = std::make_unique<SymbolUseMap>();
+    }
+
+    /// Return if this operation is a symbol table.
+    bool isSymbolTable() const { return symbolTable.get(); }
+
+    /// If this operation is a symbol table, the following contains symbol uses
+    /// within this operation.
+    std::unique_ptr<SymbolUseMap> symbolTable;
+  };
+
+  /// Resolve any symbol table uses under the given partial operation.
+  void resolveSymbolUses(Operation *op, PartialOpDef &opDef);
+
   /// A mapping from operations in the input source file to their parser state.
   SmallVector<std::unique_ptr<OperationDefinition>> operations;
   DenseMap<Operation *, unsigned> operationToIdx;
@@ -27,8 +49,38 @@ struct AsmParserState::Impl {
   /// A set of value definitions that are placeholders for forward references.
   /// This map should be empty if the parser finishes successfully.
   DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
+
+  /// A stack of partial operation definitions that have been started but not
+  /// yet finalized.
+  SmallVector<PartialOpDef> partialOperations;
+
+  /// A stack of symbol use scopes. This is used when collecting symbol table
+  /// uses during parsing.
+  SmallVector<SymbolUseMap *> symbolUseScopes;
+
+  /// A symbol table containing all of the symbol table operations in the IR.
+  SymbolTableCollection symbolTable;
 };
 
+void AsmParserState::Impl::resolveSymbolUses(Operation *op,
+                                             PartialOpDef &opDef) {
+  assert(opDef.isSymbolTable() && "expected op to be a symbol table");
+
+  SmallVector<Operation *> symbolOps;
+  for (auto &it : *opDef.symbolTable) {
+    symbolOps.clear();
+    if (failed(symbolTable.lookupSymbolIn(op, it.first.cast<SymbolRefAttr>(),
+                                          symbolOps)))
+      continue;
+
+    for (const auto &symIt : llvm::zip(symbolOps, it.second)) {
+      auto opIt = operationToIdx.find(std::get<0>(symIt));
+      if (opIt != operationToIdx.end())
+        operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // AsmParserState
 //===----------------------------------------------------------------------===//
@@ -77,17 +129,70 @@ llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
 //===----------------------------------------------------------------------===//
 // Populate State
 
-void AsmParserState::addDefinition(
-    Operation *op, llvm::SMRange location,
+void AsmParserState::initialize(Operation *topLevelOp) {
+  startOperationDefinition(topLevelOp->getName());
+
+  // If the top-level operation is a symbol table, push a new symbol scope.
+  Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+  if (partialOpDef.isSymbolTable())
+    impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
+}
+
+void AsmParserState::finalize(Operation *topLevelOp) {
+  assert(!impl->partialOperations.empty() &&
+         "expected valid partial operation definition");
+  Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
+
+  // If this operation is a symbol table, resolve any symbol uses.
+  if (partialOpDef.isSymbolTable())
+    impl->resolveSymbolUses(topLevelOp, partialOpDef);
+}
+
+void AsmParserState::startOperationDefinition(const OperationName &opName) {
+  impl->partialOperations.emplace_back(opName);
+}
+
+void AsmParserState::finalizeOperationDefinition(
+    Operation *op, llvm::SMRange nameLoc,
     ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
+  assert(!impl->partialOperations.empty() &&
+         "expected valid partial operation definition");
+  Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
+
+  // Build the full operation definition.
   std::unique_ptr<OperationDefinition> def =
-      std::make_unique<OperationDefinition>(op, location);
+      std::make_unique<OperationDefinition>(op, nameLoc);
   for (auto &resultGroup : resultGroups)
     def->resultGroups.emplace_back(resultGroup.first,
                                    convertIdLocToRange(resultGroup.second));
-
   impl->operationToIdx.try_emplace(op, impl->operations.size());
   impl->operations.emplace_back(std::move(def));
+
+  // If this operation is a symbol table, resolve any symbol uses.
+  if (partialOpDef.isSymbolTable())
+    impl->resolveSymbolUses(op, partialOpDef);
+}
+
+void AsmParserState::startRegionDefinition() {
+  assert(!impl->partialOperations.empty() &&
+         "expected valid partial operation definition");
+
+  // If the parent operation of this region is a symbol table, we also push a
+  // new symbol scope.
+  Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+  if (partialOpDef.isSymbolTable())
+    impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
+}
+
+void AsmParserState::finalizeRegionDefinition() {
+  assert(!impl->partialOperations.empty() &&
+         "expected valid partial operation definition");
+
+  // If the parent operation of this region is a symbol table, pop the symbol
+  // scope for this region.
+  Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+  if (partialOpDef.isSymbolTable())
+    impl->symbolUseScopes.pop_back();
 }
 
 void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
@@ -169,6 +274,18 @@ void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
     def.definition.uses.push_back(convertIdLocToRange(loc));
 }
 
+void AsmParserState::addUses(SymbolRefAttr refAttr,
+                             ArrayRef<llvm::SMRange> locations) {
+  // Ignore this symbol if no scopes are active.
+  if (impl->symbolUseScopes.empty())
+    return;
+
+  assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
+         "expected the same number of references as provided locations");
+  (*impl->symbolUseScopes.back())[refAttr].append(locations.begin(),
+                                                  locations.end());
+}
+
 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
   auto it = impl->placeholderValueUses.find(oldValue);
   assert(it != impl->placeholderValueUses.end() &&

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index a90c65ff1fb3..38f86155d4b7 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/Parser/AsmParserState.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Endian.h"
 
@@ -153,6 +154,13 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse a symbol reference attribute.
   case Token::at_identifier: {
+    // When populating the parser state, this is a list of locations for all of
+    // the nested references.
+    SmallVector<llvm::SMRange> referenceLocations;
+    if (state.asmState)
+      referenceLocations.push_back(getToken().getLocRange());
+
+    // Parse the top-level reference.
     std::string nameStr = getToken().getSymbolReference();
     consumeToken(Token::at_identifier);
 
@@ -174,12 +182,21 @@ Attribute Parser::parseAttribute(Type type) {
         return Attribute();
       }
 
+      // If we are populating the assembly state, add the location for this
+      // reference.
+      if (state.asmState)
+        referenceLocations.push_back(getToken().getLocRange());
+
       std::string nameStr = getToken().getSymbolReference();
       consumeToken(Token::at_identifier);
       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
     }
+    SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);
 
-    return builder.getSymbolRefAttr(nameStr, nestedRefs);
+    // If we are populating the assembly state, record this symbol reference.
+    if (state.asmState)
+      state.asmState->addUses(symbolRefAttr, referenceLocations);
+    return symbolRefAttr;
   }
 
   // Parse a 'unit' attribute.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index aa0d12f7568f..4b80158f04a2 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -166,18 +166,12 @@ namespace {
 /// operations.
 class OperationParser : public Parser {
 public:
-  OperationParser(ParserState &state, Operation *topLevelOp)
-      : Parser(state), opBuilder(topLevelOp->getRegion(0)),
-        topLevelOp(topLevelOp) {
-    // The top level operation starts a new name scope.
-    pushSSANameScope(/*isIsolated=*/true);
-  }
-
+  OperationParser(ParserState &state, Operation *topLevelOp);
   ~OperationParser();
 
   /// After parsing is finished, this function must be called to see if there
   /// are any remaining issues.
-  ParseResult finalize();
+  ParseResult finalize(Operation *topLevelOp);
 
   //===--------------------------------------------------------------------===//
   // SSA Value Handling
@@ -281,7 +275,10 @@ class OperationParser : public Parser {
                           bool isIsolatedNameScope = false);
 
   /// Parse a region body into 'region'.
-  ParseResult parseRegionBody(Region &region);
+  ParseResult
+  parseRegionBody(Region &region, llvm::SMLoc startLoc,
+                  ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
+                  bool isIsolatedNameScope);
 
   //===--------------------------------------------------------------------===//
   // Block Parsing
@@ -402,6 +399,17 @@ class OperationParser : public Parser {
 };
 } // end anonymous namespace
 
+OperationParser::OperationParser(ParserState &state, Operation *topLevelOp)
+    : Parser(state), opBuilder(topLevelOp->getRegion(0)),
+      topLevelOp(topLevelOp) {
+  // The top level operation starts a new name scope.
+  pushSSANameScope(/*isIsolated=*/true);
+
+  // If we are populating the parser state, prepare it for parsing.
+  if (state.asmState)
+    state.asmState->initialize(topLevelOp);
+}
+
 OperationParser::~OperationParser() {
   for (auto &fwd : forwardRefPlaceholders) {
     // Drop all uses of undefined forward declared reference and destroy
@@ -421,7 +429,7 @@ OperationParser::~OperationParser() {
 
 /// After parsing is finished, this function must be called to see if there are
 /// any remaining issues.
-ParseResult OperationParser::finalize() {
+ParseResult OperationParser::finalize(Operation *topLevelOp) {
   // Check for any forward references that are left.  If we find any, error
   // out.
   if (!forwardRefPlaceholders.empty()) {
@@ -458,6 +466,10 @@ ParseResult OperationParser::finalize() {
       opOrArgument.get<BlockArgument>().setLoc(locAttr);
   }
 
+  // If we are populating the parser state, finalize the top-level operation.
+  if (state.asmState)
+    state.asmState->finalize(topLevelOp);
+
   // Pop the top level name scope.
   return popSSANameScope();
 }
@@ -809,7 +821,8 @@ ParseResult OperationParser::parseOperation() {
         asmResultGroups.emplace_back(resultIt, std::get<2>(record));
         resultIt += std::get<1>(record);
       }
-      state.asmState->addDefinition(op, nameTok.getLocRange(), asmResultGroups);
+      state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
+                                                  asmResultGroups);
     }
 
     // Add definitions for each of the result groups.
@@ -824,7 +837,7 @@ ParseResult OperationParser::parseOperation() {
 
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
-    state.asmState->addDefinition(op, nameTok.getLocRange());
+    state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange());
   }
 
   return success();
@@ -903,6 +916,10 @@ Operation *OperationParser::parseGenericOperation() {
     }
   }
 
+  // If we are populating the parser state, start a new operation definition.
+  if (state.asmState)
+    state.asmState->startOperationDefinition(result.name);
+
   // Parse the operand list.
   SmallVector<SSAUseInfo, 8> operandInfos;
   if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@@ -981,9 +998,19 @@ Operation *OperationParser::parseGenericOperation() {
 
 Operation *OperationParser::parseGenericOperation(Block *insertBlock,
                                                   Block::iterator insertPt) {
+  Token nameToken = getToken();
+
   OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder);
   opBuilder.setInsertionPoint(insertBlock, insertPt);
-  return parseGenericOperation();
+  Operation *op = parseGenericOperation();
+  if (!op)
+    return nullptr;
+
+  // If we are populating the parser asm state, finalize this operation
+  // definition.
+  if (state.asmState)
+    state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange());
+  return op;
 }
 
 namespace {
@@ -1367,6 +1394,14 @@ class CustomOpAsmParser : public OpAsmParser {
     result = getBuilder().getStringAttr(atToken.getSymbolReference());
     attrs.push_back(getBuilder().getNamedAttr(attrName, result));
     parser.consumeToken();
+
+    // If we are populating the assembly parser state, record this as a symbol
+    // reference.
+    if (parser.getState().asmState) {
+      parser.getState().asmState->addUses(
+          getBuilder().getSymbolRefAttr(result.getValue()),
+          atToken.getLocRange());
+    }
     return success();
   }
 
@@ -1858,9 +1893,13 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Get location information for the operation.
   auto srcLocation = getEncodedSourceLocation(opLoc);
+  OperationState opState(srcLocation, opName);
+
+  // If we are populating the parser state, start a new operation definition.
+  if (state.asmState)
+    state.asmState->startOperationDefinition(opState.name);
 
   // Have the op implementation take a crack and parsing this.
-  OperationState opState(srcLocation, opName);
   CleanupOpStateRegions guard{opState};
   CustomOpAsmParser opAsmParser(opLoc, resultIDs, parseAssemblyFn,
                                 isIsolatedFromAbove, opName, *this);
@@ -1931,10 +1970,6 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
 // Region Parsing
 //===----------------------------------------------------------------------===//
 
-/// Region.
-///
-///   region ::= '{' region-body
-///
 ParseResult OperationParser::parseRegion(
     Region &region,
     ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
@@ -1944,9 +1979,29 @@ ParseResult OperationParser::parseRegion(
   if (parseToken(Token::l_brace, "expected '{' to begin a region"))
     return failure();
 
-  // Check for an empty region.
-  if (entryArguments.empty() && consumeIf(Token::r_brace))
-    return success();
+  // If we are populating the parser state, start a new region definition.
+  if (state.asmState)
+    state.asmState->startRegionDefinition();
+
+  // Parse the region body.
+  if ((!entryArguments.empty() || getToken().isNot(Token::r_brace)) &&
+      parseRegionBody(region, lBraceTok.getLoc(), entryArguments,
+                      isIsolatedNameScope)) {
+    return failure();
+  }
+  consumeToken(Token::r_brace);
+
+  // If we are populating the parser state, finalize this region.
+  if (state.asmState)
+    state.asmState->finalizeRegionDefinition();
+
+  return success();
+}
+
+ParseResult OperationParser::parseRegionBody(
+    Region &region, llvm::SMLoc startLoc,
+    ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
+    bool isIsolatedNameScope) {
   auto currentPt = opBuilder.saveInsertionPoint();
 
   // Push a new named value scope.
@@ -1960,7 +2015,7 @@ ParseResult OperationParser::parseRegion(
   // now in the assembly state. Blocks with a name will be defined when the name
   // is parsed.
   if (state.asmState && getToken().isNot(Token::caret_identifier))
-    state.asmState->addDefinition(block, lBraceTok.getLoc());
+    state.asmState->addDefinition(block, startLoc);
 
   // Add arguments to the entry block.
   if (!entryArguments.empty()) {
@@ -2002,8 +2057,12 @@ ParseResult OperationParser::parseRegion(
 
   // Parse the rest of the region.
   region.push_back(owning_block.release());
-  if (parseRegionBody(region))
-    return failure();
+  while (getToken().isNot(Token::r_brace)) {
+    Block *newBlock = nullptr;
+    if (parseBlock(newBlock))
+      return failure();
+    region.push_back(newBlock);
+  }
 
   // Pop the SSA value scope for this region.
   if (popSSANameScope())
@@ -2014,21 +2073,6 @@ ParseResult OperationParser::parseRegion(
   return success();
 }
 
-/// Region.
-///
-///   region-body ::= block* '}'
-///
-ParseResult OperationParser::parseRegionBody(Region &region) {
-  // Parse the list of blocks.
-  while (!consumeIf(Token::r_brace)) {
-    Block *newBlock = nullptr;
-    if (parseBlock(newBlock))
-      return failure();
-    region.push_back(newBlock);
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Block Parsing
 //===----------------------------------------------------------------------===//
@@ -2278,7 +2322,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
 
     // If we got to the end of the file, then we're done.
     case Token::eof: {
-      if (opParser.finalize())
+      if (opParser.finalize(topLevelOp.get()))
         return failure();
 
       // Verify that the parsed operations are valid.

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 3b1dfcc6195e..d4b61372552d 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -21,7 +21,7 @@ static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) {
   std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
   lsp::Position pos;
   pos.line = lineAndCol.first - 1;
-  pos.character = lineAndCol.second;
+  pos.character = lineAndCol.second - 1;
   return pos;
 }
 
@@ -33,10 +33,7 @@ static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
 
 /// Returns a language server range for the given source range.
 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) {
-  // lsp::Range is an inclusive range, SMRange is half-open.
-  llvm::SMLoc inclusiveEnd =
-      llvm::SMLoc::getFromPointer(range.End.getPointer() - 1);
-  return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, inclusiveEnd)};
+  return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
 }
 
 /// Returns a language server location from the given source range.
@@ -365,6 +362,12 @@ void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
     for (const auto &result : op.resultGroups)
       if (containsPosition(result.second))
         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
+    for (const auto &symUse : op.symbolUses) {
+      if (contains(symUse, posLoc)) {
+        locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
+        return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
+      }
+    }
   }
 
   // Check all definitions related to blocks.
@@ -395,11 +398,21 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
     if (contains(op.loc, posLoc)) {
       for (const auto &result : op.resultGroups)
         appendSMDef(result.second);
+      for (const auto &symUse : op.symbolUses)
+        if (contains(symUse, posLoc))
+          references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
       return;
     }
     for (const auto &result : op.resultGroups)
       if (isDefOrUse(result.second, posLoc))
         return appendSMDef(result.second);
+    for (const auto &symUse : op.symbolUses) {
+      if (!contains(symUse, posLoc))
+        continue;
+      for (const auto &symUse : op.symbolUses)
+        references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
+      return;
+    }
   }
 
   // Check all definitions related to blocks.

diff  --git a/mlir/test/mlir-lsp-server/definition-split-file.test b/mlir/test/mlir-lsp-server/definition-split-file.test
index c32f8be396a7..14bd2e2da3f4 100644
--- a/mlir/test/mlir-lsp-server/definition-split-file.test
+++ b/mlir/test/mlir-lsp-server/definition-split-file.test
@@ -25,7 +25,7 @@
 // CHECK-NEXT:          "line": 3
 // CHECK-NEXT:        },
 // CHECK-NEXT:        "start": {
-// CHECK-NEXT:          "character": 1,
+// CHECK-NEXT:          "character": 0,
 // CHECK-NEXT:          "line": 3
 // CHECK-NEXT:        }
 // CHECK-NEXT:      },

diff  --git a/mlir/test/mlir-lsp-server/definition.test b/mlir/test/mlir-lsp-server/definition.test
index ddcebcfa28f3..5911f9c2f0ca 100644
--- a/mlir/test/mlir-lsp-server/definition.test
+++ b/mlir/test/mlir-lsp-server/definition.test
@@ -22,13 +22,34 @@
 // CHECK-NEXT:          "line": 1
 // CHECK-NEXT:        },
 // CHECK-NEXT:        "start": {
-// CHECK-NEXT:          "character": 1,
+// CHECK-NEXT:          "character": 0,
 // CHECK-NEXT:          "line": 1
 // CHECK-NEXT:        }
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "uri": "{{.*}}/foo.mlir"
 // CHECK-NEXT:    }
 // -----
+{"jsonrpc":"2.0","id":2,"method":"textDocument/definition","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":0,"character":7}
+}}
+//      CHECK:  "id": 2
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": [
+// CHECK-NEXT:    {
+// CHECK-NEXT:      "range": {
+// CHECK-NEXT:        "end": {
+// CHECK-NEXT:          "character": 4,
+// CHECK-NEXT:          "line": 0
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "start": {
+// CHECK-NEXT:          "character": 0,
+// CHECK-NEXT:          "line": 0
+// CHECK-NEXT:        }
+// CHECK-NEXT:      },
+// CHECK-NEXT:      "uri": "{{.*}}/foo.mlir"
+// CHECK-NEXT:    }
+// -----
 {"jsonrpc":"2.0","id":3,"method":"shutdown"}
 // -----
 {"jsonrpc":"2.0","method":"exit"}

diff  --git a/mlir/test/mlir-lsp-server/hover.test b/mlir/test/mlir-lsp-server/hover.test
index 77fff2a79f89..71008c8d9ab2 100644
--- a/mlir/test/mlir-lsp-server/hover.test
+++ b/mlir/test/mlir-lsp-server/hover.test
@@ -26,7 +26,7 @@
 // CHECK-NEXT:        "line": 1
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
-// CHECK-NEXT:        "character": 10,
+// CHECK-NEXT:        "character": 9,
 // CHECK-NEXT:        "line": 1
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
@@ -50,7 +50,7 @@
 // CHECK-NEXT:        "line": 1
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
-// CHECK-NEXT:        "character": 1,
+// CHECK-NEXT:        "character": 0,
 // CHECK-NEXT:        "line": 1
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
@@ -74,7 +74,7 @@
 // CHECK-NEXT:        "line": 3
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
-// CHECK-NEXT:        "character": 1,
+// CHECK-NEXT:        "character": 0,
 // CHECK-NEXT:        "line": 3
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }
@@ -98,7 +98,7 @@
 // CHECK-NEXT:        "line": 0
 // CHECK-NEXT:      },
 // CHECK-NEXT:      "start": {
-// CHECK-NEXT:        "character": 11,
+// CHECK-NEXT:        "character": 10,
 // CHECK-NEXT:        "line": 0
 // CHECK-NEXT:      }
 // CHECK-NEXT:    }

diff  --git a/mlir/test/mlir-lsp-server/references.test b/mlir/test/mlir-lsp-server/references.test
index 333329ddf341..8d3c1b6f25c5 100644
--- a/mlir/test/mlir-lsp-server/references.test
+++ b/mlir/test/mlir-lsp-server/references.test
@@ -23,7 +23,7 @@
 // CHECK-NEXT:          "line": 1
 // CHECK-NEXT:        },
 // CHECK-NEXT:        "start": {
-// CHECK-NEXT:          "character": 1,
+// CHECK-NEXT:          "character": 0,
 // CHECK-NEXT:          "line": 1
 // CHECK-NEXT:        }
 // CHECK-NEXT:      },
@@ -36,7 +36,7 @@
 // CHECK-NEXT:          "line": 2
 // CHECK-NEXT:        },
 // CHECK-NEXT:        "start": {
-// CHECK-NEXT:          "character": 8,
+// CHECK-NEXT:          "character": 7,
 // CHECK-NEXT:          "line": 2
 // CHECK-NEXT:        }
 // CHECK-NEXT:      },
@@ -44,6 +44,29 @@
 // CHECK-NEXT:    }
 // CHECK-NEXT:  ]
 // -----
+{"jsonrpc":"2.0","id":2,"method":"textDocument/references","params":{
+  "textDocument":{"uri":"test:///foo.mlir"},
+  "position":{"line":0,"character":7},
+  "context":{"includeDeclaration": false}
+}}
+//      CHECK:  "id": 2
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": [
+// CHECK-NEXT:    {
+// CHECK-NEXT:      "range": {
+// CHECK-NEXT:        "end": {
+// CHECK-NEXT:          "character": 9,
+// CHECK-NEXT:          "line": 0
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "start": {
+// CHECK-NEXT:          "character": 5,
+// CHECK-NEXT:          "line": 0
+// CHECK-NEXT:        }
+// CHECK-NEXT:      },
+// CHECK-NEXT:      "uri": "{{.*}}/foo.mlir"
+// CHECK-NEXT:    }
+// CHECK-NEXT:  ]
+// -----
 {"jsonrpc":"2.0","id":3,"method":"shutdown"}
 // -----
 {"jsonrpc":"2.0","method":"exit"}


        


More information about the Mlir-commits mailing list