[Mlir-commits] [mlir] d6af89b - [mlir-lsp-server] Add support for tracking the use/def chains of symbols
River Riddle
llvmlistbot at llvm.org
Thu Jun 3 16:12:43 PDT 2021
Author: River Riddle
Date: 2021-06-03T16:12:27-07:00
New Revision: d6af89beb26df549d5b9e9041dac3b205b44e512
URL: https://github.com/llvm/llvm-project/commit/d6af89beb26df549d5b9e9041dac3b205b44e512
DIFF: https://github.com/llvm/llvm-project/commit/d6af89beb26df549d5b9e9041dac3b205b44e512.diff
LOG: [mlir-lsp-server] Add support for tracking the use/def chains of symbols
This revision adds assembly state tracking for uses of symbols, allowing for go-to-definition and references support for SymbolRefAttrs.
Differential Revision: https://reviews.llvm.org/D103585
Added:
Modified:
mlir/include/mlir/Parser/AsmParserState.h
mlir/lib/Parser/AsmParserState.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/test/mlir-lsp-server/definition-split-file.test
mlir/test/mlir-lsp-server/definition.test
mlir/test/mlir-lsp-server/hover.test
mlir/test/mlir-lsp-server/references.test
Removed:
################################################################################
diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h
index 86625e6afe88..a519731d97dd 100644
--- a/mlir/include/mlir/Parser/AsmParserState.h
+++ b/mlir/include/mlir/Parser/AsmParserState.h
@@ -20,6 +20,8 @@ class Block;
class BlockArgument;
class FileLineColLoc;
class Operation;
+class OperationName;
+class SymbolRefAttr;
class Value;
/// This class represents state from a parsed MLIR textual format string. It is
@@ -61,6 +63,10 @@ class AsmParserState {
/// Source definitions for any result groups of this operation.
SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;
+
+ /// If this operation is a symbol operation, this vector contains symbol
+ /// uses of this operation.
+ SmallVector<llvm::SMRange> symbolUses;
};
/// This class represents the information for a block definition within the
@@ -112,10 +118,28 @@ class AsmParserState {
// Populate State
//===--------------------------------------------------------------------===//
- /// Add a definition of the given operation.
- void addDefinition(
- Operation *op, llvm::SMRange location,
+ /// Initialize the state in preparation for populating more parser state under
+ /// the given top-level operation.
+ void initialize(Operation *topLevelOp);
+
+ /// Finalize any in-progress parser state under the given top-level operation.
+ void finalize(Operation *topLevelOp);
+
+ /// Start a definition for an operation with the given name.
+ void startOperationDefinition(const OperationName &opName);
+
+ /// Finalize the most recently started operation definition.
+ void finalizeOperationDefinition(
+ Operation *op, llvm::SMRange nameLoc,
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);
+
+ /// Start a definition for a region nested under the current operation.
+ void startRegionDefinition();
+
+ /// Finalize the most recently started region definition.
+ void finalizeRegionDefinition();
+
+ /// Add a definition of the given entity.
void addDefinition(Block *block, llvm::SMLoc location);
void addDefinition(BlockArgument blockArg, llvm::SMLoc location);
@@ -123,6 +147,12 @@ class AsmParserState {
void addUses(Value value, ArrayRef<llvm::SMLoc> locations);
void addUses(Block *block, ArrayRef<llvm::SMLoc> locations);
+ /// Add source uses for all the references nested under `refAttr`. The
+ /// provided `locations` should match 1-1 with the number of references in
+ /// `refAttr`, i.e.:
+ /// nestedReferences.size() + /*leafReference=*/1 == refLocations.size()
+ void addUses(SymbolRefAttr refAttr, ArrayRef<llvm::SMRange> refLocations);
+
/// Refine the `oldValue` to the `newValue`. This is used to indicate that
/// `oldValue` was a placeholder, and the uses of it should really refer to
/// `newValue`.
diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp
index 71b13c945205..d85c068d493f 100644
--- a/mlir/lib/Parser/AsmParserState.cpp
+++ b/mlir/lib/Parser/AsmParserState.cpp
@@ -8,6 +8,7 @@
#include "mlir/Parser/AsmParserState.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
using namespace mlir;
@@ -16,6 +17,27 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
struct AsmParserState::Impl {
+ /// A map from a SymbolRefAttr to a range of uses.
+ using SymbolUseMap = DenseMap<Attribute, SmallVector<llvm::SMRange>>;
+
+ struct PartialOpDef {
+ explicit PartialOpDef(const OperationName &opName) {
+ const auto *abstractOp = opName.getAbstractOperation();
+ if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
+ symbolTable = std::make_unique<SymbolUseMap>();
+ }
+
+ /// Return if this operation is a symbol table.
+ bool isSymbolTable() const { return symbolTable.get(); }
+
+ /// If this operation is a symbol table, the following contains symbol uses
+ /// within this operation.
+ std::unique_ptr<SymbolUseMap> symbolTable;
+ };
+
+ /// Resolve any symbol table uses under the given partial operation.
+ void resolveSymbolUses(Operation *op, PartialOpDef &opDef);
+
/// A mapping from operations in the input source file to their parser state.
SmallVector<std::unique_ptr<OperationDefinition>> operations;
DenseMap<Operation *, unsigned> operationToIdx;
@@ -27,8 +49,38 @@ struct AsmParserState::Impl {
/// A set of value definitions that are placeholders for forward references.
/// This map should be empty if the parser finishes successfully.
DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
+
+ /// A stack of partial operation definitions that have been started but not
+ /// yet finalized.
+ SmallVector<PartialOpDef> partialOperations;
+
+ /// A stack of symbol use scopes. This is used when collecting symbol table
+ /// uses during parsing.
+ SmallVector<SymbolUseMap *> symbolUseScopes;
+
+ /// A symbol table containing all of the symbol table operations in the IR.
+ SymbolTableCollection symbolTable;
};
+void AsmParserState::Impl::resolveSymbolUses(Operation *op,
+ PartialOpDef &opDef) {
+ assert(opDef.isSymbolTable() && "expected op to be a symbol table");
+
+ SmallVector<Operation *> symbolOps;
+ for (auto &it : *opDef.symbolTable) {
+ symbolOps.clear();
+ if (failed(symbolTable.lookupSymbolIn(op, it.first.cast<SymbolRefAttr>(),
+ symbolOps)))
+ continue;
+
+ for (const auto &symIt : llvm::zip(symbolOps, it.second)) {
+ auto opIt = operationToIdx.find(std::get<0>(symIt));
+ if (opIt != operationToIdx.end())
+ operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// AsmParserState
//===----------------------------------------------------------------------===//
@@ -77,17 +129,70 @@ llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
//===----------------------------------------------------------------------===//
// Populate State
-void AsmParserState::addDefinition(
- Operation *op, llvm::SMRange location,
+void AsmParserState::initialize(Operation *topLevelOp) {
+ startOperationDefinition(topLevelOp->getName());
+
+ // If the top-level operation is a symbol table, push a new symbol scope.
+ Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+ if (partialOpDef.isSymbolTable())
+ impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
+}
+
+void AsmParserState::finalize(Operation *topLevelOp) {
+ assert(!impl->partialOperations.empty() &&
+ "expected valid partial operation definition");
+ Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
+
+ // If this operation is a symbol table, resolve any symbol uses.
+ if (partialOpDef.isSymbolTable())
+ impl->resolveSymbolUses(topLevelOp, partialOpDef);
+}
+
+void AsmParserState::startOperationDefinition(const OperationName &opName) {
+ impl->partialOperations.emplace_back(opName);
+}
+
+void AsmParserState::finalizeOperationDefinition(
+ Operation *op, llvm::SMRange nameLoc,
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
+ assert(!impl->partialOperations.empty() &&
+ "expected valid partial operation definition");
+ Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
+
+ // Build the full operation definition.
std::unique_ptr<OperationDefinition> def =
- std::make_unique<OperationDefinition>(op, location);
+ std::make_unique<OperationDefinition>(op, nameLoc);
for (auto &resultGroup : resultGroups)
def->resultGroups.emplace_back(resultGroup.first,
convertIdLocToRange(resultGroup.second));
-
impl->operationToIdx.try_emplace(op, impl->operations.size());
impl->operations.emplace_back(std::move(def));
+
+ // If this operation is a symbol table, resolve any symbol uses.
+ if (partialOpDef.isSymbolTable())
+ impl->resolveSymbolUses(op, partialOpDef);
+}
+
+void AsmParserState::startRegionDefinition() {
+ assert(!impl->partialOperations.empty() &&
+ "expected valid partial operation definition");
+
+ // If the parent operation of this region is a symbol table, we also push a
+ // new symbol scope.
+ Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+ if (partialOpDef.isSymbolTable())
+ impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
+}
+
+void AsmParserState::finalizeRegionDefinition() {
+ assert(!impl->partialOperations.empty() &&
+ "expected valid partial operation definition");
+
+ // If the parent operation of this region is a symbol table, pop the symbol
+ // scope for this region.
+ Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
+ if (partialOpDef.isSymbolTable())
+ impl->symbolUseScopes.pop_back();
}
void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
@@ -169,6 +274,18 @@ void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
def.definition.uses.push_back(convertIdLocToRange(loc));
}
+void AsmParserState::addUses(SymbolRefAttr refAttr,
+ ArrayRef<llvm::SMRange> locations) {
+ // Ignore this symbol if no scopes are active.
+ if (impl->symbolUseScopes.empty())
+ return;
+
+ assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
+ "expected the same number of references as provided locations");
+ (*impl->symbolUseScopes.back())[refAttr].append(locations.begin(),
+ locations.end());
+}
+
void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
auto it = impl->placeholderValueUses.find(oldValue);
assert(it != impl->placeholderValueUses.end() &&
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index a90c65ff1fb3..38f86155d4b7 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/Parser/AsmParserState.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
@@ -153,6 +154,13 @@ Attribute Parser::parseAttribute(Type type) {
// Parse a symbol reference attribute.
case Token::at_identifier: {
+ // When populating the parser state, this is a list of locations for all of
+ // the nested references.
+ SmallVector<llvm::SMRange> referenceLocations;
+ if (state.asmState)
+ referenceLocations.push_back(getToken().getLocRange());
+
+ // Parse the top-level reference.
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
@@ -174,12 +182,21 @@ Attribute Parser::parseAttribute(Type type) {
return Attribute();
}
+ // If we are populating the assembly state, add the location for this
+ // reference.
+ if (state.asmState)
+ referenceLocations.push_back(getToken().getLocRange());
+
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
}
+ SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);
- return builder.getSymbolRefAttr(nameStr, nestedRefs);
+ // If we are populating the assembly state, record this symbol reference.
+ if (state.asmState)
+ state.asmState->addUses(symbolRefAttr, referenceLocations);
+ return symbolRefAttr;
}
// Parse a 'unit' attribute.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index aa0d12f7568f..4b80158f04a2 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -166,18 +166,12 @@ namespace {
/// operations.
class OperationParser : public Parser {
public:
- OperationParser(ParserState &state, Operation *topLevelOp)
- : Parser(state), opBuilder(topLevelOp->getRegion(0)),
- topLevelOp(topLevelOp) {
- // The top level operation starts a new name scope.
- pushSSANameScope(/*isIsolated=*/true);
- }
-
+ OperationParser(ParserState &state, Operation *topLevelOp);
~OperationParser();
/// After parsing is finished, this function must be called to see if there
/// are any remaining issues.
- ParseResult finalize();
+ ParseResult finalize(Operation *topLevelOp);
//===--------------------------------------------------------------------===//
// SSA Value Handling
@@ -281,7 +275,10 @@ class OperationParser : public Parser {
bool isIsolatedNameScope = false);
/// Parse a region body into 'region'.
- ParseResult parseRegionBody(Region ®ion);
+ ParseResult
+ parseRegionBody(Region ®ion, llvm::SMLoc startLoc,
+ ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
+ bool isIsolatedNameScope);
//===--------------------------------------------------------------------===//
// Block Parsing
@@ -402,6 +399,17 @@ class OperationParser : public Parser {
};
} // end anonymous namespace
+OperationParser::OperationParser(ParserState &state, Operation *topLevelOp)
+ : Parser(state), opBuilder(topLevelOp->getRegion(0)),
+ topLevelOp(topLevelOp) {
+ // The top level operation starts a new name scope.
+ pushSSANameScope(/*isIsolated=*/true);
+
+ // If we are populating the parser state, prepare it for parsing.
+ if (state.asmState)
+ state.asmState->initialize(topLevelOp);
+}
+
OperationParser::~OperationParser() {
for (auto &fwd : forwardRefPlaceholders) {
// Drop all uses of undefined forward declared reference and destroy
@@ -421,7 +429,7 @@ OperationParser::~OperationParser() {
/// After parsing is finished, this function must be called to see if there are
/// any remaining issues.
-ParseResult OperationParser::finalize() {
+ParseResult OperationParser::finalize(Operation *topLevelOp) {
// Check for any forward references that are left. If we find any, error
// out.
if (!forwardRefPlaceholders.empty()) {
@@ -458,6 +466,10 @@ ParseResult OperationParser::finalize() {
opOrArgument.get<BlockArgument>().setLoc(locAttr);
}
+ // If we are populating the parser state, finalize the top-level operation.
+ if (state.asmState)
+ state.asmState->finalize(topLevelOp);
+
// Pop the top level name scope.
return popSSANameScope();
}
@@ -809,7 +821,8 @@ ParseResult OperationParser::parseOperation() {
asmResultGroups.emplace_back(resultIt, std::get<2>(record));
resultIt += std::get<1>(record);
}
- state.asmState->addDefinition(op, nameTok.getLocRange(), asmResultGroups);
+ state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
+ asmResultGroups);
}
// Add definitions for each of the result groups.
@@ -824,7 +837,7 @@ ParseResult OperationParser::parseOperation() {
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
- state.asmState->addDefinition(op, nameTok.getLocRange());
+ state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange());
}
return success();
@@ -903,6 +916,10 @@ Operation *OperationParser::parseGenericOperation() {
}
}
+ // If we are populating the parser state, start a new operation definition.
+ if (state.asmState)
+ state.asmState->startOperationDefinition(result.name);
+
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@@ -981,9 +998,19 @@ Operation *OperationParser::parseGenericOperation() {
Operation *OperationParser::parseGenericOperation(Block *insertBlock,
Block::iterator insertPt) {
+ Token nameToken = getToken();
+
OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder);
opBuilder.setInsertionPoint(insertBlock, insertPt);
- return parseGenericOperation();
+ Operation *op = parseGenericOperation();
+ if (!op)
+ return nullptr;
+
+ // If we are populating the parser asm state, finalize this operation
+ // definition.
+ if (state.asmState)
+ state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange());
+ return op;
}
namespace {
@@ -1367,6 +1394,14 @@ class CustomOpAsmParser : public OpAsmParser {
result = getBuilder().getStringAttr(atToken.getSymbolReference());
attrs.push_back(getBuilder().getNamedAttr(attrName, result));
parser.consumeToken();
+
+ // If we are populating the assembly parser state, record this as a symbol
+ // reference.
+ if (parser.getState().asmState) {
+ parser.getState().asmState->addUses(
+ getBuilder().getSymbolRefAttr(result.getValue()),
+ atToken.getLocRange());
+ }
return success();
}
@@ -1858,9 +1893,13 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(opLoc);
+ OperationState opState(srcLocation, opName);
+
+ // If we are populating the parser state, start a new operation definition.
+ if (state.asmState)
+ state.asmState->startOperationDefinition(opState.name);
// Have the op implementation take a crack and parsing this.
- OperationState opState(srcLocation, opName);
CleanupOpStateRegions guard{opState};
CustomOpAsmParser opAsmParser(opLoc, resultIDs, parseAssemblyFn,
isIsolatedFromAbove, opName, *this);
@@ -1931,10 +1970,6 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
// Region Parsing
//===----------------------------------------------------------------------===//
-/// Region.
-///
-/// region ::= '{' region-body
-///
ParseResult OperationParser::parseRegion(
Region ®ion,
ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
@@ -1944,9 +1979,29 @@ ParseResult OperationParser::parseRegion(
if (parseToken(Token::l_brace, "expected '{' to begin a region"))
return failure();
- // Check for an empty region.
- if (entryArguments.empty() && consumeIf(Token::r_brace))
- return success();
+ // If we are populating the parser state, start a new region definition.
+ if (state.asmState)
+ state.asmState->startRegionDefinition();
+
+ // Parse the region body.
+ if ((!entryArguments.empty() || getToken().isNot(Token::r_brace)) &&
+ parseRegionBody(region, lBraceTok.getLoc(), entryArguments,
+ isIsolatedNameScope)) {
+ return failure();
+ }
+ consumeToken(Token::r_brace);
+
+ // If we are populating the parser state, finalize this region.
+ if (state.asmState)
+ state.asmState->finalizeRegionDefinition();
+
+ return success();
+}
+
+ParseResult OperationParser::parseRegionBody(
+ Region ®ion, llvm::SMLoc startLoc,
+ ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
+ bool isIsolatedNameScope) {
auto currentPt = opBuilder.saveInsertionPoint();
// Push a new named value scope.
@@ -1960,7 +2015,7 @@ ParseResult OperationParser::parseRegion(
// now in the assembly state. Blocks with a name will be defined when the name
// is parsed.
if (state.asmState && getToken().isNot(Token::caret_identifier))
- state.asmState->addDefinition(block, lBraceTok.getLoc());
+ state.asmState->addDefinition(block, startLoc);
// Add arguments to the entry block.
if (!entryArguments.empty()) {
@@ -2002,8 +2057,12 @@ ParseResult OperationParser::parseRegion(
// Parse the rest of the region.
region.push_back(owning_block.release());
- if (parseRegionBody(region))
- return failure();
+ while (getToken().isNot(Token::r_brace)) {
+ Block *newBlock = nullptr;
+ if (parseBlock(newBlock))
+ return failure();
+ region.push_back(newBlock);
+ }
// Pop the SSA value scope for this region.
if (popSSANameScope())
@@ -2014,21 +2073,6 @@ ParseResult OperationParser::parseRegion(
return success();
}
-/// Region.
-///
-/// region-body ::= block* '}'
-///
-ParseResult OperationParser::parseRegionBody(Region ®ion) {
- // Parse the list of blocks.
- while (!consumeIf(Token::r_brace)) {
- Block *newBlock = nullptr;
- if (parseBlock(newBlock))
- return failure();
- region.push_back(newBlock);
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Block Parsing
//===----------------------------------------------------------------------===//
@@ -2278,7 +2322,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
// If we got to the end of the file, then we're done.
case Token::eof: {
- if (opParser.finalize())
+ if (opParser.finalize(topLevelOp.get()))
return failure();
// Verify that the parsed operations are valid.
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 3b1dfcc6195e..d4b61372552d 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -21,7 +21,7 @@ static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) {
std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
lsp::Position pos;
pos.line = lineAndCol.first - 1;
- pos.character = lineAndCol.second;
+ pos.character = lineAndCol.second - 1;
return pos;
}
@@ -33,10 +33,7 @@ static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
/// Returns a language server range for the given source range.
static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) {
- // lsp::Range is an inclusive range, SMRange is half-open.
- llvm::SMLoc inclusiveEnd =
- llvm::SMLoc::getFromPointer(range.End.getPointer() - 1);
- return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, inclusiveEnd)};
+ return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
}
/// Returns a language server location from the given source range.
@@ -365,6 +362,12 @@ void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
for (const auto &result : op.resultGroups)
if (containsPosition(result.second))
return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
+ for (const auto &symUse : op.symbolUses) {
+ if (contains(symUse, posLoc)) {
+ locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
+ return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
+ }
+ }
}
// Check all definitions related to blocks.
@@ -395,11 +398,21 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
if (contains(op.loc, posLoc)) {
for (const auto &result : op.resultGroups)
appendSMDef(result.second);
+ for (const auto &symUse : op.symbolUses)
+ if (contains(symUse, posLoc))
+ references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
return;
}
for (const auto &result : op.resultGroups)
if (isDefOrUse(result.second, posLoc))
return appendSMDef(result.second);
+ for (const auto &symUse : op.symbolUses) {
+ if (!contains(symUse, posLoc))
+ continue;
+ for (const auto &symUse : op.symbolUses)
+ references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
+ return;
+ }
}
// Check all definitions related to blocks.
diff --git a/mlir/test/mlir-lsp-server/definition-split-file.test b/mlir/test/mlir-lsp-server/definition-split-file.test
index c32f8be396a7..14bd2e2da3f4 100644
--- a/mlir/test/mlir-lsp-server/definition-split-file.test
+++ b/mlir/test/mlir-lsp-server/definition-split-file.test
@@ -25,7 +25,7 @@
// CHECK-NEXT: "line": 3
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 1,
+// CHECK-NEXT: "character": 0,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: }
// CHECK-NEXT: },
diff --git a/mlir/test/mlir-lsp-server/definition.test b/mlir/test/mlir-lsp-server/definition.test
index ddcebcfa28f3..5911f9c2f0ca 100644
--- a/mlir/test/mlir-lsp-server/definition.test
+++ b/mlir/test/mlir-lsp-server/definition.test
@@ -22,13 +22,34 @@
// CHECK-NEXT: "line": 1
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 1,
+// CHECK-NEXT: "character": 0,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "uri": "{{.*}}/foo.mlir"
// CHECK-NEXT: }
// -----
+{"jsonrpc":"2.0","id":2,"method":"textDocument/definition","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":0,"character":7}
+}}
+// CHECK: "id": 2
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "range": {
+// CHECK-NEXT: "end": {
+// CHECK-NEXT: "character": 4,
+// CHECK-NEXT: "line": 0
+// CHECK-NEXT: },
+// CHECK-NEXT: "start": {
+// CHECK-NEXT: "character": 0,
+// CHECK-NEXT: "line": 0
+// CHECK-NEXT: }
+// CHECK-NEXT: },
+// CHECK-NEXT: "uri": "{{.*}}/foo.mlir"
+// CHECK-NEXT: }
+// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
diff --git a/mlir/test/mlir-lsp-server/hover.test b/mlir/test/mlir-lsp-server/hover.test
index 77fff2a79f89..71008c8d9ab2 100644
--- a/mlir/test/mlir-lsp-server/hover.test
+++ b/mlir/test/mlir-lsp-server/hover.test
@@ -26,7 +26,7 @@
// CHECK-NEXT: "line": 1
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 10,
+// CHECK-NEXT: "character": 9,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -50,7 +50,7 @@
// CHECK-NEXT: "line": 1
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 1,
+// CHECK-NEXT: "character": 0,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -74,7 +74,7 @@
// CHECK-NEXT: "line": 3
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 1,
+// CHECK-NEXT: "character": 0,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -98,7 +98,7 @@
// CHECK-NEXT: "line": 0
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 11,
+// CHECK-NEXT: "character": 10,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: }
// CHECK-NEXT: }
diff --git a/mlir/test/mlir-lsp-server/references.test b/mlir/test/mlir-lsp-server/references.test
index 333329ddf341..8d3c1b6f25c5 100644
--- a/mlir/test/mlir-lsp-server/references.test
+++ b/mlir/test/mlir-lsp-server/references.test
@@ -23,7 +23,7 @@
// CHECK-NEXT: "line": 1
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 1,
+// CHECK-NEXT: "character": 0,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: }
// CHECK-NEXT: },
@@ -36,7 +36,7 @@
// CHECK-NEXT: "line": 2
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
-// CHECK-NEXT: "character": 8,
+// CHECK-NEXT: "character": 7,
// CHECK-NEXT: "line": 2
// CHECK-NEXT: }
// CHECK-NEXT: },
@@ -44,6 +44,29 @@
// CHECK-NEXT: }
// CHECK-NEXT: ]
// -----
+{"jsonrpc":"2.0","id":2,"method":"textDocument/references","params":{
+ "textDocument":{"uri":"test:///foo.mlir"},
+ "position":{"line":0,"character":7},
+ "context":{"includeDeclaration": false}
+}}
+// CHECK: "id": 2
+// CHECK-NEXT: "jsonrpc": "2.0",
+// CHECK-NEXT: "result": [
+// CHECK-NEXT: {
+// CHECK-NEXT: "range": {
+// CHECK-NEXT: "end": {
+// CHECK-NEXT: "character": 9,
+// CHECK-NEXT: "line": 0
+// CHECK-NEXT: },
+// CHECK-NEXT: "start": {
+// CHECK-NEXT: "character": 5,
+// CHECK-NEXT: "line": 0
+// CHECK-NEXT: }
+// CHECK-NEXT: },
+// CHECK-NEXT: "uri": "{{.*}}/foo.mlir"
+// CHECK-NEXT: }
+// CHECK-NEXT: ]
+// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
More information about the Mlir-commits
mailing list