[Mlir-commits] [mlir] 91b8d96 - [mlir:PDLL] Add proper support for operation result type inference

River Riddle llvmlistbot at llvm.org
Mon May 30 17:43:29 PDT 2022


Author: River Riddle
Date: 2022-05-30T17:35:33-07:00
New Revision: 91b8d96fd12af6120e2988639455ba35d89a8d3e

URL: https://github.com/llvm/llvm-project/commit/91b8d96fd12af6120e2988639455ba35d89a8d3e
DIFF: https://github.com/llvm/llvm-project/commit/91b8d96fd12af6120e2988639455ba35d89a8d3e.diff

LOG: [mlir:PDLL] Add proper support for operation result type inference

This allows for the results of operations to be inferred in certain contexts,
and matches the support in PDL for result type inference. The main two
initial circumstances are when used as a replacement of another operation,
or when the operation being created implements InferTypeOpInterface.

Differential Revision: https://reviews.llvm.org/D124782

Added: 
    

Modified: 
    mlir/docs/PDLL.md
    mlir/include/mlir/Tools/PDLL/ODS/Context.h
    mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
    mlir/include/mlir/Tools/PDLL/ODS/Operation.h
    mlir/lib/Tools/PDLL/ODS/Context.cpp
    mlir/lib/Tools/PDLL/ODS/Dialect.cpp
    mlir/lib/Tools/PDLL/ODS/Operation.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/Parser/expr-failure.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md
index 11ab5831d2b6a..984ba2ed26f14 100644
--- a/mlir/docs/PDLL.md
+++ b/mlir/docs/PDLL.md
@@ -577,9 +577,10 @@ let root = op<>;
 #### Operands
 
 The operands section corresponds to the operands of the operation. This section
-of an operation expression may be elided, in which case the operands are not
-constrained in any way. When present, the operands of an operation expression
-are interpreted in the following ways:
+of an operation expression may be elided, which within a `match` section means
+that the operands are not constrained in any way. If elided within a `rewrite`
+section, the operation is treated as having no operands. When present, the
+operands of an operation expression are interpreted in the following ways:
 
 1) A single instance of type `ValueRange`:
 
@@ -612,10 +613,11 @@ let root = op<my_dialect.indirect_call>(call: Value, args: ValueRange);
 
 #### Results
 
-The results section corresponds to the result types of the operation. This
-section of an operation expression may be elided, in which case the result types
-are not constrained in any way. When present, the result types of an operation
-expression are interpreted in the following ways:
+The results section corresponds to the result types of the operation. This section
+of an operation expression may be elided, which within a `match` section means
+that the result types are not constrained in any way. If elided within a `rewrite`
+section, the results of the operation are [inferred](#inferred-results). When present,
+the result types of an operation expression are interpreted in the following ways:
 
 1) A single instance of type `TypeRange`:
 
@@ -646,6 +648,87 @@ We can match the result types as so:
 let root = op<my_dialect.op> -> (result: Type, otherResults: TypeRange);
 ```
 
+#### Inferred Results
+
+Within the `rewrite` section of a pattern, the result types of an
+operation are inferred if they are elided or otherwise not
+previously bound. The ["variable binding"](#variable-binding) section above
+discusses the concept of "binding" in more detail. Below are various examples
+that build upon this to help showcase how a result type may be "bound":
+
+* Binding to a [constant](#type-expression):
+
+```pdll
+op<my_dialect.op> -> (type<"i32">);
+```
+
+* Binding to types within the `match` section:
+
+```pdll
+Pattern {
+  replace op<dialect.inputOp> -> (resultTypes: TypeRange)
+    with op<dialect.outputOp> -> (resultTypes);
+}
+```
+
+* Binding to previously inferred types:
+
+```pdll
+Pattern {
+  rewrite root: Op with {
+    // `resultTypes` here is *not* yet bound, and will be inferred when
+    // creating `dialect.op`. Any uses of `resultTypes` after this expression,
+    // will use the types inferred when creating this operation.
+    op<dialect.op> -> (resultTypes: TypeRange);
+
+    // `resultTypes` here is bound to the types inferred when creating `dialect.op`.
+    op<dialect.bar> -> (resultTypes);
+  };
+}
+```
+
+* Binding to a [`Native Rewrite`](#native-rewriters) method result:
+
+```pdll
+Rewrite BuildTypes() -> TypeRange;
+
+Pattern {
+  rewrite root: Op with {
+    op<dialect.op> -> (BuildTypes());
+  };
+}
+```
+
+Below are the set of contexts in which result type inferrence is supported:
+
+##### Inferred Results of Replacement Operation
+
+Replacements have the invariant that the types of the replacement values must
+match the result types of the input operation. This means that when replacing
+one operation with another, the result types of the replacement operation may
+be inferred from the result types of the operation being replaced. For example,
+consider the following pattern:
+
+```pdll
+Pattern => replace op<dialect.inputOp> with op<dialect.outputOp>;
+```
+
+This pattern could be written in a more explicit way as:
+
+```pdll
+Pattern {
+  replace op<dialect.inputOp> -> (resultTypes: TypeRange)
+    with op<dialect.outputOp> -> (resultTypes);
+}
+```
+
+##### Inferred Results with InferTypeOpInterface
+
+`InferTypeOpInterface` is an interface that enables operations to infer its result
+types from its input attributes, operands, regions, etc. When the result types of
+an operation cannot be inferred from any other context, this interface is invoked
+to infer the result types of the operation.
+
 #### Attributes
 
 The attributes section of the operation expression corresponds to the attribute

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
index ea3751eac3541..6baa90af44adf 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
@@ -62,7 +62,8 @@ class Context {
   /// and a boolean indicating if the operation newly inserted (false if the
   /// operation already existed).
   std::pair<Operation *, bool>
-  insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+  insertOperation(StringRef name, StringRef summary, StringRef desc,
+                  bool supportsResultTypeInferrence, SMLoc loc);
 
   /// Lookup an operation registered with the given name, or null if no
   /// operation with that name is registered.

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
index f75d497867b8c..de8181843e84c 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
@@ -34,7 +34,8 @@ class Dialect {
   /// and a boolean indicating if the operation newly inserted (false if the
   /// operation already existed).
   std::pair<Operation *, bool>
-  insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+  insertOperation(StringRef name, StringRef summary, StringRef desc,
+                  bool supportsResultTypeInferrence, SMLoc loc);
 
   /// Lookup an operation registered with the given name, or null if no
   /// operation with that name is registered.

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
index c5b86e1733d0b..8ad36a6872b7c 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
@@ -75,6 +75,12 @@ class OperandOrResult {
   /// Return the name of this value.
   StringRef getName() const { return name; }
 
+  /// Returns true if this value is variable length, i.e. if it is Variadic or
+  /// Optional.
+  bool isVariableLength() const {
+    return variableLengthKind != VariableLengthKind::Single;
+  }
+
   /// Returns true if this value is variadic (Note this is false if the value is
   /// Optional).
   bool isVariadic() const {
@@ -157,8 +163,12 @@ class Operation {
   /// Returns the results of this operation.
   ArrayRef<OperandOrResult> getResults() const { return results; }
 
+  /// Return if the operation is known to support result type inferrence.
+  bool hasResultTypeInferrence() const { return supportsTypeInferrence; }
+
 private:
-  Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+  Operation(StringRef name, StringRef summary, StringRef desc,
+            bool supportsTypeInferrence, SMLoc loc);
 
   /// The name of the operation.
   std::string name;
@@ -167,6 +177,9 @@ class Operation {
   std::string summary;
   std::string description;
 
+  /// Flag indicating if the operation is known to support type inferrence.
+  bool supportsTypeInferrence;
+
   /// The source location of this operation.
   SMRange location;
 

diff  --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp
index 8186af9fe1b26..00f7cb26b432d 100644
--- a/mlir/lib/Tools/PDLL/ODS/Context.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp
@@ -59,13 +59,12 @@ const Dialect *Context::lookupDialect(StringRef name) const {
   return it == dialects.end() ? nullptr : &*it->second;
 }
 
-std::pair<Operation *, bool> Context::insertOperation(StringRef name,
-                                                      StringRef summary,
-                                                      StringRef desc,
-                                                      SMLoc loc) {
+std::pair<Operation *, bool>
+Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
+                         bool supportsResultTypeInferrence, SMLoc loc) {
   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
   return insertDialect(dialectAndName.first)
-      .insertOperation(name, summary, desc, loc);
+      .insertOperation(name, summary, desc, supportsResultTypeInferrence, loc);
 }
 
 const Operation *Context::lookupOperation(StringRef name) const {

diff  --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
index ce9c23421c0e9..2e084c5d6cfd6 100644
--- a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
@@ -21,15 +21,15 @@ using namespace mlir::pdll::ods;
 Dialect::Dialect(StringRef name) : name(name.str()) {}
 Dialect::~Dialect() = default;
 
-std::pair<Operation *, bool> Dialect::insertOperation(StringRef name,
-                                                      StringRef summary,
-                                                      StringRef desc,
-                                                      llvm::SMLoc loc) {
+std::pair<Operation *, bool>
+Dialect::insertOperation(StringRef name, StringRef summary, StringRef desc,
+                         bool supportsResultTypeInferrence, llvm::SMLoc loc) {
   std::unique_ptr<Operation> &operation = operations[name];
   if (operation)
     return std::make_pair(&*operation, /*wasInserted*/ false);
 
-  operation.reset(new Operation(name, summary, desc, loc));
+  operation.reset(
+      new Operation(name, summary, desc, supportsResultTypeInferrence, loc));
   return std::make_pair(&*operation, /*wasInserted*/ true);
 }
 

diff  --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
index 121c6c8c4c886..3b8a3a9e97333 100644
--- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
@@ -18,8 +18,9 @@ using namespace mlir::pdll::ods;
 //===----------------------------------------------------------------------===//
 
 Operation::Operation(StringRef name, StringRef summary, StringRef desc,
-                     llvm::SMLoc loc)
+                     bool supportsTypeInferrence, llvm::SMLoc loc)
     : name(name.str()), summary(summary.str()),
+      supportsTypeInferrence(supportsTypeInferrence),
       location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
   llvm::raw_string_ostream descOS(description);
   raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t"));

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1672c11ca7294..49eaa41672f35 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -76,6 +76,19 @@ class Parser {
     Rewrite,
   };
 
+  /// The current specification context of an operations result type. This
+  /// indicates how the result types of an operation may be inferred.
+  enum class OpResultTypeContext {
+    /// The result types of the operation are not known to be inferred.
+    Explicit,
+    /// The result types of the operation are inferred from the root input of a
+    /// `replace` statement.
+    Replacement,
+    /// The result types of the operation are inferred by using the
+    /// `InferTypeOpInterface` interface provided by the operation.
+    Interface,
+  };
+
   //===--------------------------------------------------------------------===//
   // Parsing
   //===--------------------------------------------------------------------===//
@@ -280,7 +293,9 @@ class Parser {
   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
-  FailureOr<ast::Expr *> parseOperationExpr();
+  FailureOr<ast::Expr *>
+  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
+                         OpResultTypeContext::Explicit);
   FailureOr<ast::Expr *> parseTupleExpr();
   FailureOr<ast::Expr *> parseTypeExpr();
   FailureOr<ast::Expr *> parseUnderscoreExpr();
@@ -378,6 +393,7 @@ class Parser {
                                             StringRef name, SMRange loc);
   FailureOr<ast::OperationExpr *>
   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
+                      OpResultTypeContext resultTypeContext,
                       MutableArrayRef<ast::Expr *> operands,
                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
                       MutableArrayRef<ast::Expr *> results);
@@ -388,6 +404,8 @@ class Parser {
   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
                                          const ods::Operation *odsOp,
                                          MutableArrayRef<ast::Expr *> results);
+  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
+                                          const ods::Operation *odsOp);
   LogicalResult validateOperationOperandsOrResults(
       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
@@ -795,11 +813,15 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
     tblgen::Operator op(def);
 
+    // Check to see if this operation is known to support type inferrence.
+    bool supportsResultTypeInferrence =
+        op.getTrait("::mlir::InferTypeOpInterface::Trait");
+
     bool inserted = false;
     ods::Operation *odsOp = nullptr;
-    std::tie(odsOp, inserted) =
-        odsContext.insertOperation(op.getOperationName(), op.getSummary(),
-                                   op.getDescription(), op.getLoc().front());
+    std::tie(odsOp, inserted) = odsContext.insertOperation(
+        op.getOperationName(), op.getSummary(), op.getDescription(),
+        supportsResultTypeInferrence, op.getLoc().front());
 
     // Ignore operations that have already been added.
     if (!inserted)
@@ -1917,7 +1939,8 @@ Parser::parseWrappedOperationName(bool allowEmptyName) {
   return opNameDecl;
 }
 
-FailureOr<ast::Expr *> Parser::parseOperationExpr() {
+FailureOr<ast::Expr *>
+Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
   SMRange loc = curToken.getLoc();
   consumeToken(Token::kw_op);
 
@@ -1994,13 +2017,23 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
       return failure();
   }
 
-  // Check for the optional list of result types.
+  // Handle the result types of the operation.
   SmallVector<ast::Expr *> resultTypes;
+  OpResultTypeContext resultTypeContext = inputResultTypeContext;
+
+  // Check for an explicit list of result types.
   if (consumeIf(Token::arrow)) {
     if (failed(parseToken(Token::l_paren,
                           "expected `(` before operation result type list")))
       return failure();
 
+    // If result types are provided, initially assume that the operation does
+    // not rely on type inferrence. We don't assert that it isn't, because we
+    // may be inferring the value of some type/type range variables, but given
+    // that these variables may be defined in calls we can't always discern when
+    // this is the case.
+    resultTypeContext = OpResultTypeContext::Explicit;
+
     // Handle the case of an empty result list.
     if (!consumeIf(Token::r_paren)) {
       do {
@@ -2027,10 +2060,14 @@ FailureOr<ast::Expr *> Parser::parseOperationExpr() {
     // "unconstrained results".
     resultTypes.push_back(createImplicitRangeVar(
         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
+  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
+    // If the result list isn't specified and we are in a rewrite, try to infer
+    // them at runtime instead.
+    resultTypeContext = OpResultTypeContext::Interface;
   }
 
-  return createOperationExpr(loc, *opNameDecl, operands, attributes,
-                             resultTypes);
+  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
+                             attributes, resultTypes);
 }
 
 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
@@ -2294,7 +2331,13 @@ FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
                           "expected `)` after replacement values")))
       return failure();
   } else {
-    FailureOr<ast::Expr *> replExpr = parseExpr();
+    // Handle replacement with an operation uniquely, as the replacement
+    // operation supports type inferrence from the root operation.
+    FailureOr<ast::Expr *> replExpr;
+    if (curToken.is(Token::kw_op))
+      replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
+    else
+      replExpr = parseExpr();
     if (failed(replExpr))
       return failure();
     replValues.emplace_back(*replExpr);
@@ -2710,6 +2753,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
 
 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
     SMRange loc, const ast::OpNameDecl *name,
+    OpResultTypeContext resultTypeContext,
     MutableArrayRef<ast::Expr *> operands,
     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
     MutableArrayRef<ast::Expr *> results) {
@@ -2731,9 +2775,22 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
     }
   }
 
-  // Verify the result types.
-  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
-    return failure();
+  assert(
+      (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
+      "unexpected inferrence when results were explicitly specified");
+
+  // If we aren't relying on type inferrence, or explicit results were provided,
+  // validate them.
+  if (resultTypeContext == OpResultTypeContext::Explicit) {
+    if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
+      return failure();
+
+    // Validate the use of interface based type inferrence for this operation.
+  } else if (resultTypeContext == OpResultTypeContext::Interface) {
+    assert(opNameRef &&
+           "expected valid operation name when inferring operation results");
+    checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
+  }
 
   return ast::OperationExpr::create(ctx, loc, name, operands, results,
                                     attributes);
@@ -2758,6 +2815,48 @@ Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
 }
 
+void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
+                                                const ods::Operation *odsOp) {
+  // If the operation might not have inferrence support, emit a warning to the
+  // user. We don't emit an error because the interface might be added to the
+  // operation at runtime. It's rare, but it could still happen. We emit a
+  // warning here instead.
+
+  // Handle inferrence warnings for unknown operations.
+  if (!odsOp) {
+    ctx.getDiagEngine().emitWarning(
+        loc, llvm::formatv(
+                 "operation result types are marked to be inferred, but "
+                 "`{0}` is unknown. Ensure that `{0}` supports zero "
+                 "results or implements `InferTypeOpInterface`. Include "
+                 "the ODS definition of this operation to remove this warning.",
+                 opName));
+    return;
+  }
+
+  // Handle inferrence warnings for known operations that expected at least one
+  // result, but don't have inference support. An elided results list can mean
+  // "zero-results", and we don't want to warn when that is the expected
+  // behavior.
+  bool requiresInferrence =
+      llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
+        return !result.isVariableLength();
+      });
+  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
+    ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
+        loc,
+        llvm::formatv("operation result types are marked to be inferred, but "
+                      "`{0}` does not provide an implementation of "
+                      "`InferTypeOpInterface`. Ensure that `{0}` attaches "
+                      "`InferTypeOpInterface` at runtime, or add support to "
+                      "the ODS definition to remove this warning.",
+                      opName));
+    diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
+                     odsOp->getLoc());
+    return;
+  }
+}
+
 LogicalResult Parser::validateOperationOperandsOrResults(
     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,

diff  --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 4a766c817d23f..de8c6163895e3 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -296,6 +296,36 @@ Pattern {
 
 // -----
 
+Pattern {
+  // CHECK: warning: operation result types are marked to be inferred, but
+  // CHECK-SAME: `test.unknown_inferred_result_op` is unknown.
+  // CHECK-SAME: Ensure that `test.unknown_inferred_result_op` supports zero
+  // CHECK-SAME: results or implements `InferTypeOpInterface`.
+  // CHECK-SAME: Include the ODS definition of this operation to remove this
+  // CHECK-SAME: warning.
+  rewrite _: Op with {
+    op<test.unknown_inferred_result_op>;
+  };
+}
+
+// -----
+
+#include "include/ops.td"
+
+Pattern {
+  // CHECK: warning: operation result types are marked to be inferred, but
+  // CHECK-SAME: `test.multiple_single_result` does not provide an implementation
+  // CHECK-SAME: of `InferTypeOpInterface`. Ensure that `test.multiple_single_result`
+  // CHECK-SAME: attaches `InferTypeOpInterface` at runtime, or add support
+  // CHECK-SAME: to the ODS definition to remove this warning.
+  // CHECK: see the definition of `test.multiple_single_result` here
+  rewrite _: Op with {
+    op<test.multiple_single_result>;
+  };
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // `type` Expr
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list