[Mlir-commits] [mlir] 58b44c8 - Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""
Jacques Pienaar
llvmlistbot at llvm.org
Sun Mar 3 05:57:04 PST 2024
Author: Jacques Pienaar
Date: 2024-03-03T05:56:56-08:00
New Revision: 58b44c8102afb0e76d1cb70d4a5d089f70d2f657
URL: https://github.com/llvm/llvm-project/commit/58b44c8102afb0e76d1cb70d4a5d089f70d2f657
DIFF: https://github.com/llvm/llvm-project/commit/58b44c8102afb0e76d1cb70d4a5d089f70d2f657.diff
LOG: Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""
Fix ASAN by erasing the op extracted post printing.
This reverts commit 732a5cba8c739ed40a7280b5d74ca717910c2c4c.
Added:
mlir/test/mlir-query/function-extraction.mlir
Modified:
mlir/include/mlir/Query/Matcher/ErrorBuilder.h
mlir/include/mlir/Query/Matcher/MatchersInternal.h
mlir/lib/Query/CMakeLists.txt
mlir/lib/Query/Matcher/Diagnostics.cpp
mlir/lib/Query/Matcher/Parser.cpp
mlir/lib/Query/Matcher/Parser.h
mlir/lib/Query/Matcher/RegistryManager.cpp
mlir/lib/Query/Matcher/RegistryManager.h
mlir/lib/Query/Query.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
index 1073daed8703f5..08f1f415cbd3e5 100644
--- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -37,8 +37,12 @@ enum class ErrorType {
None,
// Parser Errors
+ ParserChainedExprInvalidArg,
+ ParserChainedExprNoCloseParen,
+ ParserChainedExprNoOpenParen,
ParserFailedToBuildMatcher,
ParserInvalidToken,
+ ParserMalformedChainedExpr,
ParserNoCloseParen,
ParserNoCode,
ParserNoComma,
@@ -50,9 +54,10 @@ enum class ErrorType {
// Registry Errors
RegistryMatcherNotFound,
+ RegistryNotBindable,
RegistryValueNotFound,
RegistryWrongArgCount,
- RegistryWrongArgType
+ RegistryWrongArgType,
};
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 67455be592393b..117f7d4edef9e3 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -63,8 +63,15 @@ class DynMatcher {
bool match(Operation *op) const { return implementation->match(op); }
+ void setFunctionName(StringRef name) { functionName = name.str(); };
+
+ bool hasFunctionName() const { return !functionName.empty(); };
+
+ StringRef getFunctionName() const { return functionName; };
+
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+ std::string functionName;
};
} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt
index 817583e94c5222..7ecbf6e628b318 100644
--- a/mlir/lib/Query/CMakeLists.txt
+++ b/mlir/lib/Query/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRQuery
${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
LINK_LIBS PUBLIC
+ MLIRFuncDialect
MLIRQueryMatcher
)
diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
index 10468dbcc53067..2a137e8fdfab0d 100644
--- a/mlir/lib/Query/Matcher/Diagnostics.cpp
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
case ErrorType::RegistryValueNotFound:
return "Value not found: $0";
+ case ErrorType::RegistryNotBindable:
+ return "Matcher does not support binding.";
case ErrorType::ParserStringError:
return "Error parsing string token: <$0>";
@@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
return "Unexpected end of code.";
case ErrorType::ParserOverloadedType:
return "Input value has unresolved overloaded type: $0";
+ case ErrorType::ParserMalformedChainedExpr:
+ return "Period not followed by valid chained call.";
+ case ErrorType::ParserChainedExprInvalidArg:
+ return "Missing/Invalid argument for the chained call.";
+ case ErrorType::ParserChainedExprNoCloseParen:
+ return "Missing ')' for the chained call.";
+ case ErrorType::ParserChainedExprNoOpenParen:
+ return "Missing '(' for the chained call.";
case ErrorType::ParserFailedToBuildMatcher:
return "Failed to build matcher: $0.";
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 30eb4801fc03c1..3609e24f9939f7 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -26,12 +26,17 @@ struct Parser::TokenInfo {
text = newText;
}
+ // Known identifiers.
+ static const char *const ID_Extract;
+
llvm::StringRef text;
TokenKind kind = TokenKind::Eof;
SourceRange range;
VariantValue value;
};
+const char *const Parser::TokenInfo::ID_Extract = "extract";
+
class Parser::CodeTokenizer {
public:
// Constructor with matcherCode and error
@@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
}
+bool Parser::parseChainedExpression(std::string &argument) {
+ // Parse the parenthesized argument to .extract("foo")
+ // Note: EOF is handled inside the consume functions and would fail below when
+ // checking token kind.
+ const TokenInfo openToken = tokenizer->consumeNextToken();
+ const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
+ const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
+
+ if (openToken.kind != TokenKind::OpenParen) {
+ error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
+ return false;
+ }
+
+ if (argumentToken.kind != TokenKind::Literal ||
+ !argumentToken.value.isString()) {
+ error->addError(argumentToken.range,
+ ErrorType::ParserChainedExprInvalidArg);
+ return false;
+ }
+
+ if (closeToken.kind != TokenKind::CloseParen) {
+ error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
+ return false;
+ }
+
+ // If all checks passed, extract the argument and return true.
+ argument = argumentToken.value.getString();
+ return true;
+}
+
// Parse the arguments of a matcher
bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
return false;
}
+ std::string functionName;
+ if (tokenizer->peekNextToken().kind == TokenKind::Period) {
+ tokenizer->consumeNextToken();
+ TokenInfo chainCallToken = tokenizer->consumeNextToken();
+ if (chainCallToken.kind == TokenKind::CodeCompletion) {
+ addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
+ return false;
+ }
+
+ if (chainCallToken.kind != TokenKind::Ident ||
+ chainCallToken.text != TokenInfo::ID_Extract) {
+ error->addError(chainCallToken.range,
+ ErrorType::ParserMalformedChainedExpr);
+ return false;
+ }
+
+ if (chainCallToken.text == TokenInfo::ID_Extract &&
+ !parseChainedExpression(functionName))
+ return false;
+ }
+
if (!ctor)
return false;
// Merge the start and end infos.
SourceRange matcherRange = nameToken.range;
matcherRange.end = endToken.range.end;
- VariantMatcher result =
- sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
+ VariantMatcher result = sema->actOnMatcherExpression(
+ *ctor, matcherRange, functionName, args, error);
if (result.isNull())
return false;
*value = result;
@@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
}
VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
- MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
- Diagnostics *error) {
- return RegistryManager::constructMatcher(ctor, nameRange, args, error);
+ MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+ llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
+ return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
+ error);
}
std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index f049af34e9c907..58968023022d56 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -64,10 +64,9 @@ class Parser {
// Process a matcher expression. The caller takes ownership of the Matcher
// object returned.
- virtual VariantMatcher
- actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
- llvm::ArrayRef<ParserValue> args,
- Diagnostics *error) = 0;
+ virtual VariantMatcher actOnMatcherExpression(
+ MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+ llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
// Look up a matcher by name in the matcher name found by the parser.
virtual std::optional<MatcherCtor>
@@ -93,10 +92,11 @@ class Parser {
std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName) override;
- VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
- SourceRange nameRange,
- llvm::ArrayRef<ParserValue> args,
- Diagnostics *error) override;
+ VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
+ SourceRange NameRange,
+ StringRef functionName,
+ ArrayRef<ParserValue> Args,
+ Diagnostics *Error) override;
std::vector<ArgKind> getAcceptedCompletionTypes(
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,6 +153,8 @@ class Parser {
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
const NamedValueMap *namedValues, Diagnostics *error);
+ bool parseChainedExpression(std::string &argument);
+
bool parseExpressionImpl(VariantValue *value);
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 01856aa8ffa67f..8c9197f4d00981 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
VariantMatcher RegistryManager::constructMatcher(
MatcherCtor ctor, internal::SourceRange nameRange,
- llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
- return ctor->create(nameRange, args, error);
+ llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
+ internal::Diagnostics *error) {
+ VariantMatcher out = ctor->create(nameRange, args, error);
+ if (functionName.empty() || out.isNull())
+ return out;
+
+ if (std::optional<DynMatcher> result = out.getDynMatcher()) {
+ result->setFunctionName(functionName);
+ return VariantMatcher::SingleMatcher(*result);
+ }
+
+ error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
+ return {};
}
} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
index 5f2867261225e7..e2026e97f83dcb 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.h
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -61,6 +61,7 @@ class RegistryManager {
static VariantMatcher constructMatcher(MatcherCtor ctor,
internal::SourceRange nameRange,
+ llvm::StringRef functionName,
ArrayRef<ParserValue> args,
internal::Diagnostics *error);
};
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 5c42e5a5f0a116..7fdbbd181234b6 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -8,6 +8,8 @@
#include "mlir/Query/Query.h"
#include "QueryParser.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
#include "mlir/Support/LogicalResult.h"
@@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
"\"" + binding + "\" binds here");
}
+// TODO: Extract into a helper function that can be reused outside query
+// context.
+static Operation *extractFunction(std::vector<Operation *> &ops,
+ MLIRContext *context,
+ llvm::StringRef functionName) {
+ context->loadDialect<func::FuncDialect>();
+ OpBuilder builder(context);
+
+ // Collect data for function creation
+ std::vector<Operation *> slice;
+ std::vector<Value> values;
+ std::vector<Type> outputTypes;
+
+ for (auto *op : ops) {
+ // Return op's operands are propagated, but the op itself isn't needed.
+ if (!isa<func::ReturnOp>(op))
+ slice.push_back(op);
+
+ // All results are returned by the extracted function.
+ outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
+ op->getResults().getTypes().end());
+
+ // Track all values that need to be taken as input to function.
+ values.insert(values.end(), op->getOperands().begin(),
+ op->getOperands().end());
+ }
+
+ // Create the function
+ FunctionType funcType =
+ builder.getFunctionType(ValueRange(values), outputTypes);
+ auto loc = builder.getUnknownLoc();
+ func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
+
+ builder.setInsertionPointToEnd(funcOp.addEntryBlock());
+
+ // Map original values to function arguments
+ IRMapping mapper;
+ for (const auto &arg : llvm::enumerate(values))
+ mapper.map(arg.value(), funcOp.getArgument(arg.index()));
+
+ // Clone operations and build function body
+ std::vector<Operation *> clonedOps;
+ std::vector<Value> clonedVals;
+ for (Operation *slicedOp : slice) {
+ Operation *clonedOp =
+ clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
+ clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
+ clonedOp->result_end());
+ }
+ // Add return operation
+ builder.create<func::ReturnOp>(loc, clonedVals);
+
+ // Remove unused function arguments
+ size_t currentIndex = 0;
+ while (currentIndex < funcOp.getNumArguments()) {
+ if (funcOp.getArgument(currentIndex).use_empty())
+ funcOp.eraseArgument(currentIndex);
+ else
+ ++currentIndex;
+ }
+
+ return funcOp;
+}
+
Query::~Query() = default;
mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -65,9 +131,22 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
QuerySession &qs) const {
+ Operation *rootOp = qs.getRootOp();
int matchCount = 0;
std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
+ matcher::MatchFinder().getMatches(rootOp, matcher);
+
+ // An extract call is recognized by considering if the matcher has a name.
+ // TODO: Consider making the extract more explicit.
+ if (matcher.hasFunctionName()) {
+ auto functionName = matcher.getFunctionName();
+ Operation *function =
+ extractFunction(matches, rootOp->getContext(), functionName);
+ os << "\n" << *function << "\n\n";
+ function->erase();
+ return mlir::success();
+ }
+
os << "\n";
for (Operation *op : matches) {
os << "Match #" << ++matchCount << ":\n\n";
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
new file mode 100644
index 00000000000000..a783f65c6761bc
--- /dev/null
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+
+// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
+// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
+// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
+// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
+// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+
+func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
+ %sum0 = arith.addf %a, %b : f32
+ %sub0 = arith.subf %sum0, %c : f32
+ %mul0 = arith.mulf %a, %sub0 : f32
+ %sum1 = arith.addf %b, %c : f32
+ %mul1 = arith.mulf %sum1, %mul0 : f32
+ %sub2 = arith.subf %mul1, %a : f32
+ %sum2 = arith.addf %mul1, %b : f32
+ %mul2 = arith.mulf %sub2, %sum2 : f32
+ return %mul2 : f32
+}
More information about the Mlir-commits
mailing list