[Mlir-commits] [mlir] c66f2d0 - [mlir-query] Add function extraction feature to mlir-query

Jacques Pienaar llvmlistbot at llvm.org
Thu Feb 29 07:46:57 PST 2024


Author: Devajith Valaparambil Sreeramaswamy
Date: 2024-02-29T07:46:49-08:00
New Revision: c66f2d0c4a46ba66fb98a2cab4e63ad90888a261

URL: https://github.com/llvm/llvm-project/commit/c66f2d0c4a46ba66fb98a2cab4e63ad90888a261
DIFF: https://github.com/llvm/llvm-project/commit/c66f2d0c4a46ba66fb98a2cab4e63ad90888a261.diff

LOG: [mlir-query] Add function extraction feature to mlir-query

This enables specifying the extract modifier to extract all matches into
a function. This currently does this very directly by converting all
operands to function arguments (ones due to results of other matched ops
are dropped) and all results as return values.

Differential Revision: https://reviews.llvm.org/D158693

Added: 
    mlir/test/mlir-query/function-extraction.mlir

Modified: 
    mlir/include/mlir/Query/Matcher/ErrorBuilder.h
    mlir/include/mlir/Query/Matcher/MatchersInternal.h
    mlir/lib/Query/Matcher/Diagnostics.cpp
    mlir/lib/Query/Matcher/Parser.cpp
    mlir/lib/Query/Matcher/Parser.h
    mlir/lib/Query/Matcher/RegistryManager.cpp
    mlir/lib/Query/Matcher/RegistryManager.h
    mlir/lib/Query/Query.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
index 1073daed8703f5..08f1f415cbd3e5 100644
--- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -37,8 +37,12 @@ enum class ErrorType {
   None,
 
   // Parser Errors
+  ParserChainedExprInvalidArg,
+  ParserChainedExprNoCloseParen,
+  ParserChainedExprNoOpenParen,
   ParserFailedToBuildMatcher,
   ParserInvalidToken,
+  ParserMalformedChainedExpr,
   ParserNoCloseParen,
   ParserNoCode,
   ParserNoComma,
@@ -50,9 +54,10 @@ enum class ErrorType {
 
   // Registry Errors
   RegistryMatcherNotFound,
+  RegistryNotBindable,
   RegistryValueNotFound,
   RegistryWrongArgCount,
-  RegistryWrongArgType
+  RegistryWrongArgType,
 };
 
 void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

diff  --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 67455be592393b..117f7d4edef9e3 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -63,8 +63,15 @@ class DynMatcher {
 
   bool match(Operation *op) const { return implementation->match(op); }
 
+  void setFunctionName(StringRef name) { functionName = name.str(); };
+
+  bool hasFunctionName() const { return !functionName.empty(); };
+
+  StringRef getFunctionName() const { return functionName; };
+
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+  std::string functionName;
 };
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
index 10468dbcc53067..2a137e8fdfab0d 100644
--- a/mlir/lib/Query/Matcher/Diagnostics.cpp
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
   case ErrorType::RegistryValueNotFound:
     return "Value not found: $0";
+  case ErrorType::RegistryNotBindable:
+    return "Matcher does not support binding.";
 
   case ErrorType::ParserStringError:
     return "Error parsing string token: <$0>";
@@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Unexpected end of code.";
   case ErrorType::ParserOverloadedType:
     return "Input value has unresolved overloaded type: $0";
+  case ErrorType::ParserMalformedChainedExpr:
+    return "Period not followed by valid chained call.";
+  case ErrorType::ParserChainedExprInvalidArg:
+    return "Missing/Invalid argument for the chained call.";
+  case ErrorType::ParserChainedExprNoCloseParen:
+    return "Missing ')' for the chained call.";
+  case ErrorType::ParserChainedExprNoOpenParen:
+    return "Missing '(' for the chained call.";
   case ErrorType::ParserFailedToBuildMatcher:
     return "Failed to build matcher: $0.";
 

diff  --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 30eb4801fc03c1..3609e24f9939f7 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -26,12 +26,17 @@ struct Parser::TokenInfo {
     text = newText;
   }
 
+  // Known identifiers.
+  static const char *const ID_Extract;
+
   llvm::StringRef text;
   TokenKind kind = TokenKind::Eof;
   SourceRange range;
   VariantValue value;
 };
 
+const char *const Parser::TokenInfo::ID_Extract = "extract";
+
 class Parser::CodeTokenizer {
 public:
   // Constructor with matcherCode and error
@@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
   return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
 }
 
+bool Parser::parseChainedExpression(std::string &argument) {
+  // Parse the parenthesized argument to .extract("foo")
+  // Note: EOF is handled inside the consume functions and would fail below when
+  // checking token kind.
+  const TokenInfo openToken = tokenizer->consumeNextToken();
+  const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
+  const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
+
+  if (openToken.kind != TokenKind::OpenParen) {
+    error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
+    return false;
+  }
+
+  if (argumentToken.kind != TokenKind::Literal ||
+      !argumentToken.value.isString()) {
+    error->addError(argumentToken.range,
+                    ErrorType::ParserChainedExprInvalidArg);
+    return false;
+  }
+
+  if (closeToken.kind != TokenKind::CloseParen) {
+    error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
+    return false;
+  }
+
+  // If all checks passed, extract the argument and return true.
+  argument = argumentToken.value.getString();
+  return true;
+}
+
 // Parse the arguments of a matcher
 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
                               const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
     return false;
   }
 
+  std::string functionName;
+  if (tokenizer->peekNextToken().kind == TokenKind::Period) {
+    tokenizer->consumeNextToken();
+    TokenInfo chainCallToken = tokenizer->consumeNextToken();
+    if (chainCallToken.kind == TokenKind::CodeCompletion) {
+      addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
+      return false;
+    }
+
+    if (chainCallToken.kind != TokenKind::Ident ||
+        chainCallToken.text != TokenInfo::ID_Extract) {
+      error->addError(chainCallToken.range,
+                      ErrorType::ParserMalformedChainedExpr);
+      return false;
+    }
+
+    if (chainCallToken.text == TokenInfo::ID_Extract &&
+        !parseChainedExpression(functionName))
+      return false;
+  }
+
   if (!ctor)
     return false;
   // Merge the start and end infos.
   SourceRange matcherRange = nameToken.range;
   matcherRange.end = endToken.range.end;
-  VariantMatcher result =
-      sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
+  VariantMatcher result = sema->actOnMatcherExpression(
+      *ctor, matcherRange, functionName, args, error);
   if (result.isNull())
     return false;
   *value = result;
@@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
 }
 
 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
-    MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
-    Diagnostics *error) {
-  return RegistryManager::constructMatcher(ctor, nameRange, args, error);
+    MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+    llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
+  return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
+                                           error);
 }
 
 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

diff  --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index f049af34e9c907..58968023022d56 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -64,10 +64,9 @@ class Parser {
 
     // Process a matcher expression. The caller takes ownership of the Matcher
     // object returned.
-    virtual VariantMatcher
-    actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
-                           llvm::ArrayRef<ParserValue> args,
-                           Diagnostics *error) = 0;
+    virtual VariantMatcher actOnMatcherExpression(
+        MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+        llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
 
     // Look up a matcher by name in the matcher name found by the parser.
     virtual std::optional<MatcherCtor>
@@ -93,10 +92,11 @@ class Parser {
     std::optional<MatcherCtor>
     lookupMatcherCtor(llvm::StringRef matcherName) override;
 
-    VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
-                                          SourceRange nameRange,
-                                          llvm::ArrayRef<ParserValue> args,
-                                          Diagnostics *error) override;
+    VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
+                                          SourceRange NameRange,
+                                          StringRef functionName,
+                                          ArrayRef<ParserValue> Args,
+                                          Diagnostics *Error) override;
 
     std::vector<ArgKind> getAcceptedCompletionTypes(
         llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,6 +153,8 @@ class Parser {
   Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
          const NamedValueMap *namedValues, Diagnostics *error);
 
+  bool parseChainedExpression(std::string &argument);
+
   bool parseExpressionImpl(VariantValue *value);
 
   bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 01856aa8ffa67f..8c9197f4d00981 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
 
 VariantMatcher RegistryManager::constructMatcher(
     MatcherCtor ctor, internal::SourceRange nameRange,
-    llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
-  return ctor->create(nameRange, args, error);
+    llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
+    internal::Diagnostics *error) {
+  VariantMatcher out = ctor->create(nameRange, args, error);
+  if (functionName.empty() || out.isNull())
+    return out;
+
+  if (std::optional<DynMatcher> result = out.getDynMatcher()) {
+    result->setFunctionName(functionName);
+    return VariantMatcher::SingleMatcher(*result);
+  }
+
+  error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
+  return {};
 }
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
index 5f2867261225e7..e2026e97f83dcb 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.h
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -61,6 +61,7 @@ class RegistryManager {
 
   static VariantMatcher constructMatcher(MatcherCtor ctor,
                                          internal::SourceRange nameRange,
+                                         llvm::StringRef functionName,
                                          ArrayRef<ParserValue> args,
                                          internal::Diagnostics *error);
 };

diff  --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 5c42e5a5f0a116..27db52b37dade0 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Query/Query.h"
 #include "QueryParser.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
 #include "mlir/Support/LogicalResult.h"
@@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
                                      "\"" + binding + "\" binds here");
 }
 
+// TODO: Extract into a helper function that can be reused outside query
+// context.
+static Operation *extractFunction(std::vector<Operation *> &ops,
+                                  MLIRContext *context,
+                                  llvm::StringRef functionName) {
+  context->loadDialect<func::FuncDialect>();
+  OpBuilder builder(context);
+
+  // Collect data for function creation
+  std::vector<Operation *> slice;
+  std::vector<Value> values;
+  std::vector<Type> outputTypes;
+
+  for (auto *op : ops) {
+    // Return op's operands are propagated, but the op itself isn't needed.
+    if (!isa<func::ReturnOp>(op))
+      slice.push_back(op);
+
+    // All results are returned by the extracted function.
+    outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
+                       op->getResults().getTypes().end());
+
+    // Track all values that need to be taken as input to function.
+    values.insert(values.end(), op->getOperands().begin(),
+                  op->getOperands().end());
+  }
+
+  // Create the function
+  FunctionType funcType =
+      builder.getFunctionType(ValueRange(values), outputTypes);
+  auto loc = builder.getUnknownLoc();
+  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
+
+  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
+
+  // Map original values to function arguments
+  IRMapping mapper;
+  for (const auto &arg : llvm::enumerate(values))
+    mapper.map(arg.value(), funcOp.getArgument(arg.index()));
+
+  // Clone operations and build function body
+  std::vector<Operation *> clonedOps;
+  std::vector<Value> clonedVals;
+  for (Operation *slicedOp : slice) {
+    Operation *clonedOp =
+        clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
+    clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
+                      clonedOp->result_end());
+  }
+  // Add return operation
+  builder.create<func::ReturnOp>(loc, clonedVals);
+
+  // Remove unused function arguments
+  size_t currentIndex = 0;
+  while (currentIndex < funcOp.getNumArguments()) {
+    if (funcOp.getArgument(currentIndex).use_empty())
+      funcOp.eraseArgument(currentIndex);
+    else
+      ++currentIndex;
+  }
+
+  return funcOp;
+}
+
 Query::~Query() = default;
 
 mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -65,9 +131,21 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
 
 mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
                                     QuerySession &qs) const {
+  Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
   std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
+      matcher::MatchFinder().getMatches(rootOp, matcher);
+
+  // An extract call is recognized by considering if the matcher has a name.
+  // TODO: Consider making the extract more explicit.
+  if (matcher.hasFunctionName()) {
+    auto functionName = matcher.getFunctionName();
+    Operation *function =
+        extractFunction(matches, rootOp->getContext(), functionName);
+    os << "\n" << *function << "\n\n";
+    return mlir::success();
+  }
+
   os << "\n";
   for (Operation *op : matches) {
     os << "Match #" << ++matchCount << ":\n\n";

diff  --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
new file mode 100644
index 00000000000000..a783f65c6761bc
--- /dev/null
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+
+// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
+// CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
+// CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
+// CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
+// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+
+func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
+  %sum0 = arith.addf %a, %b : f32
+  %sub0 = arith.subf %sum0, %c : f32
+  %mul0 = arith.mulf %a, %sub0 : f32
+  %sum1 = arith.addf %b, %c : f32
+  %mul1 = arith.mulf %sum1, %mul0 : f32
+  %sub2 = arith.subf %mul1, %a : f32
+  %sum2 = arith.addf %mul1, %b : f32
+  %mul2 = arith.mulf %sub2, %sum2 : f32
+  return %mul2 : f32
+}


        


More information about the Mlir-commits mailing list