[Mlir-commits] [mlir] de55c2f - Revert "[mlir-query] Add function extraction feature to mlir-query"
Mehdi Amini
llvmlistbot at llvm.org
Thu Feb 29 13:14:20 PST 2024
Author: Mehdi Amini
Date: 2024-02-29T13:14:00-08:00
New Revision: de55c2f869925a3ed7f26e168424021c6bc46799
URL: https://github.com/llvm/llvm-project/commit/de55c2f869925a3ed7f26e168424021c6bc46799
DIFF: https://github.com/llvm/llvm-project/commit/de55c2f869925a3ed7f26e168424021c6bc46799.diff
LOG: Revert "[mlir-query] Add function extraction feature to mlir-query"
This reverts commit c66f2d0c4a46ba66fb98a2cab4e63ad90888a261.
The bot is broken.
Added:
Modified:
mlir/include/mlir/Query/Matcher/ErrorBuilder.h
mlir/include/mlir/Query/Matcher/MatchersInternal.h
mlir/lib/Query/Matcher/Diagnostics.cpp
mlir/lib/Query/Matcher/Parser.cpp
mlir/lib/Query/Matcher/Parser.h
mlir/lib/Query/Matcher/RegistryManager.cpp
mlir/lib/Query/Matcher/RegistryManager.h
mlir/lib/Query/Query.cpp
Removed:
mlir/test/mlir-query/function-extraction.mlir
################################################################################
diff --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
index 08f1f415cbd3e5..1073daed8703f5 100644
--- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -37,12 +37,8 @@ enum class ErrorType {
None,
// Parser Errors
- ParserChainedExprInvalidArg,
- ParserChainedExprNoCloseParen,
- ParserChainedExprNoOpenParen,
ParserFailedToBuildMatcher,
ParserInvalidToken,
- ParserMalformedChainedExpr,
ParserNoCloseParen,
ParserNoCode,
ParserNoComma,
@@ -54,10 +50,9 @@ enum class ErrorType {
// Registry Errors
RegistryMatcherNotFound,
- RegistryNotBindable,
RegistryValueNotFound,
RegistryWrongArgCount,
- RegistryWrongArgType,
+ RegistryWrongArgType
};
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..67455be592393b 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -63,15 +63,8 @@ class DynMatcher {
bool match(Operation *op) const { return implementation->match(op); }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
-
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
- std::string functionName;
};
} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
index 2a137e8fdfab0d..10468dbcc53067 100644
--- a/mlir/lib/Query/Matcher/Diagnostics.cpp
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -38,8 +38,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
case ErrorType::RegistryValueNotFound:
return "Value not found: $0";
- case ErrorType::RegistryNotBindable:
- return "Matcher does not support binding.";
case ErrorType::ParserStringError:
return "Error parsing string token: <$0>";
@@ -59,14 +57,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Unexpected end of code.";
case ErrorType::ParserOverloadedType:
return "Input value has unresolved overloaded type: $0";
- case ErrorType::ParserMalformedChainedExpr:
- return "Period not followed by valid chained call.";
- case ErrorType::ParserChainedExprInvalidArg:
- return "Missing/Invalid argument for the chained call.";
- case ErrorType::ParserChainedExprNoCloseParen:
- return "Missing ')' for the chained call.";
- case ErrorType::ParserChainedExprNoOpenParen:
- return "Missing '(' for the chained call.";
case ErrorType::ParserFailedToBuildMatcher:
return "Failed to build matcher: $0.";
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..30eb4801fc03c1 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -26,17 +26,12 @@ struct Parser::TokenInfo {
text = newText;
}
- // Known identifiers.
- static const char *const ID_Extract;
-
llvm::StringRef text;
TokenKind kind = TokenKind::Eof;
SourceRange range;
VariantValue value;
};
-const char *const Parser::TokenInfo::ID_Extract = "extract";
-
class Parser::CodeTokenizer {
public:
// Constructor with matcherCode and error
@@ -303,36 +298,6 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
}
-bool Parser::parseChainedExpression(std::string &argument) {
- // Parse the parenthesized argument to .extract("foo")
- // Note: EOF is handled inside the consume functions and would fail below when
- // checking token kind.
- const TokenInfo openToken = tokenizer->consumeNextToken();
- const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
- const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
-
- if (openToken.kind != TokenKind::OpenParen) {
- error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
- return false;
- }
-
- if (argumentToken.kind != TokenKind::Literal ||
- !argumentToken.value.isString()) {
- error->addError(argumentToken.range,
- ErrorType::ParserChainedExprInvalidArg);
- return false;
- }
-
- if (closeToken.kind != TokenKind::CloseParen) {
- error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
- return false;
- }
-
- // If all checks passed, extract the argument and return true.
- argument = argumentToken.value.getString();
- return true;
-}
-
// Parse the arguments of a matcher
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -399,34 +364,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
return false;
}
- std::string functionName;
- if (tokenizer->peekNextToken().kind == TokenKind::Period) {
- tokenizer->consumeNextToken();
- TokenInfo chainCallToken = tokenizer->consumeNextToken();
- if (chainCallToken.kind == TokenKind::CodeCompletion) {
- addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
- return false;
- }
-
- if (chainCallToken.kind != TokenKind::Ident ||
- chainCallToken.text != TokenInfo::ID_Extract) {
- error->addError(chainCallToken.range,
- ErrorType::ParserMalformedChainedExpr);
- return false;
- }
-
- if (chainCallToken.text == TokenInfo::ID_Extract &&
- !parseChainedExpression(functionName))
- return false;
- }
-
if (!ctor)
return false;
// Merge the start and end infos.
SourceRange matcherRange = nameToken.range;
matcherRange.end = endToken.range.end;
- VariantMatcher result = sema->actOnMatcherExpression(
- *ctor, matcherRange, functionName, args, error);
+ VariantMatcher result =
+ sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
if (result.isNull())
return false;
*value = result;
@@ -526,10 +470,9 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
}
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
- MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
- llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
- return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
- error);
+ MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) {
+ return RegistryManager::constructMatcher(ctor, nameRange, args, error);
}
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index 58968023022d56..f049af34e9c907 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -64,9 +64,10 @@ class Parser {
// Process a matcher expression. The caller takes ownership of the Matcher
// object returned.
- virtual VariantMatcher actOnMatcherExpression(
- MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
- llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
+ virtual VariantMatcher
+ actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) = 0;
// Look up a matcher by name in the matcher name found by the parser.
virtual std::optional<MatcherCtor>
@@ -92,11 +93,10 @@ class Parser {
std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName) override;
- VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
- SourceRange NameRange,
- StringRef functionName,
- ArrayRef<ParserValue> Args,
- Diagnostics *Error) override;
+ VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
+ SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) override;
std::vector<ArgKind> getAcceptedCompletionTypes(
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,8 +153,6 @@ class Parser {
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
const NamedValueMap *namedValues, Diagnostics *error);
- bool parseChainedExpression(std::string &argument);
-
bool parseExpressionImpl(VariantValue *value);
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 8c9197f4d00981..01856aa8ffa67f 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -132,19 +132,8 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
VariantMatcher RegistryManager::constructMatcher(
MatcherCtor ctor, internal::SourceRange nameRange,
- llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
- internal::Diagnostics *error) {
- VariantMatcher out = ctor->create(nameRange, args, error);
- if (functionName.empty() || out.isNull())
- return out;
-
- if (std::optional<DynMatcher> result = out.getDynMatcher()) {
- result->setFunctionName(functionName);
- return VariantMatcher::SingleMatcher(*result);
- }
-
- error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
- return {};
+ llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
+ return ctor->create(nameRange, args, error);
}
} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
index e2026e97f83dcb..5f2867261225e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.h
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -61,7 +61,6 @@ class RegistryManager {
static VariantMatcher constructMatcher(MatcherCtor ctor,
internal::SourceRange nameRange,
- llvm::StringRef functionName,
ArrayRef<ParserValue> args,
internal::Diagnostics *error);
};
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 27db52b37dade0..5c42e5a5f0a116 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -8,8 +8,6 @@
#include "mlir/Query/Query.h"
#include "QueryParser.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
#include "mlir/Support/LogicalResult.h"
@@ -36,70 +34,6 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
"\"" + binding + "\" binds here");
}
-// TODO: Extract into a helper function that can be reused outside query
-// context.
-static Operation *extractFunction(std::vector<Operation *> &ops,
- MLIRContext *context,
- llvm::StringRef functionName) {
- context->loadDialect<func::FuncDialect>();
- OpBuilder builder(context);
-
- // Collect data for function creation
- std::vector<Operation *> slice;
- std::vector<Value> values;
- std::vector<Type> outputTypes;
-
- for (auto *op : ops) {
- // Return op's operands are propagated, but the op itself isn't needed.
- if (!isa<func::ReturnOp>(op))
- slice.push_back(op);
-
- // All results are returned by the extracted function.
- outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
- op->getResults().getTypes().end());
-
- // Track all values that need to be taken as input to function.
- values.insert(values.end(), op->getOperands().begin(),
- op->getOperands().end());
- }
-
- // Create the function
- FunctionType funcType =
- builder.getFunctionType(ValueRange(values), outputTypes);
- auto loc = builder.getUnknownLoc();
- func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
-
- builder.setInsertionPointToEnd(funcOp.addEntryBlock());
-
- // Map original values to function arguments
- IRMapping mapper;
- for (const auto &arg : llvm::enumerate(values))
- mapper.map(arg.value(), funcOp.getArgument(arg.index()));
-
- // Clone operations and build function body
- std::vector<Operation *> clonedOps;
- std::vector<Value> clonedVals;
- for (Operation *slicedOp : slice) {
- Operation *clonedOp =
- clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
- clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
- clonedOp->result_end());
- }
- // Add return operation
- builder.create<func::ReturnOp>(loc, clonedVals);
-
- // Remove unused function arguments
- size_t currentIndex = 0;
- while (currentIndex < funcOp.getNumArguments()) {
- if (funcOp.getArgument(currentIndex).use_empty())
- funcOp.eraseArgument(currentIndex);
- else
- ++currentIndex;
- }
-
- return funcOp;
-}
-
Query::~Query() = default;
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -131,21 +65,9 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
QuerySession &qs) const {
- Operation *rootOp = qs.getRootOp();
int matchCount = 0;
std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(rootOp, matcher);
-
- // An extract call is recognized by considering if the matcher has a name.
- // TODO: Consider making the extract more explicit.
- if (matcher.hasFunctionName()) {
- auto functionName = matcher.getFunctionName();
- Operation *function =
- extractFunction(matches, rootOp->getContext(), functionName);
- os << "\n" << *function << "\n\n";
- return mlir::success();
- }
-
+ matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
os << "\n";
for (Operation *op : matches) {
os << "Match #" << ++matchCount << ":\n\n";
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
deleted file mode 100644
index a783f65c6761bc..00000000000000
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ /dev/null
@@ -1,19 +0,0 @@
-// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
-
-// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
-// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
-// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
-// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
-
-func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
- %sum0 = arith.addf %a, %b : f32
- %sub0 = arith.subf %sum0, %c : f32
- %mul0 = arith.mulf %a, %sub0 : f32
- %sum1 = arith.addf %b, %c : f32
- %mul1 = arith.mulf %sum1, %mul0 : f32
- %sub2 = arith.subf %mul1, %a : f32
- %sum2 = arith.addf %mul1, %b : f32
- %mul2 = arith.mulf %sub2, %sum2 : f32
- return %mul2 : f32
-}
More information about the Mlir-commits
mailing list