[Mlir-commits] [mlir] 732a5cb - Revert "Reapply "[mlir-query] Add function extraction feature to mlir-query""

Jacques Pienaar llvmlistbot at llvm.org
Sun Mar 3 05:23:20 PST 2024


Author: Jacques Pienaar
Date: 2024-03-03T05:22:41-08:00
New Revision: 732a5cba8c739ed40a7280b5d74ca717910c2c4c

URL: https://github.com/llvm/llvm-project/commit/732a5cba8c739ed40a7280b5d74ca717910c2c4c
DIFF: https://github.com/llvm/llvm-project/commit/732a5cba8c739ed40a7280b5d74ca717910c2c4c.diff

LOG: Revert "Reapply "[mlir-query] Add function extraction feature to mlir-query""

Commit fails on sanitizers.

This reverts commit 22f34ea3b05537235956c99fe942aa95b88762c0.

Added: 
    

Modified: 
    mlir/include/mlir/Query/Matcher/ErrorBuilder.h
    mlir/include/mlir/Query/Matcher/MatchersInternal.h
    mlir/lib/Query/CMakeLists.txt
    mlir/lib/Query/Matcher/Diagnostics.cpp
    mlir/lib/Query/Matcher/Parser.cpp
    mlir/lib/Query/Matcher/Parser.h
    mlir/lib/Query/Matcher/RegistryManager.cpp
    mlir/lib/Query/Matcher/RegistryManager.h
    mlir/lib/Query/Query.cpp

Removed: 
    mlir/test/mlir-query/function-extraction.mlir


################################################################################
diff  --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
index 08f1f415cbd3e5..1073daed8703f5 100644
--- a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -37,12 +37,8 @@ enum class ErrorType {
   None,
 
   // Parser Errors
-  ParserChainedExprInvalidArg,
-  ParserChainedExprNoCloseParen,
-  ParserChainedExprNoOpenParen,
   ParserFailedToBuildMatcher,
   ParserInvalidToken,
-  ParserMalformedChainedExpr,
   ParserNoCloseParen,
   ParserNoCode,
   ParserNoComma,
@@ -54,10 +50,9 @@ enum class ErrorType {
 
   // Registry Errors
   RegistryMatcherNotFound,
-  RegistryNotBindable,
   RegistryValueNotFound,
   RegistryWrongArgCount,
-  RegistryWrongArgType,
+  RegistryWrongArgType
 };
 
 void addError(Diagnostics *error, SourceRange range, ErrorType errorType,

diff  --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..67455be592393b 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -63,15 +63,8 @@ class DynMatcher {
 
   bool match(Operation *op) const { return implementation->match(op); }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
-
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
-  std::string functionName;
 };
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt
index 7ecbf6e628b318..817583e94c5222 100644
--- a/mlir/lib/Query/CMakeLists.txt
+++ b/mlir/lib/Query/CMakeLists.txt
@@ -6,7 +6,6 @@ add_mlir_library(MLIRQuery
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
 
   LINK_LIBS PUBLIC
-  MLIRFuncDialect
   MLIRQueryMatcher
   )
 

diff  --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
index 2a137e8fdfab0d..10468dbcc53067 100644
--- a/mlir/lib/Query/Matcher/Diagnostics.cpp
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -38,8 +38,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
   case ErrorType::RegistryValueNotFound:
     return "Value not found: $0";
-  case ErrorType::RegistryNotBindable:
-    return "Matcher does not support binding.";
 
   case ErrorType::ParserStringError:
     return "Error parsing string token: <$0>";
@@ -59,14 +57,6 @@ static llvm::StringRef errorTypeToFormatString(ErrorType type) {
     return "Unexpected end of code.";
   case ErrorType::ParserOverloadedType:
     return "Input value has unresolved overloaded type: $0";
-  case ErrorType::ParserMalformedChainedExpr:
-    return "Period not followed by valid chained call.";
-  case ErrorType::ParserChainedExprInvalidArg:
-    return "Missing/Invalid argument for the chained call.";
-  case ErrorType::ParserChainedExprNoCloseParen:
-    return "Missing ')' for the chained call.";
-  case ErrorType::ParserChainedExprNoOpenParen:
-    return "Missing '(' for the chained call.";
   case ErrorType::ParserFailedToBuildMatcher:
     return "Failed to build matcher: $0.";
 

diff  --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..30eb4801fc03c1 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -26,17 +26,12 @@ struct Parser::TokenInfo {
     text = newText;
   }
 
-  // Known identifiers.
-  static const char *const ID_Extract;
-
   llvm::StringRef text;
   TokenKind kind = TokenKind::Eof;
   SourceRange range;
   VariantValue value;
 };
 
-const char *const Parser::TokenInfo::ID_Extract = "extract";
-
 class Parser::CodeTokenizer {
 public:
   // Constructor with matcherCode and error
@@ -303,36 +298,6 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
   return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
 }
 
-bool Parser::parseChainedExpression(std::string &argument) {
-  // Parse the parenthesized argument to .extract("foo")
-  // Note: EOF is handled inside the consume functions and would fail below when
-  // checking token kind.
-  const TokenInfo openToken = tokenizer->consumeNextToken();
-  const TokenInfo argumentToken = tokenizer->consumeNextTokenIgnoreNewlines();
-  const TokenInfo closeToken = tokenizer->consumeNextTokenIgnoreNewlines();
-
-  if (openToken.kind != TokenKind::OpenParen) {
-    error->addError(openToken.range, ErrorType::ParserChainedExprNoOpenParen);
-    return false;
-  }
-
-  if (argumentToken.kind != TokenKind::Literal ||
-      !argumentToken.value.isString()) {
-    error->addError(argumentToken.range,
-                    ErrorType::ParserChainedExprInvalidArg);
-    return false;
-  }
-
-  if (closeToken.kind != TokenKind::CloseParen) {
-    error->addError(closeToken.range, ErrorType::ParserChainedExprNoCloseParen);
-    return false;
-  }
-
-  // If all checks passed, extract the argument and return true.
-  argument = argumentToken.value.getString();
-  return true;
-}
-
 // Parse the arguments of a matcher
 bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
                               const TokenInfo &nameToken, TokenInfo &endToken) {
@@ -399,34 +364,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
     return false;
   }
 
-  std::string functionName;
-  if (tokenizer->peekNextToken().kind == TokenKind::Period) {
-    tokenizer->consumeNextToken();
-    TokenInfo chainCallToken = tokenizer->consumeNextToken();
-    if (chainCallToken.kind == TokenKind::CodeCompletion) {
-      addCompletion(chainCallToken, MatcherCompletion("extract(\"", "extract"));
-      return false;
-    }
-
-    if (chainCallToken.kind != TokenKind::Ident ||
-        chainCallToken.text != TokenInfo::ID_Extract) {
-      error->addError(chainCallToken.range,
-                      ErrorType::ParserMalformedChainedExpr);
-      return false;
-    }
-
-    if (chainCallToken.text == TokenInfo::ID_Extract &&
-        !parseChainedExpression(functionName))
-      return false;
-  }
-
   if (!ctor)
     return false;
   // Merge the start and end infos.
   SourceRange matcherRange = nameToken.range;
   matcherRange.end = endToken.range.end;
-  VariantMatcher result = sema->actOnMatcherExpression(
-      *ctor, matcherRange, functionName, args, error);
+  VariantMatcher result =
+      sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
   if (result.isNull())
     return false;
   *value = result;
@@ -526,10 +470,9 @@ Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
 }
 
 VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
-    MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
-    llvm::ArrayRef<ParserValue> args, Diagnostics *error) {
-  return RegistryManager::constructMatcher(ctor, nameRange, functionName, args,
-                                           error);
+    MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+    Diagnostics *error) {
+  return RegistryManager::constructMatcher(ctor, nameRange, args, error);
 }
 
 std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(

diff  --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index 58968023022d56..f049af34e9c907 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -64,9 +64,10 @@ class Parser {
 
     // Process a matcher expression. The caller takes ownership of the Matcher
     // object returned.
-    virtual VariantMatcher actOnMatcherExpression(
-        MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
-        llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
+    virtual VariantMatcher
+    actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
+                           llvm::ArrayRef<ParserValue> args,
+                           Diagnostics *error) = 0;
 
     // Look up a matcher by name in the matcher name found by the parser.
     virtual std::optional<MatcherCtor>
@@ -92,11 +93,10 @@ class Parser {
     std::optional<MatcherCtor>
     lookupMatcherCtor(llvm::StringRef matcherName) override;
 
-    VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
-                                          SourceRange NameRange,
-                                          StringRef functionName,
-                                          ArrayRef<ParserValue> Args,
-                                          Diagnostics *Error) override;
+    VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
+                                          SourceRange nameRange,
+                                          llvm::ArrayRef<ParserValue> args,
+                                          Diagnostics *error) override;
 
     std::vector<ArgKind> getAcceptedCompletionTypes(
         llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
@@ -153,8 +153,6 @@ class Parser {
   Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
          const NamedValueMap *namedValues, Diagnostics *error);
 
-  bool parseChainedExpression(std::string &argument);
-
   bool parseExpressionImpl(VariantValue *value);
 
   bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 8c9197f4d00981..01856aa8ffa67f 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -132,19 +132,8 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
 
 VariantMatcher RegistryManager::constructMatcher(
     MatcherCtor ctor, internal::SourceRange nameRange,
-    llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
-    internal::Diagnostics *error) {
-  VariantMatcher out = ctor->create(nameRange, args, error);
-  if (functionName.empty() || out.isNull())
-    return out;
-
-  if (std::optional<DynMatcher> result = out.getDynMatcher()) {
-    result->setFunctionName(functionName);
-    return VariantMatcher::SingleMatcher(*result);
-  }
-
-  error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
-  return {};
+    llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
+  return ctor->create(nameRange, args, error);
 }
 
 } // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
index e2026e97f83dcb..5f2867261225e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.h
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -61,7 +61,6 @@ class RegistryManager {
 
   static VariantMatcher constructMatcher(MatcherCtor ctor,
                                          internal::SourceRange nameRange,
-                                         llvm::StringRef functionName,
                                          ArrayRef<ParserValue> args,
                                          internal::Diagnostics *error);
 };

diff  --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 27db52b37dade0..5c42e5a5f0a116 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -8,8 +8,6 @@
 
 #include "mlir/Query/Query.h"
 #include "QueryParser.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
 #include "mlir/Support/LogicalResult.h"
@@ -36,70 +34,6 @@ static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
                                      "\"" + binding + "\" binds here");
 }
 
-// TODO: Extract into a helper function that can be reused outside query
-// context.
-static Operation *extractFunction(std::vector<Operation *> &ops,
-                                  MLIRContext *context,
-                                  llvm::StringRef functionName) {
-  context->loadDialect<func::FuncDialect>();
-  OpBuilder builder(context);
-
-  // Collect data for function creation
-  std::vector<Operation *> slice;
-  std::vector<Value> values;
-  std::vector<Type> outputTypes;
-
-  for (auto *op : ops) {
-    // Return op's operands are propagated, but the op itself isn't needed.
-    if (!isa<func::ReturnOp>(op))
-      slice.push_back(op);
-
-    // All results are returned by the extracted function.
-    outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
-                       op->getResults().getTypes().end());
-
-    // Track all values that need to be taken as input to function.
-    values.insert(values.end(), op->getOperands().begin(),
-                  op->getOperands().end());
-  }
-
-  // Create the function
-  FunctionType funcType =
-      builder.getFunctionType(ValueRange(values), outputTypes);
-  auto loc = builder.getUnknownLoc();
-  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
-
-  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
-
-  // Map original values to function arguments
-  IRMapping mapper;
-  for (const auto &arg : llvm::enumerate(values))
-    mapper.map(arg.value(), funcOp.getArgument(arg.index()));
-
-  // Clone operations and build function body
-  std::vector<Operation *> clonedOps;
-  std::vector<Value> clonedVals;
-  for (Operation *slicedOp : slice) {
-    Operation *clonedOp =
-        clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
-    clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
-                      clonedOp->result_end());
-  }
-  // Add return operation
-  builder.create<func::ReturnOp>(loc, clonedVals);
-
-  // Remove unused function arguments
-  size_t currentIndex = 0;
-  while (currentIndex < funcOp.getNumArguments()) {
-    if (funcOp.getArgument(currentIndex).use_empty())
-      funcOp.eraseArgument(currentIndex);
-    else
-      ++currentIndex;
-  }
-
-  return funcOp;
-}
-
 Query::~Query() = default;
 
 mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
@@ -131,21 +65,9 @@ mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
 
 mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
                                     QuerySession &qs) const {
-  Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
   std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(rootOp, matcher);
-
-  // An extract call is recognized by considering if the matcher has a name.
-  // TODO: Consider making the extract more explicit.
-  if (matcher.hasFunctionName()) {
-    auto functionName = matcher.getFunctionName();
-    Operation *function =
-        extractFunction(matches, rootOp->getContext(), functionName);
-    os << "\n" << *function << "\n\n";
-    return mlir::success();
-  }
-
+      matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
   os << "\n";
   for (Operation *op : matches) {
     os << "Match #" << ++matchCount << ":\n\n";

diff  --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
deleted file mode 100644
index a783f65c6761bc..00000000000000
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ /dev/null
@@ -1,19 +0,0 @@
-// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
-
-// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
-// CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
-// CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
-// CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
-
-func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
-  %sum0 = arith.addf %a, %b : f32
-  %sub0 = arith.subf %sum0, %c : f32
-  %mul0 = arith.mulf %a, %sub0 : f32
-  %sum1 = arith.addf %b, %c : f32
-  %mul1 = arith.mulf %sum1, %mul0 : f32
-  %sub2 = arith.subf %mul1, %a : f32
-  %sum2 = arith.addf %mul1, %b : f32
-  %mul2 = arith.mulf %sub2, %sum2 : f32
-  return %mul2 : f32
-}


        


More information about the Mlir-commits mailing list