[Mlir-commits] [mlir] 58b44c8 - Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""

Jacques Pienaar llvmlistbot at llvm.org
Sun Mar 3 05:57:04 PST 2024


Author: Jacques Pienaar
Date: 2024-03-03T05:56:56-08:00
New Revision: 58b44c8102afb0e76d1cb70d4a5d089f70d2f657

URL: https://github.com/llvm/llvm-project/commit/58b44c8102afb0e76d1cb70d4a5d089f70d2f657
DIFF: https://github.com/llvm/llvm-project/commit/58b44c8102afb0e76d1cb70d4a5d089f70d2f657.diff

LOG: Reapply "Reapply "[mlir-query] Add function extraction feature to mlir-query""

Fix ASAN by erasing the op extracted post printing.

This reverts commit 732a5cba8c739ed40a7280b5d74ca717910c2c4c.

Added: 
    mlir/test/mlir-query/function-extraction.mlir

Modified: 
    mlir/include/mlir/Query/Matcher/ErrorBuilder.h
    mlir/include/mlir/Query/Matcher/MatchersInternal.h
    mlir/lib/Query/CMakeLists.txt
    mlir/lib/Query/Matcher/Diagnostics.cpp
    mlir/lib/Query/Matcher/Parser.cpp
    mlir/lib/Query/Matcher/Parser.h
    mlir/lib/Query/Matcher/RegistryManager.cpp
    mlir/lib/Query/Matcher/RegistryManager.h
    mlir/lib/Query/Query.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
index 1073daed8703f5..08f1f415cbd3e5 100644
--- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -37,8 +37,12 @@ enum class ErrorType {
   None,
 
   // Parser Errors
+  ParserChainedExprInvalidArg,
+  ParserChainedExprNoCloseParen,
+  ParserChainedExprNoOpenParen,
   ParserFailedToBuildMatcher,
   ParserInvalidToken,
+  ParserMalformedChainedExpr,
   ParserNoCloseParen,
   ParserNoCode,
   ParserNoComma,
@@ -50,9 +54,10 @@ enum class ErrorType {
 
   // Registry Errors
   RegistryMatcherNotFound,
+  RegistryNotBindable,
   RegistryValueNotFound,
   RegistryWrongArgCount,
-  RegistryWrongArgType
+  RegistryWrongArgType,
 };
 
 void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

diff  --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 67455be592393b..117f7d4edef9e3 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -63,8 +63,15 @@ class DynMatcher {
 
   bool match(Operation *op) const { return implementation->match(op); }
 
+  void setFunctionName(StringRef name) { functionName = name.str(); };
+
+  bool hasFunctionName() const { return !functionName.empty(); };
+
+  StringRef getFunctionName() const { return functionName; };
+
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+  std::string functionName;
 };
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt
index 817583e94c5222..7ecbf6e628b318 100644
--- a/mlir/lib/Query/CMakeLists.txt
+++ b/mlir/lib/Query/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRQuery
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
 
   LINK_LIBS PUBLIC
+  MLIRFuncDialect
   MLIRQueryMatcher
   )
 

diff  --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
index 10468dbcc53067..2a137e8fdfab0d 100644
--- a/mlir/lib/Query/Matcher/Diagnostics.cpp
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -38,6 +38,8 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
   case ErrorType::RegistryValueNotFound:
     return "Value not found: $0";
+  case ErrorType::RegistryNotBindable:
+    return "Matcher does not support binding.";
 
   case ErrorType::ParserStringError:
     return "Error parsing string token: <$0>";
@@ -57,6 +59,14 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Unexpected end of code.";
   case ErrorType::ParserOverloadedType:
     return "Input value has unresolved overloaded type: $0";
+  case ErrorType::ParserMalformedChainedExpr:
+    return "Period not followed by valid chained call.";
+  case ErrorType::ParserChainedExprInvalidArg:
+    return "Missing/Invalid argument for the chained call.";
+  case ErrorType::ParserChainedExprNoCloseParen:
+    return "Missing ')' for the chained call.";
+  case ErrorType::ParserChainedExprNoOpenParen:
+    return "Missing '(' for the chained call.";
   case ErrorType::ParserFailedToBuildMatcher:
     return "Failed to build matcher: $0.";
 

diff  --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 30eb4801fc03c1..3609e24f9939f7 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -26,12 +26,17 @@ struct Parser::TokenInfo {
     text = newText;
   }
 
+  // Known identifiers.
+  static const char *const ID_Extract;
+
   llvm::StringRef text;
   TokenKind kind = TokenKind::Eof;
   SourceRange range;
   VariantValue value;
 };
 
+const char *const Parser::TokenInfo::ID_Extract = "extract";
+
 class Parser::CodeTokenizer {
 public:
   // Constructor with matcherCode and error
@@ -298,6 +303,36 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
   return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
 }
 
+bool Parser::parseChainedExpression(std::string &argument) {
+  // Parse the parenthesized argument to .extract("foo")
+  // Note: EOF is handled inside the consume functions and would fail below when
+  // checking token kind.
+  const TokenInfo openToken = tokenizer->consumeNextToken();
+  const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
+  const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
+
+  if (openToken.kind != TokenKind::OpenParen) {
+    error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
+    return false;
+  }
+
+  if (argumentToken.kind != TokenKind::Literal ||
+      !argumentToken.value.isString()) {
+    error->addError(argumentToken.range,
+                    ErrorType::ParserChainedExprInvalidArg);
+    return false;
+  }
+
+  if (closeToken.kind != TokenKind::CloseParen) {
+    error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
+    return false;
+  }
+
+  // If all checks passed, extract the argument and return true.
+  argument = argumentToken.value.getString();
+  return true;
+}
+
 // Parse the arguments of a matcher
 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
                               const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -364,13 +399,34 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
     return false;
   }
 
+  std::string functionName;
+  if (tokenizer->peekNextToken().kind == TokenKind::Period) {
+    tokenizer->consumeNextToken();
+    TokenInfo chainCallToken = tokenizer->consumeNextToken();
+    if (chainCallToken.kind == TokenKind::CodeCompletion) {
+      addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
+      return false;
+    }
+
+    if (chainCallToken.kind != TokenKind::Ident ||
+        chainCallToken.text != TokenInfo::ID_Extract) {
+      error->addError(chainCallToken.range,
+                      ErrorType::ParserMalformedChainedExpr);
+      return false;
+    }
+
+    if (chainCallToken.text == TokenInfo::ID_Extract &&
+        !parseChainedExpression(functionName))
+      return false;
+  }
+
   if (!ctor)
     return false;
   // Merge the start and end infos.
   SourceRange matcherRange = nameToken.range;
   matcherRange.end = endToken.range.end;
-  VariantMatcher result =
-      sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
+  VariantMatcher result = sema->actOnMatcherExpression(
+      *ctor, matcherRange, functionName, args, error);
   if (result.isNull())
     return false;
   *value = result;
@@ -470,9 +526,10 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
 }
 
 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
-    MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
-    Diagnostics *error) {
-  return RegistryManager::constructMatcher(ctor, nameRange, args, error);
+    MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+    llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
+  return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
+                                           error);
 }
 
 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

diff  --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index f049af34e9c907..58968023022d56 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -64,10 +64,9 @@ class Parser {
 
     // Process a matcher expression. The caller takes ownership of the Matcher
     // object returned.
-    virtual VariantMatcher
-    actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
-                           llvm::ArrayRef<ParserValue> args,
-                           Diagnostics *error) = 0;
+    virtual VariantMatcher actOnMatcherExpression(
+        MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
+        llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
 
     // Look up a matcher by name in the matcher name found by the parser.
     virtual std::optional<MatcherCtor>
@@ -93,10 +92,11 @@ class Parser {
     std::optional<MatcherCtor>
     lookupMatcherCtor(llvm::StringRef matcherName) override;
 
-    VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
-                                          SourceRange nameRange,
-                                          llvm::ArrayRef<ParserValue> args,
-                                          Diagnostics *error) override;
+    VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
+                                          SourceRange NameRange,
+                                          StringRef functionName,
+                                          ArrayRef<ParserValue> Args,
+                                          Diagnostics *Error) override;
 
     std::vector<ArgKind> getAcceptedCompletionTypes(
         llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,6 +153,8 @@ class Parser {
   Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
          const NamedValueMap *namedValues, Diagnostics *error);
 
+  bool parseChainedExpression(std::string &argument);
+
   bool parseExpressionImpl(VariantValue *value);
 
   bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 01856aa8ffa67f..8c9197f4d00981 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -132,8 +132,19 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
 
 VariantMatcher RegistryManager::constructMatcher(
     MatcherCtor ctor, internal::SourceRange nameRange,
-    llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
-  return ctor->create(nameRange, args, error);
+    llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
+    internal::Diagnostics *error) {
+  VariantMatcher out = ctor->create(nameRange, args, error);
+  if (functionName.empty() || out.isNull())
+    return out;
+
+  if (std::optional<DynMatcher> result = out.getDynMatcher()) {
+    result->setFunctionName(functionName);
+    return VariantMatcher::SingleMatcher(*result);
+  }
+
+  error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
+  return {};
 }
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
index 5f2867261225e7..e2026e97f83dcb 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.h
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -61,6 +61,7 @@ class RegistryManager {
 
   static VariantMatcher constructMatcher(MatcherCtor ctor,
                                          internal::SourceRange nameRange,
+                                         llvm::StringRef functionName,
                                          ArrayRef<ParserValue> args,
                                          internal::Diagnostics *error);
 };

diff  --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 5c42e5a5f0a116..7fdbbd181234b6 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Query/Query.h"
 #include "QueryParser.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
 #include "mlir/Support/LogicalResult.h"
@@ -34,6 +36,70 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
                                      "\"" + binding + "\" binds here");
 }
 
+// TODO: Extract into a helper function that can be reused outside query
+// context.
+static Operation *extractFunction(std::vector<Operation *> &ops,
+                                  MLIRContext *context,
+                                  llvm::StringRef functionName) {
+  context->loadDialect<func::FuncDialect>();
+  OpBuilder builder(context);
+
+  // Collect data for function creation
+  std::vector<Operation *> slice;
+  std::vector<Value> values;
+  std::vector<Type> outputTypes;
+
+  for (auto *op : ops) {
+    // Return op's operands are propagated, but the op itself isn't needed.
+    if (!isa<func::ReturnOp>(op))
+      slice.push_back(op);
+
+    // All results are returned by the extracted function.
+    outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
+                       op->getResults().getTypes().end());
+
+    // Track all values that need to be taken as input to function.
+    values.insert(values.end(), op->getOperands().begin(),
+                  op->getOperands().end());
+  }
+
+  // Create the function
+  FunctionType funcType =
+      builder.getFunctionType(ValueRange(values), outputTypes);
+  auto loc = builder.getUnknownLoc();
+  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
+
+  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
+
+  // Map original values to function arguments
+  IRMapping mapper;
+  for (const auto &arg : llvm::enumerate(values))
+    mapper.map(arg.value(), funcOp.getArgument(arg.index()));
+
+  // Clone operations and build function body
+  std::vector<Operation *> clonedOps;
+  std::vector<Value> clonedVals;
+  for (Operation *slicedOp : slice) {
+    Operation *clonedOp =
+        clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
+    clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
+                      clonedOp->result_end());
+  }
+  // Add return operation
+  builder.create<func::ReturnOp>(loc, clonedVals);
+
+  // Remove unused function arguments
+  size_t currentIndex = 0;
+  while (currentIndex < funcOp.getNumArguments()) {
+    if (funcOp.getArgument(currentIndex).use_empty())
+      funcOp.eraseArgument(currentIndex);
+    else
+      ++currentIndex;
+  }
+
+  return funcOp;
+}
+
 Query::~Query() = default;
 
 mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -65,9 +131,22 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
 
 mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
                                     QuerySession &qs) const {
+  Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
   std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
+      matcher::MatchFinder().getMatches(rootOp, matcher);
+
+  // An extract call is recognized by considering if the matcher has a name.
+  // TODO: Consider making the extract more explicit.
+  if (matcher.hasFunctionName()) {
+    auto functionName = matcher.getFunctionName();
+    Operation *function =
+        extractFunction(matches, rootOp->getContext(), functionName);
+    os << "\n" << *function << "\n\n";
+    function->erase();
+    return mlir::success();
+  }
+
   os << "\n";
   for (Operation *op : matches) {
     os << "Match #" << ++matchCount << ":\n\n";

diff  --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
new file mode 100644
index 00000000000000..a783f65c6761bc
--- /dev/null
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+
+// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
+// CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
+// CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
+// CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
+// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+
+func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
+  %sum0 = arith.addf %a, %b : f32
+  %sub0 = arith.subf %sum0, %c : f32
+  %mul0 = arith.mulf %a, %sub0 : f32
+  %sum1 = arith.addf %b, %c : f32
+  %mul1 = arith.mulf %sum1, %mul0 : f32
+  %sub2 = arith.subf %mul1, %a : f32
+  %sum2 = arith.addf %mul1, %b : f32
+  %mul2 = arith.mulf %sub2, %sum2 : f32
+  return %mul2 : f32
+}


        


More information about the Mlir-commits mailing list