[Mlir-commits] [mlir] 1c2edb0 - [mlir:PDLL] Rework the C++ generation of native Constraint/Rewrite arguments and results

River Riddle llvmlistbot at llvm.org
Mon May 30 17:43:35 PDT 2022


Author: River Riddle
Date: 2022-05-30T17:35:34-07:00
New Revision: 1c2edb026ed67ddbb30ebe3e2d2f4f17a882a881

URL: https://github.com/llvm/llvm-project/commit/1c2edb026ed67ddbb30ebe3e2d2f4f17a882a881
DIFF: https://github.com/llvm/llvm-project/commit/1c2edb026ed67ddbb30ebe3e2d2f4f17a882a881.diff

LOG: [mlir:PDLL] Rework the C++ generation of native Constraint/Rewrite arguments and results

The current translation uses the old "ugly"/"raw" form which used PDLValue for the arguments
and results. This commit updates the C++ generation to use the recently added sugar that
allows for directly using the desired types for the arguments and result of PDL functions.
In addition, this commit also properly imports the C++ class for ODS operations, constraints,
and interfaces. This allows for a much more convienent C++ API than previously granted
with the raw/low-level types.

Differential Revision: https://reviews.llvm.org/D124817

Added: 
    

Modified: 
    mlir/docs/PDLL.md
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Tools/PDLL/AST/Nodes.h
    mlir/include/mlir/Tools/PDLL/AST/Types.h
    mlir/include/mlir/Tools/PDLL/ODS/Context.h
    mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
    mlir/include/mlir/Tools/PDLL/ODS/Operation.h
    mlir/lib/Tools/PDLL/AST/Nodes.cpp
    mlir/lib/Tools/PDLL/AST/TypeDetail.h
    mlir/lib/Tools/PDLL/AST/Types.cpp
    mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
    mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
    mlir/lib/Tools/PDLL/ODS/Context.cpp
    mlir/lib/Tools/PDLL/ODS/Dialect.cpp
    mlir/lib/Tools/PDLL/ODS/Operation.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
    mlir/test/mlir-pdll/CodeGen/CPP/general.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md
index 984ba2ed26f14..2aadbb6035d0f 100644
--- a/mlir/docs/PDLL.md
+++ b/mlir/docs/PDLL.md
@@ -1146,17 +1146,86 @@ Pattern {
 ```
 
 The arguments of the constraint are accessible within the code block via the
-same name. The type of these native variables are mapped directly to the
-corresponding MLIR type of the [core constraint](#core-constraints) used. For
-example, an `Op` corresponds to a variable of type `Operation *`.
+same name. See the ["type translation"](#native-constraint-type-translations) below for
+detailed information on how PDLL types are converted to native types. In addition to the
+PDLL arguments, the code block may also access the current `PatternRewriter` using
+`rewriter`. The result type of the native constraint function is implicitly defined
+as a `::mlir::LogicalResult`.
 
-The results of the constraint can be populated using the provided `results`
-variable. This variable is a `PDLResultList`, and expects results to be
-populated in the order that they are defined within the result list of the
-constraint declaration.
+Taking the constraints defined above as an example, these function would roughly be
+translated into:
 
-In addition to the above, the code block may also access the current
-`PatternRewriter` using `rewriter`.
+```c++
+LogicalResult HasOneUse(PatternRewriter &rewriter, Value value) {
+  return success(value.hasOneUse());
+}
+LogicalResult HasSameElementType(Value value1, Value value2) {
+  return success(value1.getType().cast<ShapedType>().getElementType() ==
+                 value2.getType().cast<ShapedType>().getElementType());
+}
+```
+
+TODO: Native constraints should also be allowed to return values in certain cases.
+
+###### Native Constraint Type Translations
+
+The types of argument and result variables are generally mapped to the corresponding
+MLIR type of the [constraint](#constraints) used. Below is a detailed description
+of how the mapped type of a variable is determined for the various 
diff erent types of
+constraints.
+
+* Attr, Op, Type, TypeRange, Value, ValueRange:
+
+These are all core constraints, and are mapped directly to the MLIR equivalent
+(that their names suggest), namely:
+
+  * `Attr`       -> "::mlir::Attribute"
+  * `Op`         -> "::mlir::Operation *"
+  * `Type`       -> "::mlir::Type"
+  * `TypeRange`  -> "::mlir::TypeRange"
+  * `Value`      -> "::mlir::Value"
+  * `ValueRange` -> "::mlir::ValueRange"
+
+* Op<dialect.name>
+
+A named operation constraint has a unique translation. If the ODS registration of the
+referenced operation has been included, the qualified C++ is used. If the ODS information
+is not available, this constraint maps to "::mlir::Operation *", similarly to the unnamed
+variant. For example, given the following:
+
+```pdll
+// `my_ops.td` provides the ODS definition of the `my_dialect` operations, such as
+// `my_dialect.bar` used below.
+#include "my_ops.td"
+
+Constraint Cst(op: Op<my_dialect.bar>) [{
+  return success(op ... );
+}];
+```
+
+The native type used for `op` may be of the form `my_dialect::BarOp`, as opposed to the
+default `::mlir::Operation *`. Below is a sample translation of the above constraint:
+
+```c++
+LogicalResult Cst(my_dialect::BarOp op) {
+  return success(op ... );
+}
+```
+
+* Imported ODS Constraints
+
+Aside from the core constraints, certain constraints imported from ODS may use a unique
+native type. How to enable this unique type depends on the ODS constraint construct that
+was imported:
+
+  * `Attr` constraints
+    - Imported `Attr` constraints utilize the `storageType` field for native type translation.
+  
+  * `Type` constraints
+    - Imported `Type` constraints utilize the `cppClassName` field for native type translation.
+
+  * `AttrInterface`/`OpInterface`/`TypeInterface` constraints
+    - Imported interfaces utilize the `cppClassName` field for native type translation.
 
 #### Defining Constraints Inline
 
@@ -1414,10 +1483,7 @@ be defined by specifying a string code block after the rewrite declaration:
 
 ```pdll
 Rewrite BuildOp(value: Value) -> (foo: Op<my_dialect.foo>, bar: Op<my_dialect.bar>) [{
-  // We push back the results into the `results` variable in the order defined
-  // by the result list of the rewrite declaration.
-  results.push_back(rewriter.create<my_dialect::FooOp>(value));
-  results.push_back(rewriter.create<my_dialect::BarOp>());
+  return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()};
 }];
 
 Pattern {
@@ -1431,17 +1497,85 @@ Pattern {
 ```
 
 The arguments of the rewrite are accessible within the code block via the
-same name. The type of these native variables are mapped directly to the
-corresponding MLIR type of the [core constraint](#core-constraints) used. For
-example, an `Op` corresponds to a variable of type `Operation *`.
+same name. See the ["type translation"](#native-rewrite-type-translations) below for
+detailed information on how PDLL types are converted to native types. In addition to the
+PDLL arguments, the code block may also access the current `PatternRewriter` using
+`rewriter`. See the ["result translation"](#native-rewrite-result-translation) section
+for detailed information on how the result type of the native function is determined.
+
+Taking the rewrite defined above as an example, this function would roughly be
+translated into:
+
+```c++
+std::tuple<my_dialect::FooOp, my_dialect::BarOp> BuildOp(Value value) {
+  return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()};
+}
+```
 
-The results of the rewrite can be populated using the provided `results`
-variable. This variable is a `PDLResultList`, and expects results to be
-populated in the order that they are defined within the result list of the
-rewrite declaration.
+###### Native Rewrite Type Translations
 
-In addition to the above, the code block may also access the current
-`PatternRewriter` using `rewriter`.
+The types of argument and result variables are generally mapped to the corresponding
+MLIR type of the [constraint](#constraints) used. The rules of native `Rewrite` type translation
+are identical to those of native `Constraint`s, please view the corresponding
+[native `Constraint` type translation](#native-constraint-type-translations) section for a
+detailed description of how the mapped type of a variable is determined.
+
+###### Native Rewrite Result Translation
+
+The results of a native rewrite are directly translated to the results of the native function,
+using the type translation rules [described above](#native-rewrite-type-translations). The section
+below describes the various result translation scenarios:
+
+* Zero Result
+
+```pdll
+Rewrite createOp() [{
+  rewriter.create<my_dialect::FooOp>();
+}];
+```
+
+In the case where a native `Rewrite` has no results, the native function returns `void`:
+
+```c++
+void createOp(PatternRewriter &rewriter) {
+  rewriter.create<my_dialect::FooOp>();
+}
+```
+
+* Single Result
+
+```pdll
+Rewrite createOp() -> Op<my_dialect.foo> [{
+  return rewriter.create<my_dialect::FooOp>();
+}];
+```
+
+In the case where a native `Rewrite` has a single result, the native function returns the corresponding
+native type for that single result:
+
+```c++
+my_dialect::FooOp createOp(PatternRewriter &rewriter) {
+  return rewriter.create<my_dialect::FooOp>();
+}
+```
+
+* Multi Result
+
+```pdll
+Rewrite complexRewrite(value: Value) -> (Op<my_dialect.foo>, FunctionOpInterface) [{
+  ...
+}];
+```
+
+In the case where a native `Rewrite` has multiple results, the native function returns a `std::tuple<...>`
+containing the corresponding native types for each of the results:
+
+```c++
+std::tuple<my_dialect::FooOp, FunctionOpInterface>
+complexRewrite(PatternRewriter &rewriter, Value value) {
+  ...
+}
+```
 
 #### Defining Rewrites Inline
 

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 751ec2af2c9d2..f3236dcff7fac 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -943,9 +943,13 @@ struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
                          " to be of type: " + llvm::getTypeName<T>());
         });
   }
+  using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
+
   static T processAsArg(BaseT baseValue) {
     return baseValue.template cast<T>();
   }
+  using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
+
   static void processAsResult(PatternRewriter &, PDLResultList &results,
                               T value) {
     results.push_back(value);
@@ -967,6 +971,8 @@ template <>
 struct ProcessPDLValue<StringRef>
     : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
   static StringRef processAsArg(StringAttr value) { return value.getValue(); }
+  using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
+
   static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
                               StringRef value) {
     results.push_back(rewriter.getStringAttr(value));

diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 6635dc764e04c..ab1a53f90b8fc 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -506,6 +506,7 @@ class OperationExpr final
                                     NamedAttributeDecl *> {
 public:
   static OperationExpr *create(Context &ctx, SMRange loc,
+                               const ods::Operation *odsOp,
                                const OpNameDecl *nameDecl,
                                ArrayRef<Expr *> operands,
                                ArrayRef<Expr *> resultTypes,
@@ -830,16 +831,15 @@ class ValueRangeConstraintDecl
 ///     - This is a constraint which is defined using only PDLL constructs.
 class UserConstraintDecl final
     : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
-      llvm::TrailingObjects<UserConstraintDecl, VariableDecl *> {
+      llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
 public:
   /// Create a native constraint with the given optional code block.
-  static UserConstraintDecl *createNative(Context &ctx, const Name &name,
-                                          ArrayRef<VariableDecl *> inputs,
-                                          ArrayRef<VariableDecl *> results,
-                                          Optional<StringRef> codeBlock,
-                                          Type resultType) {
-    return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
-                      resultType);
+  static UserConstraintDecl *
+  createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+               ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+               Type resultType, ArrayRef<StringRef> nativeInputTypes = {}) {
+    return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
+                      /*body=*/nullptr, resultType);
   }
 
   /// Create a PDLL constraint with the given body.
@@ -848,8 +848,8 @@ class UserConstraintDecl final
                                         ArrayRef<VariableDecl *> results,
                                         const CompoundStmt *body,
                                         Type resultType) {
-    return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
-                      body, resultType);
+    return createImpl(ctx, name, inputs, /*nativeInputTypes=*/llvm::None,
+                      results, /*codeBlock=*/llvm::None, body, resultType);
   }
 
   /// Return the name of the constraint.
@@ -863,6 +863,10 @@ class UserConstraintDecl final
     return const_cast<UserConstraintDecl *>(this)->getInputs();
   }
 
+  /// Return the explicit native type to use for the given input. Returns None
+  /// if no explicit type was set.
+  Optional<StringRef> getNativeInputType(unsigned index) const;
+
   /// Return the explicit results of the constraint declaration. May be empty,
   /// even if the constraint has results (e.g. in the case of inferred results).
   MutableArrayRef<VariableDecl *> getResults() {
@@ -891,10 +895,12 @@ class UserConstraintDecl final
   /// components.
   static UserConstraintDecl *
   createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+             ArrayRef<StringRef> nativeInputTypes,
              ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
              const CompoundStmt *body, Type resultType);
 
-  UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults,
+  UserConstraintDecl(const Name &name, unsigned numInputs,
+                     bool hasNativeInputTypes, unsigned numResults,
                      Optional<StringRef> codeBlock, const CompoundStmt *body,
                      Type resultType)
       : Base(name.getLoc(), &name), numInputs(numInputs),
@@ -916,8 +922,14 @@ class UserConstraintDecl final
   /// The result type of the constraint.
   Type resultType;
 
+  /// Flag indicating if this constraint has explicit native input types.
+  bool hasNativeInputTypes;
+
   /// Allow access to various internals.
-  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *>;
+  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
+  size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
+    return numInputs + numResults;
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -1145,6 +1157,23 @@ class CallableDecl : public Decl {
     return cast<UserRewriteDecl>(this)->getResultType();
   }
 
+  /// Return the explicit results of the declaration. Note that these may be
+  /// empty, even if the callable has results (e.g. in the case of inferred
+  /// results).
+  ArrayRef<VariableDecl *> getResults() const {
+    if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+      return cst->getResults();
+    return cast<UserRewriteDecl>(this)->getResults();
+  }
+
+  /// Return the optional code block of this callable, if this is a native
+  /// callable with a provided implementation.
+  Optional<StringRef> getCodeBlock() const {
+    if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+      return cst->getCodeBlock();
+    return cast<UserRewriteDecl>(this)->getCodeBlock();
+  }
+
   /// Support LLVM type casting facilities.
   static bool classof(const Node *decl) {
     return isa<UserConstraintDecl, UserRewriteDecl>(decl);

diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h
index 75a80b5bf92d4..a4c0888441e2c 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Types.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h
@@ -14,6 +14,10 @@
 
 namespace mlir {
 namespace pdll {
+namespace ods {
+class Operation;
+} // namespace ods
+
 namespace ast {
 class Context;
 
@@ -151,10 +155,15 @@ class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
   /// Return an instance of the Operation type with an optional operation name.
   /// If no name is provided, this type may refer to any operation.
   static OperationType get(Context &context,
-                           Optional<StringRef> name = llvm::None);
+                           Optional<StringRef> name = llvm::None,
+                           const ods::Operation *odsOp = nullptr);
 
   /// Return the name of this operation type, or None if it doesn't have on.
   Optional<StringRef> getName() const;
+
+  /// Return the ODS operation that this type refers to, or nullptr if the ODS
+  /// operation is unknown.
+  const ods::Operation *getODSOperation() const;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
index 6baa90af44adf..8a57bb791e639 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
@@ -63,7 +63,8 @@ class Context {
   /// operation already existed).
   std::pair<Operation *, bool>
   insertOperation(StringRef name, StringRef summary, StringRef desc,
-                  bool supportsResultTypeInferrence, SMLoc loc);
+                  StringRef nativeClassName, bool supportsResultTypeInferrence,
+                  SMLoc loc);
 
   /// Lookup an operation registered with the given name, or null if no
   /// operation with that name is registered.

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
index de8181843e84c..c5c60977dd2f1 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
@@ -35,7 +35,8 @@ class Dialect {
   /// operation already existed).
   std::pair<Operation *, bool>
   insertOperation(StringRef name, StringRef summary, StringRef desc,
-                  bool supportsResultTypeInferrence, SMLoc loc);
+                  StringRef nativeClassName, bool supportsResultTypeInferrence,
+                  SMLoc loc);
 
   /// Lookup an operation registered with the given name, or null if no
   /// operation with that name is registered.

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
index 8ad36a6872b7c..f4c5a518fbfa9 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
@@ -154,6 +154,9 @@ class Operation {
   /// Returns the description of the operation.
   StringRef getDescription() const { return description; }
 
+  /// Returns the native class name of the operation.
+  StringRef getNativeClassName() const { return nativeClassName; }
+
   /// Returns the attributes of this operation.
   ArrayRef<Attribute> getAttributes() const { return attributes; }
 
@@ -168,7 +171,7 @@ class Operation {
 
 private:
   Operation(StringRef name, StringRef summary, StringRef desc,
-            bool supportsTypeInferrence, SMLoc loc);
+            StringRef nativeClassName, bool supportsTypeInferrence, SMLoc loc);
 
   /// The name of the operation.
   std::string name;
@@ -177,6 +180,9 @@ class Operation {
   std::string summary;
   std::string description;
 
+  /// The native class name of the operation, used when generating native code.
+  std::string nativeClassName;
+
   /// Flag indicating if the operation is known to support type inferrence.
   bool supportsTypeInferrence;
 

diff  --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 5f1af1dc7b2eb..417483444615c 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -298,17 +298,18 @@ MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
 // OperationExpr
 //===----------------------------------------------------------------------===//
 
-OperationExpr *OperationExpr::create(
-    Context &ctx, SMRange loc, const OpNameDecl *name,
-    ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
-    ArrayRef<NamedAttributeDecl *> attributes) {
+OperationExpr *
+OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
+                      const OpNameDecl *name, ArrayRef<Expr *> operands,
+                      ArrayRef<Expr *> resultTypes,
+                      ArrayRef<NamedAttributeDecl *> attributes) {
   unsigned allocSize =
       OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
           operands.size() + resultTypes.size(), attributes.size());
   void *rawData =
       ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
 
-  Type resultType = OperationType::get(ctx, name->getName());
+  Type resultType = OperationType::get(ctx, name->getName(), odsOp);
   OperationExpr *opExpr = new (rawData)
       OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
                     attributes.size(), name->getLoc());
@@ -426,23 +427,41 @@ ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
 // UserConstraintDecl
 //===----------------------------------------------------------------------===//
 
+Optional<StringRef>
+UserConstraintDecl::getNativeInputType(unsigned index) const {
+  return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index]
+                             : Optional<StringRef>();
+}
+
 UserConstraintDecl *UserConstraintDecl::createImpl(
     Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
-    ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
-    const CompoundStmt *body, Type resultType) {
-  unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
-      inputs.size() + results.size());
+    ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results,
+    Optional<StringRef> codeBlock, const CompoundStmt *body, Type resultType) {
+  bool hasNativeInputTypes = !nativeInputTypes.empty();
+  assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size());
+
+  unsigned allocSize =
+      UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>(
+          inputs.size() + results.size(),
+          hasNativeInputTypes ? inputs.size() : 0);
   void *rawData =
       ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
   if (codeBlock)
     codeBlock = codeBlock->copy(ctx.getAllocator());
 
-  UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
-      name, inputs.size(), results.size(), codeBlock, body, resultType);
+  UserConstraintDecl *decl = new (rawData)
+      UserConstraintDecl(name, inputs.size(), hasNativeInputTypes,
+                         results.size(), codeBlock, body, resultType);
   std::uninitialized_copy(inputs.begin(), inputs.end(),
                           decl->getInputs().begin());
   std::uninitialized_copy(results.begin(), results.end(),
                           decl->getResults().begin());
+  if (hasNativeInputTypes) {
+    StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>();
+    for (unsigned i = 0, e = inputs.size(); i < e; ++i)
+      nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator());
+  }
+
   return decl;
 }
 

diff  --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
index e6719fb961216..4e2a686d704f1 100644
--- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h
+++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
@@ -75,13 +75,15 @@ struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {};
 //===----------------------------------------------------------------------===//
 
 struct OperationTypeStorage
-    : public TypeStorageBase<OperationTypeStorage, StringRef> {
+    : public TypeStorageBase<OperationTypeStorage,
+                             std::pair<StringRef, const ods::Operation *>> {
   using Base::Base;
 
   static OperationTypeStorage *
-  construct(StorageUniquer::StorageAllocator &alloc, StringRef key) {
-    return new (alloc.allocate<OperationTypeStorage>())
-        OperationTypeStorage(alloc.copyInto(key));
+  construct(StorageUniquer::StorageAllocator &alloc,
+            const std::pair<StringRef, const ods::Operation *> &key) {
+    return new (alloc.allocate<OperationTypeStorage>()) OperationTypeStorage(
+        std::make_pair(alloc.copyInto(key.first), key.second));
   }
 };
 

diff  --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp
index 4164cabac4dbc..02b01e8854aae 100644
--- a/mlir/lib/Tools/PDLL/AST/Types.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Types.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Tools/PDLL/AST/Context.h"
 
 using namespace mlir;
+using namespace mlir::pdll;
 using namespace mlir::pdll::ast;
 
 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
@@ -68,16 +69,22 @@ ConstraintType ConstraintType::get(Context &context) {
 // OperationType
 //===----------------------------------------------------------------------===//
 
-OperationType OperationType::get(Context &context, Optional<StringRef> name) {
+OperationType OperationType::get(Context &context, Optional<StringRef> name,
+                                 const ods::Operation *odsOp) {
   return context.getTypeUniquer().get<ImplTy>(
-      /*initFn=*/function_ref<void(ImplTy *)>(), name.getValueOr(""));
+      /*initFn=*/function_ref<void(ImplTy *)>(),
+      std::make_pair(name.getValueOr(""), odsOp));
 }
 
 Optional<StringRef> OperationType::getName() const {
-  StringRef name = getImplAs<ImplTy>()->getValue();
+  StringRef name = getImplAs<ImplTy>()->getValue().first;
   return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
 }
 
+const ods::Operation *OperationType::getODSOperation() const {
+  return getImplAs<ImplTy>()->getValue().second;
+}
+
 //===----------------------------------------------------------------------===//
 // RangeType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index c3b5c957007e8..bf73be9cbff6d 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/PDL/IR/PDLOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Tools/PDLL/AST/Nodes.h"
+#include "mlir/Tools/PDLL/ODS/Operation.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
@@ -51,11 +52,16 @@ class CodeGen {
   void generate(const ast::UserConstraintDecl *decl,
                 StringSet<> &nativeFunctions);
   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
-  void generateConstraintOrRewrite(StringRef name, bool isConstraint,
-                                   ArrayRef<ast::VariableDecl *> inputs,
-                                   StringRef codeBlock,
+  void generateConstraintOrRewrite(const ast::CallableDecl *decl,
+                                   bool isConstraint,
                                    StringSet<> &nativeFunctions);
 
+  /// Return the native name for the type of the given type.
+  StringRef getNativeTypeName(ast::Type type);
+
+  /// Return the native name for the type of the given variable decl.
+  StringRef getNativeTypeName(ast::VariableDecl *decl);
+
   /// The stream to output to.
   raw_ostream &os;
 };
@@ -152,55 +158,88 @@ void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
 
 void CodeGen::generate(const ast::UserConstraintDecl *decl,
                        StringSet<> &nativeFunctions) {
-  return generateConstraintOrRewrite(decl->getName().getName(),
-                                     /*isConstraint=*/true, decl->getInputs(),
-                                     *decl->getCodeBlock(), nativeFunctions);
+  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
+                                     /*isConstraint=*/true, nativeFunctions);
 }
 
 void CodeGen::generate(const ast::UserRewriteDecl *decl,
                        StringSet<> &nativeFunctions) {
-  return generateConstraintOrRewrite(decl->getName().getName(),
-                                     /*isConstraint=*/false, decl->getInputs(),
-                                     *decl->getCodeBlock(), nativeFunctions);
+  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
+                                     /*isConstraint=*/false, nativeFunctions);
+}
+
+StringRef CodeGen::getNativeTypeName(ast::Type type) {
+  return llvm::TypeSwitch<ast::Type, StringRef>(type)
+      .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
+      .Case([&](ast::OperationType opType) -> StringRef {
+        // Use the derived Op class when available.
+        if (const auto *odsOp = opType.getODSOperation())
+          return odsOp->getNativeClassName();
+        return "::mlir::Operation *";
+      })
+      .Case([&](ast::TypeType) { return "::mlir::Type"; })
+      .Case([&](ast::ValueType) { return "::mlir::Value"; })
+      .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
+      .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
+}
+
+StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
+  // Try to extract a type name from the variable's constraints.
+  for (ast::ConstraintRef &cst : decl->getConstraints()) {
+    if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
+      if (Optional<StringRef> name = userCst->getNativeInputType(0))
+        return *name;
+      return getNativeTypeName(userCst->getInputs()[0]);
+    }
+  }
+
+  // Otherwise, use the type of the variable.
+  return getNativeTypeName(decl->getType());
 }
 
-void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
-                                          ArrayRef<ast::VariableDecl *> inputs,
-                                          StringRef codeBlock,
+void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
+                                          bool isConstraint,
                                           StringSet<> &nativeFunctions) {
+  StringRef name = decl->getName()->getName();
   nativeFunctions.insert(name);
 
-  // TODO: Should there be something explicit for handling optionality?
-  auto getCppType = [&](ast::Type type) -> StringRef {
-    return llvm::TypeSwitch<ast::Type, StringRef>(type)
-        .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
-        .Case([&](ast::OperationType) {
-          // TODO: Allow using the derived Op class when possible.
-          return "::mlir::Operation *";
-        })
-        .Case([&](ast::TypeType) { return "::mlir::Type"; })
-        .Case([&](ast::ValueType) { return "::mlir::Value"; })
-        .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
-        .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
-  };
-  os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
-     << "PDLFn(::mlir::PatternRewriter &rewriter, "
-     << (isConstraint ? "" : "::mlir::PDLResultList &results, ")
-     << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n";
-
-  const char *argumentInitStr = R"(
-  {0} {1} = {{};
-  if (values[{2}])
-    {1} = values[{2}].cast<{0}>();
-  (void){1};
-)";
-  for (const auto &it : llvm::enumerate(inputs)) {
-    const ast::VariableDecl *input = it.value();
-    os << llvm::formatv(argumentInitStr, getCppType(input->getType()),
-                        input->getName().getName(), it.index());
+  os << "static ";
+
+  // TODO: Work out a proper modeling for "optionality".
+
+  // Emit the result type.
+  // If this is a constraint, we always return a LogicalResult.
+  // TODO: This will need to change if we allow Constraints to return values as
+  // well.
+  if (isConstraint) {
+    os << "::mlir::LogicalResult";
+  } else {
+    // Otherwise, generate a type based on the results of the callable.
+    // If the callable has explicit results, use those to build the result.
+    // Otherwise, use the type of the callable.
+    ArrayRef<ast::VariableDecl *> results = decl->getResults();
+    if (results.empty()) {
+      os << "void";
+    } else if (results.size() == 1) {
+      os << getNativeTypeName(results[0]);
+    } else {
+      os << "std::tuple<";
+      llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
+        os << getNativeTypeName(result);
+      });
+      os << ">";
+    }
   }
 
-  os << "  " << codeBlock.trim() << "\n}\n";
+  os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
+  if (!decl->getInputs().empty()) {
+    os << ", ";
+    llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
+      os << getNativeTypeName(input) << " " << input->getName().getName();
+    });
+  }
+  os << ") {\n";
+  os << "  " << decl->getCodeBlock()->trim() << "\n}\n\n";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 5678761e71805..1e4e14e3df4c9 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -316,11 +316,12 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
       Value typeValue =
           TypeSwitch<const ast::Node *, Value>(constraint.constraint)
               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
-                    ast::ValueRangeConstraintDecl>([&, this](auto *cst) -> Value {
-                if (auto *typeConstraintExpr = cst->getTypeExpr())
-                  return this->genSingleExpr(typeConstraintExpr);
-                return Value();
-              })
+                    ast::ValueRangeConstraintDecl>(
+                  [&, this](auto *cst) -> Value {
+                    if (auto *typeConstraintExpr = cst->getTypeExpr())
+                      return this->genSingleExpr(typeConstraintExpr);
+                    return Value();
+                  })
               .Default(Value());
       if (typeValue)
         return typeValue;
@@ -442,16 +443,15 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
       return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
     }
 
-    assert(opType.getName() && "expected valid operation name");
-    const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName());
-
+    const ods::Operation *odsOp = opType.getODSOperation();
     if (!odsOp) {
-      assert(llvm::isDigit(name[0]) && "unregistered op only allows numeric indexing");
+      assert(llvm::isDigit(name[0]) &&
+             "unregistered op only allows numeric indexing");
       unsigned resultIndex;
       name.getAsInteger(/*Radix=*/10, resultIndex);
       IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
       return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
-                                            parentExprs[0], index);
+                                           parentExprs[0], index);
     }
 
     // Find the result with the member name or by index.

diff  --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp
index 00f7cb26b432d..a3933c9c73b20 100644
--- a/mlir/lib/Tools/PDLL/ODS/Context.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp
@@ -61,10 +61,12 @@ const Dialect *Context::lookupDialect(StringRef name) const {
 
 std::pair<Operation *, bool>
 Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
+                         StringRef nativeClassName,
                          bool supportsResultTypeInferrence, SMLoc loc) {
   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
   return insertDialect(dialectAndName.first)
-      .insertOperation(name, summary, desc, supportsResultTypeInferrence, loc);
+      .insertOperation(name, summary, desc, nativeClassName,
+                       supportsResultTypeInferrence, loc);
 }
 
 const Operation *Context::lookupOperation(StringRef name) const {

diff  --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
index 2e084c5d6cfd6..b4654a6ad5b2e 100644
--- a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
@@ -23,13 +23,14 @@ Dialect::~Dialect() = default;
 
 std::pair<Operation *, bool>
 Dialect::insertOperation(StringRef name, StringRef summary, StringRef desc,
+                         StringRef nativeClassName,
                          bool supportsResultTypeInferrence, llvm::SMLoc loc) {
   std::unique_ptr<Operation> &operation = operations[name];
   if (operation)
     return std::make_pair(&*operation, /*wasInserted*/ false);
 
-  operation.reset(
-      new Operation(name, summary, desc, supportsResultTypeInferrence, loc));
+  operation.reset(new Operation(name, summary, desc, nativeClassName,
+                                supportsResultTypeInferrence, loc));
   return std::make_pair(&*operation, /*wasInserted*/ true);
 }
 

diff  --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
index 3b8a3a9e97333..c991c33be7d22 100644
--- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
@@ -18,8 +18,10 @@ using namespace mlir::pdll::ods;
 //===----------------------------------------------------------------------===//
 
 Operation::Operation(StringRef name, StringRef summary, StringRef desc,
-                     bool supportsTypeInferrence, llvm::SMLoc loc)
+                     StringRef nativeClassName, bool supportsTypeInferrence,
+                     llvm::SMLoc loc)
     : name(name.str()), summary(summary.str()),
+      nativeClassName(nativeClassName.str()),
       supportsTypeInferrence(supportsTypeInferrence),
       location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
   llvm::raw_string_ostream descOS(description);

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index e7ec5a047b9fc..9cd3d8caf40bf 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -142,11 +142,13 @@ class Parser {
   template <typename ConstraintT>
   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
                                                StringRef codeBlock, SMRange loc,
-                                               ast::Type type);
+                                               ast::Type type,
+                                               StringRef nativeType);
   template <typename ConstraintT>
   ast::Decl *
   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
-                                    SMRange loc, ast::Type type);
+                                    SMRange loc, ast::Type type,
+                                    StringRef nativeType);
 
   //===--------------------------------------------------------------------===//
   // Decls
@@ -610,8 +612,7 @@ LogicalResult Parser::convertExpressionTo(
     if (type == valueTy) {
       // If the operation is registered, we can verify if it can ever have a
       // single result.
-      Optional<StringRef> opName = exprOpType.getName();
-      if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
+      if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
         if (odsOp->getResults().empty()) {
           return emitConvertError()->attachNote(
               llvm::formatv("see the definition of `{0}`, which was defined "
@@ -821,7 +822,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
     ods::Operation *odsOp = nullptr;
     std::tie(odsOp, inserted) = odsContext.insertOperation(
         op.getOperationName(), op.getSummary(), op.getDescription(),
-        supportsResultTypeInferrence, op.getLoc().front());
+        op.getQualCppClassName(), supportsResultTypeInferrence,
+        op.getLoc().front());
 
     // Ignore operations that have already been added.
     if (!inserted)
@@ -846,19 +848,21 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
   /// Attr constraints.
   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
+      tblgen::Attribute constraint(def);
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
-              tblgen::AttrConstraint(def),
-              convertLocToRange(def->getLoc().front()), attrTy));
+              constraint, convertLocToRange(def->getLoc().front()), attrTy,
+              constraint.getStorageType()));
     }
   }
   /// Type constraints.
   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
+      tblgen::TypeConstraint constraint(def);
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
-              tblgen::TypeConstraint(def),
-              convertLocToRange(def->getLoc().front()), typeTy));
+              constraint, convertLocToRange(def->getLoc().front()), typeTy,
+              constraint.getCPPClassName()));
     }
   }
   /// Interfaces.
@@ -870,24 +874,26 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
       continue;
     SMRange loc = convertLocToRange(def->getLoc().front());
 
-    StringRef className = def->getValueAsString("cppClassName");
-    StringRef cppNamespace = def->getValueAsString("cppNamespace");
+    std::string cppClassName =
+        llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
+                      def->getValueAsString("cppClassName"))
+            .str();
     std::string codeBlock =
-        llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));",
-                      cppNamespace, className)
+        llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
+                      cppClassName)
             .str();
 
     if (def->isSubClassOf("OpInterface")) {
       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
-          name, codeBlock, loc, opTy));
+          name, codeBlock, loc, opTy, cppClassName));
     } else if (def->isSubClassOf("AttrInterface")) {
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
-              name, codeBlock, loc, attrTy));
+              name, codeBlock, loc, attrTy, cppClassName));
     } else if (def->isSubClassOf("TypeInterface")) {
       decls.push_back(
           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
-              name, codeBlock, loc, typeTy));
+              name, codeBlock, loc, typeTy, cppClassName));
     }
   }
 }
@@ -895,7 +901,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
 template <typename ConstraintT>
 ast::Decl *
 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
-                                          SMRange loc, ast::Type type) {
+                                          SMRange loc, ast::Type type,
+                                          StringRef nativeType) {
   // Build the single input parameter.
   ast::DeclScope *argScope = pushDeclScope();
   auto *paramVar = ast::VariableDecl::create(
@@ -907,7 +914,7 @@ Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
   // Build the native constraint.
   auto *constraintDecl = ast::UserConstraintDecl::createNative(
       ctx, ast::Name::create(ctx, name, loc), paramVar,
-      /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
+      /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
   curDeclScope->add(constraintDecl);
   return constraintDecl;
 }
@@ -915,7 +922,8 @@ Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
 template <typename ConstraintT>
 ast::Decl *
 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
-                                          SMRange loc, ast::Type type) {
+                                          SMRange loc, ast::Type type,
+                                          StringRef nativeType) {
   // Format the condition template.
   tblgen::FmtContext fmtContext;
   fmtContext.withSelf("self");
@@ -924,7 +932,7 @@ Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
       &fmtContext);
 
   return createODSNativePDLLConstraintDecl<ConstraintT>(
-      constraint.getUniqueDefName(), codeBlock, loc, type);
+      constraint.getUniqueDefName(), codeBlock, loc, type, nativeType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2534,7 +2542,8 @@ LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
     constraintType = ast::AttributeType::get(ctx);
   } else if (const auto *cst =
                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
-    constraintType = ast::OperationType::get(ctx, cst->getName());
+    constraintType = ast::OperationType::get(
+        ctx, cst->getName(), lookupODSOperation(cst->getName()));
   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
     constraintType = typeTy;
   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
@@ -2710,7 +2719,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
       return valueRangeTy;
 
     // Verify member access based on the operation type.
-    if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
+    if (const ods::Operation *odsOp = opType.getODSOperation()) {
       auto results = odsOp->getResults();
 
       // Handle indexed results.
@@ -2792,7 +2801,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
   }
 
-  return ast::OperationExpr::create(ctx, loc, name, operands, results,
+  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
                                     attributes);
 }
 

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 215a0e6745241..cbce9f8ab0cd5 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -651,9 +651,7 @@ class LSPCodeCompleteContext : public CodeCompleteContext {
   }
 
   void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
-    Optional<StringRef> opName = opType.getName();
-    const ods::Operation *odsOp =
-        opName ? odsContext.lookupOperation(*opName) : nullptr;
+    const ods::Operation *odsOp = opType.getODSOperation();
     if (!odsOp)
       return;
 

diff  --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 802958a3872f2..4dae177f10993 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -44,43 +44,22 @@ Pattern => erase op<test.op3>;
 // Check the generation of native constraints and rewrites.
 
 // CHECK:      static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
-// CHECK-SAME:                                           ::llvm::ArrayRef<::mlir::PDLValue> values) {
-// CHECK:   ::mlir::Attribute attr = {};
-// CHECK:   if (values[0])
-// CHECK:     attr = values[0].cast<::mlir::Attribute>();
-// CHECK:   ::mlir::Operation * op = {};
-// CHECK:   if (values[1])
-// CHECK:     op = values[1].cast<::mlir::Operation *>();
-// CHECK:   ::mlir::Type type = {};
-// CHECK:   if (values[2])
-// CHECK:     type = values[2].cast<::mlir::Type>();
-// CHECK:   ::mlir::Value value = {};
-// CHECK:   if (values[3])
-// CHECK:     value = values[3].cast<::mlir::Value>();
-// CHECK:   ::mlir::TypeRange typeRange = {};
-// CHECK:   if (values[4])
-// CHECK:     typeRange = values[4].cast<::mlir::TypeRange>();
-// CHECK:   ::mlir::ValueRange valueRange = {};
-// CHECK:   if (values[5])
-// CHECK:     valueRange = values[5].cast<::mlir::ValueRange>();
-
-// CHECK:   return success();
+// CHECK-SAME:     ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type,
+// CHECK-SAME:     ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) {
+// CHECK-NEXT:   return success();
 // CHECK: }
 
 // CHECK-NOT: TestUnusedCst
 
-// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results,
-// CHECK-SAME:                         ::llvm::ArrayRef<::mlir::PDLValue> values) {
-// CHECK:   ::mlir::Attribute attr = {};
-// CHECK:   ::mlir::Operation * op = {};
-// CHECK:   ::mlir::Type type = {};
-// CHECK:   ::mlir::Value value = {};
-// CHECK:   ::mlir::TypeRange typeRange = {};
-// CHECK:   ::mlir::ValueRange valueRange = {};
-
+// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter,
+// CHECK-SAME:     ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type,
+// CHECK-SAME:     ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) {
 // CHECK: foo;
 // CHECK: }
 
+// CHECK: static ::mlir::Attribute TestRewriteSinglePDLFn(::mlir::PatternRewriter &rewriter) {
+// CHECK: std::tuple<::mlir::Attribute, ::mlir::Type> TestRewriteTuplePDLFn(::mlir::PatternRewriter &rewriter) {
+
 // CHECK-NOT: TestUnusedRewrite
 
 // CHECK: struct TestCstAndRewrite : ::mlir::PDLPatternModule {
@@ -93,6 +72,8 @@ Constraint TestCst(attr: Attr, op: Op, type: Type, value: Value, typeRange: Type
 Constraint TestUnusedCst() [{ return success(); }];
 
 Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }];
+Rewrite TestRewriteSingle() -> Attr [{}];
+Rewrite TestRewriteTuple() -> (Attr, Type) [{}];
 Rewrite TestUnusedRewrite(op: Op) [{}];
 
 Pattern TestCstAndRewrite {
@@ -100,6 +81,8 @@ Pattern TestCstAndRewrite {
   TestCst(attr<"true">, root, type, operand, types, operands);
   rewrite root with {
     TestRewrite(attr<"true">, root, type, operand, types, operands);
+    TestRewriteSingle();
+    TestRewriteTuple();
     erase root;
   };
 }


        


More information about the Mlir-commits mailing list