[Mlir-commits] [mlir] 242762c - [mlir][pdl] Restructure how results are represented.

River Riddle llvmlistbot at llvm.org
Tue Mar 16 13:20:28 PDT 2021


Author: River Riddle
Date: 2021-03-16T13:20:18-07:00
New Revision: 242762c9a3313c8aea176ca76fb77adf8edf0907

URL: https://github.com/llvm/llvm-project/commit/242762c9a3313c8aea176ca76fb77adf8edf0907
DIFF: https://github.com/llvm/llvm-project/commit/242762c9a3313c8aea176ca76fb77adf8edf0907.diff

LOG: [mlir][pdl] Restructure how results are represented.

Up until now, results have been represented as additional results to a pdl.operation. This is fairly clunky, as it mismatches the representation of the rest of the IR constructs(e.g. pdl.operand) and also isn't a viable representation for operations returned by pdl.create_native. This representation also creates much more difficult problems when factoring in support for variadic result groups, optional results, etc. To resolve some of these problems, and simplify adding support for variable length results, this revision extracts the representation for results out of pdl.operation in the form of a new `pdl.result` operation. This operation returns the result of an operation at a given index, e.g.:

```
%root = pdl.operation ...
%result = pdl.result 0 of %root
```

Differential Revision: https://reviews.llvm.org/D95719

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
    mlir/test/Dialect/PDL/invalid.mlir
    mlir/test/Dialect/PDL/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
index afdf50673ed4..1c9de16af358 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
@@ -48,7 +48,7 @@ def PDL_Dialect : Dialect {
 
       %resultType = pdl.type
       %inputOperand = pdl.operand
-      %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
+      %root = pdl.operation "foo.op"(%inputOperand) -> %resultType
       pdl.rewrite %root {
         pdl.replace %root with (%inputOperand)
       }

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 60590b1fcd01..76e4c5d022a4 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -177,7 +177,7 @@ def PDL_OperandOp : PDL_Op<"operand", [HasParent<"pdl::PatternOp">]> {
   let description = [{
     `pdl.operand` operations capture external operand edges into an operation
     node that originate from operations or block arguments not otherwise
-    specified within the pattern (e.g. via `pdl.operation`). These operations
+    specified within the pattern (e.g. via `pdl.result`). These operations
     define individual operands of a given operation. A `pdl.operand` may
     partially constrain an operand by specifying an expected value type
     (via a `pdl.type` operation).
@@ -223,8 +223,8 @@ def PDL_OperationOp
     `pdl.operation`s are composed of a name, and a set of attribute, operand,
     and result type values, that map to what those that would be on a
     constructed instance of that operation. The results of a `pdl.operation` are
-    a handle to the operation itself, and a handle to each of the operation
-    result values.
+    a handle to the operation itself. Handles to the results of the operation
+    can be extracted via `pdl.result`.
 
     When used within a matching context, the name of the operation may be
     omitted.
@@ -241,7 +241,7 @@ def PDL_OperationOp
 
     ```mlir
     // Define an instance of a `foo.op` operation.
-    %op, %results:4 = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type
+    %op = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type
     ```
   }];
 
@@ -250,8 +250,13 @@ def PDL_OperationOp
                        Variadic<PDL_Attribute>:$attributes,
                        StrArrayAttr:$attributeNames,
                        Variadic<PDL_Type>:$types);
-  let results = (outs PDL_Operation:$op,
-                      Variadic<PDL_Value>:$results);
+  let results = (outs PDL_Operation:$op);
+  let assemblyFormat = [{
+    ($name^)? (`(` $operands^ `)`)?
+    custom<OperationOpAttributes>($attributes, $attributeNames)
+    (`->` $types^)? attr-dict
+  }];
+
   let builders = [
     OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name,
       CArg<"ValueRange", "llvm::None">:$operandValues,
@@ -259,10 +264,9 @@ def PDL_OperationOp
       CArg<"ValueRange", "llvm::None">:$attrValues,
       CArg<"ValueRange", "llvm::None">:$resultTypes), [{
       auto nameAttr = name ? StringAttr() : $_builder.getStringAttr(*name);
-      build($_builder, $_state, $_builder.getType<OperationType>(), {}, nameAttr,
+      build($_builder, $_state, $_builder.getType<OperationType>(), nameAttr,
             operandValues, attrValues, $_builder.getStrArrayAttr(attrNames),
             resultTypes);
-      $_state.types.append(resultTypes.size(), $_builder.getType<ValueType>());
     }]>,
   ];
   let extraClassDeclaration = [{
@@ -293,7 +297,7 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> {
     pdl.pattern : benefit(1) {
       %resultType = pdl.type
       %inputOperand = pdl.operand
-      %root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType)
+      %root = pdl.operation "foo.op"(%inputOperand) -> (%resultType)
       pdl.rewrite %root {
         pdl.replace %root with (%inputOperand)
       }
@@ -368,6 +372,39 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::ResultOp
+//===----------------------------------------------------------------------===//
+
+def PDL_ResultOp : PDL_Op<"result"> {
+  let summary = "Extract a result from an operation";
+  let description = [{
+    `pdl.result` operations extract result edges from an operation node within
+    a pattern or rewrite region. The provided index is zero-based, and
+    represents the concrete result to extract, i.e. this is not the result index
+    as defined by the ODS definition of the operation.
+
+    Example:
+
+    ```mlir
+    // Extract a result:
+    %operation = pdl.operation ...
+    %result = pdl.result 1 of %operation
+
+    // Imagine the following IR being matched:
+    %result_0, %result_1 = foo.op ...
+
+    // If the example pattern snippet above were matching against `foo.op` in
+    // the IR snippted, `%result` would correspond to `%result_1`.
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$parent, I32Attr:$index);
+  let results = (outs PDL_Value:$val);
+  let assemblyFormat = "$index `of` $parent attr-dict";
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::RewriteOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index a225699e89f7..3368ceb9be88 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -85,6 +85,9 @@ struct PatternLowering {
   void generateRewriter(pdl::ReplaceOp replaceOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
+  void generateRewriter(pdl::ResultOp resultOp,
+                        DenseMap<Value, Value> &rewriteValues,
+                        function_ref<Value(Value)> mapRewriteValue);
   void generateRewriter(pdl::TypeOp typeOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
@@ -457,9 +460,10 @@ SymbolRefAttr PatternLowering::generateRewriter(
     for (Operation &rewriteOp : *rewriter.getBody()) {
       llvm::TypeSwitch<Operation *>(&rewriteOp)
           .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
-                pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) {
-            this->generateRewriter(op, rewriteValues, mapRewriteValue);
-          });
+                pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>(
+              [&](auto op) {
+                this->generateRewriter(op, rewriteValues, mapRewriteValue);
+              });
     }
   }
 
@@ -511,17 +515,15 @@ void PatternLowering::generateRewriter(
       operationOp.attributeNames());
   rewriteValues[operationOp.op()] = createdOp;
 
-  // Make all of the new operation results available.
-  OperandRange resultTypes = operationOp.types();
-  for (auto it : llvm::enumerate(operationOp.results())) {
+  // Generate accesses for any results that have their types constrained.
+  for (auto it : llvm::enumerate(operationOp.types())) {
+    Value &type = rewriteValues[it.value()];
+    if (type)
+      continue;
+
     Value getResultVal = builder.create<pdl_interp::GetResultOp>(
         loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
-    rewriteValues[it.value()] = getResultVal;
-
-    // If any of the types have not been resolved, make those available as well.
-    Value &type = rewriteValues[resultTypes[it.index()]];
-    if (!type)
-      type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
+    type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
   }
 }
 
@@ -540,29 +542,41 @@ void PatternLowering::generateRewriter(
 void PatternLowering::generateRewriter(
     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
+  SmallVector<Value, 4> replOperands;
+
   // If the replacement was another operation, get its results. `pdl` allows
   // for using an operation for simplicitly, but the interpreter isn't as
   // user facing.
-  ValueRange origOperands;
-  if (Value replOp = replaceOp.replOperation())
-    origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results();
-  else
-    origOperands = replaceOp.replValues();
+  if (Value replOp = replaceOp.replOperation()) {
+    pdl::OperationOp op = cast<pdl::OperationOp>(replOp.getDefiningOp());
+    for (unsigned i = 0, e = op.types().size(); i < e; ++i)
+      replOperands.push_back(builder.create<pdl_interp::GetResultOp>(
+          replOp.getLoc(), builder.getType<pdl::ValueType>(),
+          mapRewriteValue(replOp), i));
+  } else {
+    for (Value operand : replaceOp.replValues())
+      replOperands.push_back(mapRewriteValue(operand));
+  }
 
   // If there are no replacement values, just create an erase instead.
-  if (origOperands.empty()) {
+  if (replOperands.empty()) {
     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
                                         mapRewriteValue(replaceOp.operation()));
     return;
   }
 
-  SmallVector<Value, 4> replOperands;
-  for (Value operand : origOperands)
-    replOperands.push_back(mapRewriteValue(operand));
   builder.create<pdl_interp::ReplaceOp>(
       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
 }
 
+void PatternLowering::generateRewriter(
+    pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
+    function_ref<Value(Value)> mapRewriteValue) {
+  rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
+      resultOp.getLoc(), builder.getType<pdl::ValueType>(),
+      mapRewriteValue(resultOp.parent()), resultOp.index());
+}
+
 void PatternLowering::generateRewriter(
     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
@@ -602,8 +616,8 @@ void PatternLowering::generateOperationResultTypeRewriter(
   bool hasTypeInference = op.hasTypeInference();
   auto resultTypeValues = op.types();
   types.reserve(resultTypeValues.size());
-  for (auto it : llvm::enumerate(op.results())) {
-    Value result = it.value(), resultType = resultTypeValues[it.index()];
+  for (auto it : llvm::enumerate(resultTypeValues)) {
+    Value resultType = it.value();
 
     // Check for an already translated value.
     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
@@ -633,16 +647,11 @@ void PatternLowering::generateOperationResultTypeRewriter(
         if ((replacedOp = getReplacedOperationFrom(use)))
           break;
       fullReplacedOperation = replacedOp;
+      assert(fullReplacedOperation &&
+             "expected replaced op to infer a result type from");
     } else {
       replacedOp = fullReplacedOperation.getValue();
     }
-    // Infer from the result, as there was no fully replaced op.
-    if (!replacedOp) {
-      for (OpOperand &use : result.getUses())
-        if ((replacedOp = getReplacedOperationFrom(use)))
-          break;
-      assert(replacedOp && "expected replaced op to infer a result type from");
-    }
 
     auto replOpOp = cast<pdl::OperationOp>(replacedOp);
     types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index b3919609a640..4d5c909465da 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -433,7 +433,7 @@ class PredicateBuilder {
   Position *getRoot() { return OperationPosition::getRoot(uniquer); }
 
   /// Returns the parent position defining the value held by the given operand.
-  Position *getParent(OperandPosition *p) {
+  OperationPosition *getParent(OperandPosition *p) {
     std::vector<unsigned> index = p->getIndex();
     index.push_back(p->getOperandNumber());
     return OperationPosition::get(uniquer, index);

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 22794aa4d991..0db35f050515 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::pdl_to_pdl_interp;
@@ -20,151 +21,181 @@ using namespace mlir::pdl_to_pdl_interp;
 // Predicate List Building
 //===----------------------------------------------------------------------===//
 
+static void getTreePredicates(std::vector<PositionalPredicate> &predList,
+                              Value val, PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs,
+                              Position *pos);
+
 /// Compares the depths of two positions.
 static bool comparePosDepth(Position *lhs, Position *rhs) {
   return lhs->getIndex().size() < rhs->getIndex().size();
 }
 
-/// Collect the tree predicates anchored at the given value.
 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                               Value val, PredicateBuilder &builder,
                               DenseMap<Value, Position *> &inputs,
-                              Position *pos) {
-  // Make sure this input value is accessible to the rewrite.
-  auto it = inputs.try_emplace(val, pos);
+                              AttributePosition *pos) {
+  assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
+  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
+  predList.emplace_back(pos, builder.getIsNotNull());
+
+  // If the attribute has a type or value, add a constraint.
+  if (Value type = attr.type())
+    getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
+  else if (Attribute value = attr.valueAttr())
+    predList.emplace_back(pos, builder.getAttributeConstraint(value));
+}
 
-  // If this is an input value that has been visited in the tree, add a
-  // constraint to ensure that both instances refer to the same value.
-  if (!it.second &&
-      isa<pdl::AttributeOp, pdl::OperandOp, pdl::TypeOp>(val.getDefiningOp())) {
-    auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth);
-    predList.emplace_back(minMaxPositions.second,
-                          builder.getEqualTo(minMaxPositions.first));
-    return;
-  }
+static void getTreePredicates(std::vector<PositionalPredicate> &predList,
+                              Value val, PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs,
+                              OperandPosition *pos) {
+  assert(val.getType().isa<pdl::ValueType>() && "expected value type");
 
-  // Check for a per-position predicate to apply.
-  switch (pos->getKind()) {
-  case Predicates::AttributePos: {
-    assert(val.getType().isa<pdl::AttributeType>() &&
-           "expected attribute type");
-    pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
-    predList.emplace_back(pos, builder.getIsNotNull());
+  // Prevent traversal into a null value.
+  predList.emplace_back(pos, builder.getIsNotNull());
 
-    // If the attribute has a type, add a type constraint.
-    if (Value type = attr.type()) {
+  // If this is a typed operand, add a type constraint.
+  if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
+    if (Value type = in.type())
       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
 
-      // Check for a constant value of the attribute.
-    } else if (Optional<Attribute> value = attr.value()) {
-      predList.emplace_back(pos, builder.getAttributeConstraint(*value));
-    }
-    break;
+    // Otherwise, recurse into a result node.
+  } else if (auto resultOp = val.getDefiningOp<pdl::ResultOp>()) {
+    OperationPosition *parentPos = builder.getParent(pos);
+    Position *resultPos = builder.getResult(parentPos, resultOp.index());
+    predList.emplace_back(parentPos, builder.getIsNotNull());
+    predList.emplace_back(resultPos, builder.getEqualTo(pos));
+    getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos);
   }
-  case Predicates::OperandPos: {
-    assert(val.getType().isa<pdl::ValueType>() && "expected value type");
+}
+
+static void getTreePredicates(std::vector<PositionalPredicate> &predList,
+                              Value val, PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs,
+                              OperationPosition *pos) {
+  assert(val.getType().isa<pdl::OperationType>() && "expected operation");
+  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
+  OperationPosition *opPos = cast<OperationPosition>(pos);
 
-    // Prevent traversal into a null value.
+  // Ensure getDefiningOp returns a non-null operation.
+  if (!opPos->isRoot())
     predList.emplace_back(pos, builder.getIsNotNull());
 
-    // If this is a typed operand, add a type constraint.
-    if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
-      if (Value type = in.type()) {
-        getTreePredicates(predList, type, builder, inputs,
-                          builder.getType(pos));
-      }
-
-      // Otherwise, recurse into the parent node.
-    } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) {
-      getTreePredicates(predList, parentOp.op(), builder, inputs,
-                        builder.getParent(cast<OperandPosition>(pos)));
-    }
-    break;
+  // Check that this is the correct root operation.
+  if (Optional<StringRef> opName = op.name())
+    predList.emplace_back(pos, builder.getOperationName(*opName));
+
+  // Check that the operation has the proper number of operands and results.
+  OperandRange operands = op.operands();
+  OperandRange types = op.types();
+  predList.emplace_back(pos, builder.getOperandCount(operands.size()));
+  predList.emplace_back(pos, builder.getResultCount(types.size()));
+
+  // Recurse into any attributes, operands, or results.
+  for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
+    getTreePredicates(
+        predList, std::get<1>(it), builder, inputs,
+        builder.getAttribute(opPos,
+                             std::get<0>(it).cast<StringAttr>().getValue()));
   }
-  case Predicates::OperationPos: {
-    assert(val.getType().isa<pdl::OperationType>() && "expected operation");
-    pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
-    OperationPosition *opPos = cast<OperationPosition>(pos);
-
-    // Ensure getDefiningOp returns a non-null operation.
-    if (!opPos->isRoot())
-      predList.emplace_back(pos, builder.getIsNotNull());
-
-    // Check that this is the correct root operation.
-    if (Optional<StringRef> opName = op.name())
-      predList.emplace_back(pos, builder.getOperationName(*opName));
-
-    // Check that the operation has the proper number of operands and results.
-    OperandRange operands = op.operands();
-    ResultRange results = op.results();
-    predList.emplace_back(pos, builder.getOperandCount(operands.size()));
-    predList.emplace_back(pos, builder.getResultCount(results.size()));
-
-    // Recurse into any attributes, operands, or results.
-    for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
-      getTreePredicates(
-          predList, std::get<1>(it), builder, inputs,
-          builder.getAttribute(opPos,
-                               std::get<0>(it).cast<StringAttr>().getValue()));
-    }
-    for (auto operandIt : llvm::enumerate(operands))
-      getTreePredicates(predList, operandIt.value(), builder, inputs,
-                        builder.getOperand(opPos, operandIt.index()));
-
-    // Only recurse into results that are not referenced in the source tree.
-    for (auto resultIt : llvm::enumerate(results)) {
-      getTreePredicates(predList, resultIt.value(), builder, inputs,
-                        builder.getResult(opPos, resultIt.index()));
-    }
-    break;
+  for (auto operandIt : llvm::enumerate(operands)) {
+    getTreePredicates(predList, operandIt.value(), builder, inputs,
+                      builder.getOperand(opPos, operandIt.index()));
+  }
+  for (auto &resultIt : llvm::enumerate(types)) {
+    auto *resultPos = builder.getResult(pos, resultIt.index());
+    predList.emplace_back(resultPos, builder.getIsNotNull());
+    getTreePredicates(predList, resultIt.value(), builder, inputs,
+                      builder.getType(resultPos));
   }
-  case Predicates::ResultPos: {
-    assert(val.getType().isa<pdl::ValueType>() && "expected value type");
-    pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp());
+}
 
-    // Prevent traversing a null value.
-    predList.emplace_back(pos, builder.getIsNotNull());
+static void getTreePredicates(std::vector<PositionalPredicate> &predList,
+                              Value val, PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs,
+                              TypePosition *pos) {
+  assert(val.getType().isa<pdl::TypeType>() && "expected value type");
+  pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
 
-    // Traverse the type constraint.
-    unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber();
-    getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs,
-                      builder.getType(pos));
-    break;
-  }
-  case Predicates::TypePos: {
-    assert(val.getType().isa<pdl::TypeType>() && "expected value type");
-    pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
-
-    // Check for a constraint on a constant type.
-    if (Optional<Type> type = typeOp.type())
-      predList.emplace_back(pos, builder.getTypeConstraint(*type));
-    break;
-  }
-  default:
-    llvm_unreachable("unknown position kind");
+  // Check for a constraint on a constant type.
+  if (Optional<Type> type = typeOp.type())
+    predList.emplace_back(pos, builder.getTypeConstraint(*type));
+}
+
+/// Collect the tree predicates anchored at the given value.
+static void getTreePredicates(std::vector<PositionalPredicate> &predList,
+                              Value val, PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs,
+                              Position *pos) {
+  // Make sure this input value is accessible to the rewrite.
+  auto it = inputs.try_emplace(val, pos);
+  if (!it.second) {
+    // If this is an input value that has been visited in the tree, add a
+    // constraint to ensure that both instances refer to the same value.
+    if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperationOp, pdl::TypeOp>(
+            val.getDefiningOp())) {
+      auto minMaxPositions =
+          std::minmax(pos, it.first->second, comparePosDepth);
+      predList.emplace_back(minMaxPositions.second,
+                            builder.getEqualTo(minMaxPositions.first));
+    }
+    return;
   }
+
+  TypeSwitch<Position *>(pos)
+      .Case<AttributePosition, OperandPosition, OperationPosition,
+            TypePosition>([&](auto *derivedPos) {
+        getTreePredicates(predList, val, builder, inputs, derivedPos);
+      })
+      .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
 }
 
 /// Collect all of the predicates related to constraints within the given
 /// pattern operation.
-static void collectConstraintPredicates(
-    pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList,
-    PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) {
-  for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) {
-    OperandRange arguments = op.args();
-    ArrayAttr parameters = op.constParamsAttr();
-
-    std::vector<Position *> allPositions;
-    allPositions.reserve(arguments.size());
-    for (Value arg : arguments)
-      allPositions.push_back(inputs.lookup(arg));
-
-    // Push the constraint to the furthest position.
-    Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
-                                      comparePosDepth);
-    PredicateBuilder::Predicate pred =
-        builder.getConstraint(op.name(), std::move(allPositions), parameters);
-    predList.emplace_back(pos, pred);
+static void getConstraintPredicates(pdl::ApplyConstraintOp op,
+                                    std::vector<PositionalPredicate> &predList,
+                                    PredicateBuilder &builder,
+                                    DenseMap<Value, Position *> &inputs) {
+  OperandRange arguments = op.args();
+  ArrayAttr parameters = op.constParamsAttr();
+
+  std::vector<Position *> allPositions;
+  allPositions.reserve(arguments.size());
+  for (Value arg : arguments)
+    allPositions.push_back(inputs.lookup(arg));
+
+  // Push the constraint to the furthest position.
+  Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
+                                    comparePosDepth);
+  PredicateBuilder::Predicate pred =
+      builder.getConstraint(op.name(), std::move(allPositions), parameters);
+  predList.emplace_back(pos, pred);
+}
+
+static void getResultPredicates(pdl::ResultOp op,
+                                std::vector<PositionalPredicate> &predList,
+                                PredicateBuilder &builder,
+                                DenseMap<Value, Position *> &inputs) {
+  Position *&resultPos = inputs[op];
+  if (resultPos)
+    return;
+  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
+  resultPos = builder.getResult(parentPos, op.index());
+  predList.emplace_back(resultPos, builder.getIsNotNull());
+}
+
+/// Collect all of the predicates that cannot be determined via walking the
+/// tree.
+static void getNonTreePredicates(pdl::PatternOp pattern,
+                                 std::vector<PositionalPredicate> &predList,
+                                 PredicateBuilder &builder,
+                                 DenseMap<Value, Position *> &inputs) {
+  for (Operation &op : pattern.body().getOps()) {
+    if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op))
+      getConstraintPredicates(constraintOp, predList, builder, inputs);
+    else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
+      getResultPredicates(resultOp, predList, builder, inputs);
   }
 }
 
@@ -176,7 +207,7 @@ static void buildPredicateList(pdl::PatternOp pattern,
                                DenseMap<Value, Position *> &valueToPosition) {
   getTreePredicates(predList, pattern.getRewriter().root(), builder,
                     valueToPosition, builder.getRoot());
-  collectConstraintPredicates(pattern, predList, builder, valueToPosition);
+  getNonTreePredicates(pattern, predList, builder, valueToPosition);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index beb43d7072f2..d35aab41ba8f 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -28,21 +28,36 @@ void PDLDialect::initialize() {
   registerTypes();
 }
 
+//===----------------------------------------------------------------------===//
+// PDL Operations
+//===----------------------------------------------------------------------===//
+
 /// Returns true if the given operation is used by a "binding" pdl operation
 /// within the main matcher body of a `pdl.pattern`.
+static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
+  for (Operation *user : op->getUsers()) {
+    if (user->getBlock() != matcherBlock)
+      continue;
+    if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
+      return true;
+    // A result by itself is not binding, it must also be bound.
+    if (isa<ResultOp>(user) && hasBindingUseInMatcher(user, matcherBlock))
+      return true;
+  }
+  return false;
+}
+
+/// Returns success if the given operation is used by a "binding" pdl operation
+/// within the main matcher body of a `pdl.pattern`. On failure, emits an error
+/// with the given context message.
 static LogicalResult
 verifyHasBindingUseInMatcher(Operation *op,
                              StringRef bindableContextStr = "`pdl.operation`") {
   // If the pattern is not a pattern, there is nothing to do.
   if (!isa<PatternOp>(op->getParentOp()))
     return success();
-  Block *matcherBlock = op->getBlock();
-  for (Operation *user : op->getUsers()) {
-    if (user->getBlock() != matcherBlock)
-      continue;
-    if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
-      return success();
-  }
+  if (hasBindingUseInMatcher(op, op->getBlock()))
+    return success();
   return op->emitOpError()
          << "expected a bindable (i.e. " << bindableContextStr
          << ") user when defined in the matcher body of a `pdl.pattern`";
@@ -86,37 +101,12 @@ static LogicalResult verify(OperandOp op) {
 // pdl::OperationOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
+static ParseResult parseOperationOpAttributes(
+    OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
+    ArrayAttr &attrNamesAttr) {
   Builder &builder = p.getBuilder();
-
-  // Parse the optional operation name.
-  bool startsWithOperands = succeeded(p.parseOptionalLParen());
-  bool startsWithAttributes =
-      !startsWithOperands && succeeded(p.parseOptionalLBrace());
-  bool startsWithOpName = false;
-  if (!startsWithAttributes && !startsWithOperands) {
-    StringAttr opName;
-    OptionalParseResult opNameResult =
-        p.parseOptionalAttribute(opName, "name", state.attributes);
-    startsWithOpName = opNameResult.hasValue();
-    if (startsWithOpName && failed(*opNameResult))
-      return failure();
-  }
-
-  // Parse the operands.
-  SmallVector<OpAsmParser::OperandType, 4> operands;
-  if (startsWithOperands ||
-      (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
-    if (p.parseOperandList(operands) || p.parseRParen() ||
-        p.resolveOperands(operands, builder.getType<ValueType>(),
-                          state.operands))
-      return failure();
-  }
-
-  // Parse the attributes.
   SmallVector<Attribute, 4> attrNames;
-  if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
-    SmallVector<OpAsmParser::OperandType, 4> attrOps;
+  if (succeeded(p.parseOptionalLBrace())) {
     do {
       StringAttr nameAttr;
       OpAsmParser::OperandType operand;
@@ -124,68 +114,29 @@ static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
           p.parseOperand(operand))
         return failure();
       attrNames.push_back(nameAttr);
-      attrOps.push_back(operand);
+      attrOperands.push_back(operand);
     } while (succeeded(p.parseOptionalComma()));
-
-    if (p.parseRBrace() ||
-        p.resolveOperands(attrOps, builder.getType<AttributeType>(),
-                          state.operands))
-      return failure();
-  }
-  state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
-  state.addTypes(builder.getType<OperationType>());
-
-  // Parse the result types.
-  SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
-  if (succeeded(p.parseOptionalArrow())) {
-    if (p.parseOperandList(opResultTypes) ||
-        p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
-                          state.operands))
+    if (p.parseRBrace())
       return failure();
-    state.types.append(opResultTypes.size(), builder.getType<ValueType>());
   }
-
-  if (p.parseOptionalAttrDict(state.attributes))
-    return failure();
-
-  int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
-                                   static_cast<int32_t>(attrNames.size()),
-                                   static_cast<int32_t>(opResultTypes.size())};
-  state.addAttribute("operand_segment_sizes",
-                     builder.getI32VectorAttr(operandSegmentSizes));
+  attrNamesAttr = builder.getArrayAttr(attrNames);
   return success();
 }
 
-static void print(OpAsmPrinter &p, OperationOp op) {
-  p << "pdl.operation ";
-  if (Optional<StringRef> name = op.name())
-    p << '"' << *name << '"';
-
-  auto operandValues = op.operands();
-  if (!operandValues.empty())
-    p << '(' << operandValues << ')';
-
-  // Emit the optional attributes.
-  ArrayAttr attrNames = op.attributeNames();
-  if (!attrNames.empty()) {
-    Operation::operand_range attrArgs = op.attributes();
-    p << " {";
-    interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
-                    [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
-    p << '}';
-  }
-
-  // Print the result type constraints of the operation.
-  if (!op.results().empty())
-    p << " -> " << op.types();
-  p.printOptionalAttrDict(op->getAttrs(),
-                          {"attributeNames", "name", "operand_segment_sizes"});
+static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
+                                       OperandRange attrArgs,
+                                       ArrayAttr attrNames) {
+  if (attrNames.empty())
+    return;
+  p << " {";
+  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
+                  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
+  p << '}';
 }
 
 /// Verifies that the result types of this operation, defined within a
 /// `pdl.rewrite`, can be inferred.
 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
-                                                    ResultRange opResults,
                                                     OperandRange resultTypes) {
   // Functor that returns if the given use can be used to infer a type.
   Block *rewriterBlock = op->getBlock();
@@ -207,8 +158,8 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
     return success();
 
   // Otherwise, make sure each of the types can be inferred.
-  for (int i : llvm::seq<int>(0, opResults.size())) {
-    Operation *resultTypeOp = resultTypes[i].getDefiningOp();
+  for (auto it : llvm::enumerate(resultTypes)) {
+    Operation *resultTypeOp = it.value().getDefiningOp();
     assert(resultTypeOp && "expected valid result type operation");
 
     // If the op was defined by a `create_native`, it is guaranteed to be
@@ -229,14 +180,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
     if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
       continue;
 
-    // Otherwise, check to see if any uses of the result can infer the type.
-    if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
-      continue;
     return op
         .emitOpError("must have inferable or constrained result types when "
                      "nested within `pdl.rewrite`")
         .attachNote()
-        .append("result type #", i, " was not constrained");
+        .append("result type #", it.index(), " was not constrained");
   }
   return success();
 }
@@ -256,19 +204,10 @@ static LogicalResult verify(OperationOp op) {
            << " values";
   }
 
-  OperandRange resultTypes = op.types();
-  auto opResults = op.results();
-  if (resultTypes.size() != opResults.size()) {
-    return op.emitOpError() << "expected the same number of result values and "
-                               "result type constraints, got "
-                            << opResults.size() << " results and "
-                            << resultTypes.size() << " constraints";
-  }
-
   // If the operation is within a rewrite body and doesn't have type inference,
   // ensure that the result types can be resolved.
   if (isWithinRewrite && !op.hasTypeInference()) {
-    if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
+    if (failed(verifyResultTypesAreInferrable(op, op.types())))
       return failure();
   }
 
@@ -341,37 +280,9 @@ Optional<StringRef> PatternOp::getRootKind() {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(ReplaceOp op) {
-  auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
-  auto sourceOpResults = sourceOp.results();
-  auto replValues = op.replValues();
-
-  if (Value replOpVal = op.replOperation()) {
-    auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
-    auto replOpResults = replOp.results();
-    if (sourceOpResults.size() != replOpResults.size()) {
-      return op.emitOpError()
-             << "expected source operation to have the same number of results "
-                "as the replacement operation, replacement operation provided "
-             << replOpResults.size() << " but expected "
-             << sourceOpResults.size();
-    }
-
-    if (!replValues.empty()) {
-      return op.emitOpError() << "expected no replacement values to be provided"
-                                 " when the replacement operation is present";
-    }
-
-    return success();
-  }
-
-  if (sourceOpResults.size() != replValues.size()) {
-    return op.emitOpError()
-           << "expected source operation to have the same number of results "
-              "as the provided replacement values, found "
-           << replValues.size() << " replacement values but expected "
-           << sourceOpResults.size();
-  }
-
+  if (op.replOperation() && !op.replValues().empty())
+    return op.emitOpError() << "expected no replacement values to be provided"
+                               " when the replacement operation is present";
   return success();
 }
 

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 9d87ba5a21f0..c856ab5c9f6f 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -63,15 +63,16 @@ module @constraints {
   // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
   // CHECK-DAG:   %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
   // CHECK-DAG:   %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
-  // CHECK:       pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value)
+  // CHECK-DAG:   %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
+  // CHECK:       pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]]
 
   pdl.pattern : benefit(1) {
     %input0 = pdl.operand
     %input1 = pdl.operand
-
-    pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value)
-
     %root = pdl.operation(%input0, %input1)
+    %result0 = pdl.result 0 of %root
+
+    pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -107,19 +108,52 @@ module @results {
   // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
   // CHECK:   pdl_interp.check_result_count of %[[ROOT]] is 2
 
-  // Get the input and check the type.
+  // Get the result and check the type.
   // CHECK-DAG:   %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
   // CHECK-DAG:   pdl_interp.is_not_null %[[RESULT]] : !pdl.value
   // CHECK-DAG:   %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
   // CHECK-DAG:   pdl_interp.check_type %[[RESULT_TYPE]] is i32
 
-  // Get the second operand and check that it is equal to the first.
-  // CHECK-DAG:  %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]]
-  // CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]]
+  // The second result doesn't have any constraints, so we don't generate an
+  // access for it.
+  // CHECK-NOT:   pdl_interp.get_result 1 of %[[ROOT]]
   pdl.pattern : benefit(1) {
     %type1 = pdl.type : i32
     %type2 = pdl.type
-    %root, %results:2 = pdl.operation -> %type1, %type2
+    %root = pdl.operation -> %type1, %type2
+    pdl.rewrite %root with "rewriter"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @results_as_operands
+module @results_as_operands {
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+
+  // Get the first result and check it matches the first operand.
+  // CHECK-DAG:   %[[OPERAND_0:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
+  // CHECK-DAG:   %[[DEF_OP_0:.*]] = pdl_interp.get_defining_op of %[[OPERAND_0]]
+  // CHECK-DAG:   %[[RESULT_0:.*]] = pdl_interp.get_result 0 of %[[DEF_OP_0]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[RESULT_0]], %[[OPERAND_0]]
+
+  // Get the second result and check it matches the second operand.
+  // CHECK-DAG:   %[[OPERAND_1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
+  // CHECK-DAG:   %[[DEF_OP_1:.*]] = pdl_interp.get_defining_op of %[[OPERAND_1]]
+  // CHECK-DAG:   %[[RESULT_1:.*]] = pdl_interp.get_result 1 of %[[DEF_OP_1]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[RESULT_1]], %[[OPERAND_1]]
+
+  // Check that the parent operation of both results is the same.
+  // CHECK-DAG:   pdl_interp.are_equal %[[DEF_OP_0]], %[[DEF_OP_1]]
+
+  pdl.pattern : benefit(1) {
+    %type1 = pdl.type : i32
+    %type2 = pdl.type
+    %inputOp = pdl.operation -> %type1, %type2
+    %result1 = pdl.result 0 of %inputOp
+    %result2 = pdl.result 1 of %inputOp
+
+    %root = pdl.operation(%result1, %result2)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -134,12 +168,12 @@ module @switch_result_types {
   // CHECK:   pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64]
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root, %result = pdl.operation -> %type
+    %root = pdl.operation -> %type
     pdl.rewrite %root with "rewriter"
   }
   pdl.pattern : benefit(1) {
     %type = pdl.type : i64
-    %root, %result = pdl.operation -> %type
+    %root = pdl.operation -> %type
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -161,13 +195,13 @@ module @predicate_ordering  {
   pdl.pattern : benefit(1) {
     %resultType = pdl.type
     pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type)
-    %root, %result = pdl.operation -> %resultType
+    %root = pdl.operation -> %resultType
     pdl.rewrite %root with "rewriter"
   }
 
   pdl.pattern : benefit(1) {
     %resultType = pdl.type
-    %apply, %applyRes = pdl.operation -> %resultType
+    %apply = pdl.operation -> %resultType
     pdl.rewrite %apply with "rewriter"
   }
 }

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 4b6b1ae75700..5652b2118afe 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -63,7 +63,8 @@ module @operation_operands {
     %root = pdl.operation "foo.op"(%operand)
     pdl.rewrite %root {
       %type = pdl.type : i32
-      %newOp, %result = pdl.operation "foo.op"(%operand) -> %type
+      %newOp = pdl.operation "foo.op"(%operand) -> %type
+      %result = pdl.result 0 of %newOp
       %newOp1 = pdl.operation "foo.op2"(%result)
       pdl.erase %root
     }
@@ -84,7 +85,8 @@ module @operation_operands {
     %root = pdl.operation "foo.op"(%operand)
     pdl.rewrite %root {
       %type = pdl.type : i32
-      %newOp, %result = pdl.operation "foo.op"(%operand) -> %type
+      %newOp = pdl.operation "foo.op"(%operand) -> %type
+      %result = pdl.result 0 of %newOp
       %newOp1 = pdl.operation "foo.op2"(%result)
       pdl.erase %root
     }
@@ -101,10 +103,10 @@ module @operation_result_types {
   pdl.pattern : benefit(1) {
     %rootType = pdl.type
     %rootType1 = pdl.type
-    %root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1
+    %root = pdl.operation "foo.op" -> %rootType, %rootType1
     pdl.rewrite %root {
       %newType1 = pdl.type
-      %newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1
+      %newOp = pdl.operation "foo.op" -> %rootType, %newType1
       pdl.replace %root with %newOp
     }
   }
@@ -112,23 +114,6 @@ module @operation_result_types {
 
 // -----
 
-// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement
-module @operation_result_types_infer_from_value_replacement {
-  // CHECK: module @rewriters
-  // CHECK:   func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type
-  // CHECK:     pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
-  pdl.pattern : benefit(1) {
-    %rootType = pdl.type
-    %root, %result = pdl.operation "foo.op" -> %rootType
-    pdl.rewrite %root {
-      %newType = pdl.type
-      %newOp, %newResult = pdl.operation "foo.op" -> %newType
-      pdl.replace %root with (%newResult)
-    }
-  }
-}
-// -----
-
 // CHECK-LABEL: module @replace_with_op
 module @replace_with_op {
   // CHECK: module @rewriters
@@ -138,9 +123,9 @@ module @replace_with_op {
   // CHECK:     pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root, %result = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> %type
     pdl.rewrite %root {
-      %newOp, %newResult = pdl.operation "foo.op" -> %type
+      %newOp = pdl.operation "foo.op" -> %type
       pdl.replace %root with %newOp
     }
   }
@@ -157,9 +142,10 @@ module @replace_with_values {
   // CHECK:     pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root, %result = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> %type
     pdl.rewrite %root {
-      %newOp, %newResult = pdl.operation "foo.op" -> %type
+      %newOp = pdl.operation "foo.op" -> %type
+      %newResult = pdl.result 0 of %newOp
       pdl.replace %root with (%newResult)
     }
   }
@@ -192,10 +178,10 @@ module @create_native {
   // CHECK:     pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
   pdl.pattern : benefit(1) {
     %type = pdl.type
-    %root, %result = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> %type
     pdl.rewrite %root {
       %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
-      %newOp, %newResult = pdl.operation "foo.op" -> %newType
+      %newOp = pdl.operation "foo.op" -> %newType
       pdl.replace %root with %newOp
     }
   }

diff  --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index 0f4d96778277..0f900bbe3f53 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -24,7 +24,7 @@ pdl.pattern : benefit(1) {
   // expected-error at below {{expected only one of [`type`, `value`] to be set}}
   %attr = pdl.attribute : %type 10
 
-  %op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
+  %op = pdl.operation "foo.op" {"attr" = %attr} -> %type
   pdl.rewrite %op with "rewriter"
 }
 
@@ -108,7 +108,7 @@ pdl.pattern : benefit(1) {
 
     // expected-error at below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
     // expected-note at below {{result type #0 was not constrained}}
-    %newOp, %result = pdl.operation "foo.op" -> %type
+    %newOp = pdl.operation "foo.op" -> %type
   }
 }
 
@@ -147,28 +147,12 @@ pdl.pattern : benefit(1) {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// pdl::ReplaceOp
-//===----------------------------------------------------------------------===//
-
-pdl.pattern : benefit(1) {
-  %root = pdl.operation "foo.op"
-  pdl.rewrite %root {
-    %type = pdl.type : i32
-    %newOp, %newResult = pdl.operation "foo.op" -> %type
-
-    // expected-error at below {{to have the same number of results as the replacement operation}}
-    pdl.replace %root with %newOp
-  }
-}
-
-// -----
-
 pdl.pattern : benefit(1) {
   %type = pdl.type : i32
-  %root, %oldResult = pdl.operation "foo.op" -> %type
+  %root = pdl.operation "foo.op" -> %type
   pdl.rewrite %root {
-    %newOp, %newResult = pdl.operation "foo.op" -> %type
+    %newOp = pdl.operation "foo.op" -> %type
+    %newResult = pdl.result 0 of %newOp
 
     // expected-error at below {{expected no replacement values to be provided when the replacement operation is present}}
     "pdl.replace"(%root, %newOp, %newResult) {
@@ -179,19 +163,6 @@ pdl.pattern : benefit(1) {
 
 // -----
 
-pdl.pattern : benefit(1) {
-  %root = pdl.operation "foo.op"
-  pdl.rewrite %root {
-    %type = pdl.type : i32
-    %newOp, %newResult = pdl.operation "foo.op" -> %type
-
-    // expected-error at below {{to have the same number of results as the provided replacement values}}
-    pdl.replace %root with (%newResult)
-  }
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // pdl::RewriteOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 5b6a642daf83..d376f001fcfa 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -8,7 +8,8 @@ pdl.pattern @operations : benefit(1) {
   // Operation with attributes and results.
   %attribute = pdl.attribute
   %type = pdl.type
-  %op0, %op0_result = pdl.operation {"attr" = %attribute} -> %type
+  %op0 = pdl.operation {"attr" = %attribute} -> %type
+  %op0_result = pdl.result 0 of %op0
 
   // Operation with input.
   %input = pdl.operand
@@ -46,38 +47,23 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) {
 pdl.pattern @infer_type_from_operation_replace : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
-  %root, %results:2 = pdl.operation -> %type1, %type2
+  %root = pdl.operation -> %type1, %type2
   pdl.rewrite %root {
     %type3 = pdl.type
-    %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
+    %newOp = pdl.operation "foo.op" -> %type1, %type3
     pdl.replace %root with %newOp
   }
 }
 
 // -----
 
-// Check that the result type of an operation within a rewrite can be inferred
-// from a pdl.replace.
-pdl.pattern @infer_type_from_result_replace : benefit(1) {
-  %type1 = pdl.type : i32
-  %type2 = pdl.type
-  %root, %results:2 = pdl.operation -> %type1, %type2
-  pdl.rewrite %root {
-    %type3 = pdl.type
-    %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
-    pdl.replace %root with (%newResults#0, %newResults#1)
-  }
-}
-
-// -----
-
 // Check that the result type of an operation within a rewrite can be inferred
 // from a pdl.replace.
 pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
-  %root, %results:2 = pdl.operation -> %type1, %type2
+  %root = pdl.operation -> %type1, %type2
   pdl.rewrite %root {
-    %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
+    %newOp = pdl.operation "foo.op" -> %type1, %type2
   }
 }


        


More information about the Mlir-commits mailing list