[Mlir-commits] [mlir] 2692eae - [MLIR][PDL] Refactor the positions for multi-root patterns.

Uday Bondhugula llvmlistbot at llvm.org
Mon Jan 3 18:39:01 PST 2022


Author: Stanislav Funiak
Date: 2022-01-04T08:03:44+05:30
New Revision: 2692eae57428e1136ab58ac4004883245d0623ca

URL: https://github.com/llvm/llvm-project/commit/2692eae57428e1136ab58ac4004883245d0623ca
DIFF: https://github.com/llvm/llvm-project/commit/2692eae57428e1136ab58ac4004883245d0623ca.diff

LOG: [MLIR][PDL] Refactor the positions for multi-root patterns.

When the original version of multi-root patterns was reviewed, several improvements were made to the pdl_interp operations during the review process. Specifically, the "get users of a value at the specified operand index" was split up into "get users" and "compare the users' operands with that value". The iterative execution was also cleaned up to `pdl_interp.foreach`. However, the positions in the pdl-to-pdl_interp lowering were not similarly refactored. This introduced several problems, including hard-to-detect bugs in the lowering and duplicate evaluation of `pdl_interp.get_users`.

This diff cleans up the positions. The "upward" `OperationPosition` was split-out into `UsersPosition` and `ForEachPosition`, and the operand comparison was replaced with a simple predicate. In the process, I fixed three bugs:
1. When multiple roots were had the same connector (i.e., a node that they shared with a subtree at the previously visited root), we would generate a single foreach loop rather than one foreach loop for each such root. The reason for this is that such connectors shared the position. The solution for this is to add root index as an id to the newly introduced `ForEachPosition`.
2. Previously, we would use `pdl_interp.get_operands` indiscriminately, whether or not the operand was variadic. We now correctly detect variadic operands and insert `pdl_interp.get_operand` when needed.
3. In certain corner cases, we would trigger the "connector has not been traversed yet" assertion. This was caused by not inserting the values during the upward traversal correctly. This has now been fixed.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116080

Added: 
    

Modified: 
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 367bbb55ee1b2..9362a29ddb6f1 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -248,45 +248,43 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
   switch (pos->getKind()) {
   case Predicates::OperationPos: {
     auto *operationPos = cast<OperationPosition>(pos);
-    if (!operationPos->isUpward()) {
+    if (operationPos->isOperandDefiningOp())
       // Standard (downward) traversal which directly follows the defining op.
       value = builder.create<pdl_interp::GetDefiningOpOp>(
           loc, builder.getType<pdl::OperationType>(), parentVal);
-      break;
-    }
+    else
+      // A passthrough operation position.
+      value = parentVal;
+    break;
+  }
+  case Predicates::UsersPos: {
+    auto *usersPos = cast<UsersPosition>(pos);
 
     // The first operation retrieves the representative value of a range.
-    // This applies only when the parent is a range of values.
-    if (parentVal.getType().isa<pdl::RangeType>())
+    // This applies only when the parent is a range of values and we were
+    // requested to use a representative value (e.g., upward traversal).
+    if (parentVal.getType().isa<pdl::RangeType>() &&
+        usersPos->useRepresentative())
       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
     else
       value = parentVal;
 
     // The second operation retrieves the users.
     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
-
-    // The third operation iterates over them.
+    break;
+  }
+  case Predicates::ForEachPos: {
     assert(!failureBlockStack.empty() && "expected valid failure block");
     auto foreach = builder.create<pdl_interp::ForEachOp>(
-        loc, value, failureBlockStack.back(), /*initLoop=*/true);
+        loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
     value = foreach.getLoopVariable();
 
-    // Create the success and continuation blocks.
-    Block *successBlock = builder.createBlock(&foreach.region());
-    Block *continueBlock = builder.createBlock(successBlock);
+    // Create the continuation block.
+    Block *continueBlock = builder.createBlock(&foreach.region());
     builder.create<pdl_interp::ContinueOp>(loc);
     failureBlockStack.push_back(continueBlock);
 
-    // The fourth operation extracts the operand(s) of the user at the specified
-    // index (which can be None, indicating all operands).
-    builder.setInsertionPointToStart(&foreach.region().front());
-    Value operands = builder.create<pdl_interp::GetOperandsOp>(
-        loc, parentVal.getType(), value, operationPos->getIndex());
-
-    // The fifth operation compares the operands to the parent value / range.
-    builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands,
-                                           successBlock, continueBlock);
-    currentBlock = successBlock;
+    currentBlock = &foreach.region().front();
     break;
   }
   case Predicates::OperandPos: {

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
index 07fa5c77c13fe..a12f3171e7afa 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
@@ -48,4 +48,6 @@ OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) {
 //===----------------------------------------------------------------------===//
 // OperationPosition
 
-constexpr unsigned OperationPosition::kDown;
+bool OperationPosition::isOperandDefiningOp() const {
+  return isa_and_nonnull<OperandPosition, OperandGroupPosition>(parent);
+}

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 266580bd41f59..1d723996f8c31 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -52,6 +52,8 @@ enum Kind : unsigned {
   TypePos,
   AttributeLiteralPos,
   TypeLiteralPos,
+  UsersPos,
+  ForEachPos,
 
   // Questions, ordered by dependency and decreasing priority.
   IsNotNullQuestion,
@@ -185,6 +187,20 @@ struct AttributeLiteralPosition
   using PredicateBase::PredicateBase;
 };
 
+//===----------------------------------------------------------------------===//
+// ForEachPosition
+
+/// A position describing an iterative choice of an operation.
+struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
+                                              std::pair<Position *, unsigned>,
+                                              Predicates::ForEachPos> {
+  explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
+
+  /// Returns the ID, for 
diff erentiating various loops.
+  /// For upward traversals, this is the index of the root.
+  unsigned getID() const { return key.second; }
+};
+
 //===----------------------------------------------------------------------===//
 // OperandPosition
 
@@ -229,14 +245,11 @@ struct OperandGroupPosition
 
 /// An operation position describes an operation node in the IR. Other position
 /// kinds are formed with respect to an operation position.
-struct OperationPosition
-    : public PredicateBase<OperationPosition, Position,
-                           std::tuple<Position *, Optional<unsigned>, unsigned>,
-                           Predicates::OperationPos> {
-  static constexpr unsigned kDown = std::numeric_limits<unsigned>::max();
-
+struct OperationPosition : public PredicateBase<OperationPosition, Position,
+                                                std::pair<Position *, unsigned>,
+                                                Predicates::OperationPos> {
   explicit OperationPosition(const KeyTy &key) : Base(key) {
-    parent = std::get<0>(key);
+    parent = key.first;
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -246,31 +259,22 @@ struct OperationPosition
 
   /// Gets the root position.
   static OperationPosition *getRoot(StorageUniquer &uniquer) {
-    return Base::get(uniquer, nullptr, kDown, 0);
+    return Base::get(uniquer, nullptr, 0);
   }
 
-  /// Gets an downward operation position with the given parent.
+  /// Gets an operation position with the given parent.
   static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
-    return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1);
-  }
-
-  /// Gets an upward operation position with the given parent and operand.
-  static OperationPosition *get(StorageUniquer &uniquer, Position *parent,
-                                Optional<unsigned> operand) {
-    return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1);
+    return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
   }
 
-  /// Returns the operand index for an upward operation position.
-  Optional<unsigned> getIndex() const { return std::get<1>(key); }
-
-  /// Returns if this operation position is upward, accepting an input.
-  bool isUpward() const { return getIndex().getValueOr(0) != kDown; }
-
   /// Returns the depth of this position.
-  unsigned getDepth() const { return std::get<2>(key); }
+  unsigned getDepth() const { return key.second; }
 
   /// Returns if this operation position corresponds to the root.
   bool isRoot() const { return getDepth() == 0; }
+
+  /// Returns if this operation represents an operand defining op.
+  bool isOperandDefiningOp() const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -340,6 +344,26 @@ struct TypeLiteralPosition
   using PredicateBase::PredicateBase;
 };
 
+//===----------------------------------------------------------------------===//
+// UsersPosition
+
+/// A position describing the users of a value or a range of values. The second
+/// value in the key indicates whether we choose users of a representative for
+/// a range (this is true, e.g., in the upward traversals).
+struct UsersPosition
+    : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
+                           Predicates::UsersPos> {
+  explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
+
+  /// Returns a hash suitable for the given keytype.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_value(key);
+  }
+
+  /// Indicates whether to compute a range of a representative.
+  bool useRepresentative() const { return key.second; }
+};
+
 //===----------------------------------------------------------------------===//
 // Qualifiers
 //===----------------------------------------------------------------------===//
@@ -496,6 +520,7 @@ class PredicateUniquer : public StorageUniquer {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
+    registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
     registerParametricStorageType<OperationPosition>();
@@ -503,6 +528,7 @@ class PredicateUniquer : public StorageUniquer {
     registerParametricStorageType<ResultGroupPosition>();
     registerParametricStorageType<TypePosition>();
     registerParametricStorageType<TypeLiteralPosition>();
+    registerParametricStorageType<UsersPosition>();
 
     // Register the types of Questions with the uniquer.
     registerParametricStorageType<AttributeAnswer>();
@@ -550,12 +576,10 @@ class PredicateBuilder {
     return OperationPosition::get(uniquer, p);
   }
 
-  /// Returns the position of operation using the value at the given index.
-  OperationPosition *getUsersOp(Position *p, Optional<unsigned> operand) {
-    assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
-                ResultGroupPosition>(p)) &&
-           "expected result position");
-    return OperationPosition::get(uniquer, p, operand);
+  /// Returns the operation position equivalent to the given position.
+  OperationPosition *getPassthroughOp(Position *p) {
+    assert((isa<ForEachPosition>(p)) && "expected users position");
+    return OperationPosition::get(uniquer, p);
   }
 
   /// Returns an attribute position for an attribute of the given operation.
@@ -568,6 +592,10 @@ class PredicateBuilder {
     return AttributeLiteralPosition::get(uniquer, attr);
   }
 
+  Position *getForEach(Position *p, unsigned id) {
+    return ForEachPosition::get(uniquer, p, id);
+  }
+
   /// Returns an operand position for an operand of the given operation.
   Position *getOperand(OperationPosition *p, unsigned operand) {
     return OperandPosition::get(uniquer, p, operand);
@@ -605,6 +633,14 @@ class PredicateBuilder {
     return TypeLiteralPosition::get(uniquer, attr);
   }
 
+  /// Returns the users of a position using the value at the given operand.
+  UsersPosition *getUsers(Position *p, bool useRepresentative) {
+    assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
+                ResultGroupPosition>(p)) &&
+           "expected result position");
+    return UsersPosition::get(uniquer, p, useRepresentative);
+  }
+
   //===--------------------------------------------------------------------===//
   // Qualifiers
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 43c57a8e6033a..24b2f19e58c2b 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -158,8 +158,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
   // group, we treat it as all of the operands/results of the operation.
   /// Operands.
   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
-    getTreePredicates(predList, operands.front(), builder, inputs,
-                      builder.getAllOperands(opPos));
+    // Ignore the operands if we are performing an upward traversal (in that
+    // case, they have already been visited).
+    if (opPos->isRoot() || opPos->isOperandDefiningOp())
+      getTreePredicates(predList, operands.front(), builder, inputs,
+                        builder.getAllOperands(opPos));
   } else {
     bool foundVariableLength = false;
     for (const auto &operandIt : llvm::enumerate(operands)) {
@@ -502,23 +505,47 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
          "the pattern contains a candidate root disconnected from the others");
 }
 
+/// Returns true if the operand at the given index needs to be queried using an
+/// operand group, i.e., if it is variadic itself or follows a variadic operand.
+static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
+  OperandRange operands = op.operands();
+  assert(index < operands.size() && "operand index out of range");
+  for (unsigned i = 0; i <= index; ++i)
+    if (operands[i].getType().isa<pdl::RangeType>())
+      return true;
+  return false;
+}
+
 /// Visit a node during upward traversal.
-void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
-                 PredicateBuilder &builder,
-                 DenseMap<Value, Position *> &valueToPosition, Position *&pos,
-                 bool &first) {
+static void visitUpward(std::vector<PositionalPredicate> &predList,
+                        OpIndex opIndex, PredicateBuilder &builder,
+                        DenseMap<Value, Position *> &valueToPosition,
+                        Position *&pos, unsigned rootID) {
   Value value = opIndex.parent;
   TypeSwitch<Operation *>(value.getDefiningOp())
       .Case<pdl::OperationOp>([&](auto operationOp) {
         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
-        OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index);
 
-        // Guard against traversing back to where we came from.
-        if (first) {
-          Position *parent = pos->getParent();
-          predList.emplace_back(opPos, builder.getNotEqualTo(parent));
-          first = false;
+        // Get users and iterate over them.
+        Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
+        Position *foreachPos = builder.getForEach(usersPos, rootID);
+        OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
+
+        // Compare the operand(s) of the user against the input value(s).
+        Position *operandPos;
+        if (!opIndex.index) {
+          // We are querying all the operands of the operation.
+          operandPos = builder.getAllOperands(opPos);
+        } else if (useOperandGroup(operationOp, *opIndex.index)) {
+          // We are querying an operand group.
+          Type type = operationOp.operands()[*opIndex.index].getType();
+          bool variadic = type.isa<pdl::RangeType>();
+          operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
+        } else {
+          // We are querying an individual operand.
+          operandPos = builder.getOperand(opPos, *opIndex.index);
         }
+        predList.emplace_back(operandPos, builder.getEqualTo(pos));
 
         // Guard against duplicate upward visits. These are not possible,
         // because if this value was already visited, it would have been
@@ -540,6 +567,9 @@ void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
         auto *opPos = dyn_cast<OperationPosition>(pos);
         assert(opPos && "operations and results must be interleaved");
         pos = builder.getResult(opPos, *opIndex.index);
+
+        // Insert the result position in case we have not visited it yet.
+        valueToPosition.try_emplace(value, pos);
       })
       .Case<pdl::ResultsOp>([&](auto resultOp) {
         // Traverse up a group of results.
@@ -550,6 +580,9 @@ void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
         else
           pos = builder.getAllResults(opPos);
+
+        // Insert the result position in case we have not visited it yet.
+        valueToPosition.try_emplace(value, pos);
       });
 }
 
@@ -568,7 +601,8 @@ static Value buildPredicateList(pdl::PatternOp pattern,
   LLVM_DEBUG({
     llvm::dbgs() << "Graph:\n";
     for (auto &target : graph) {
-      llvm::dbgs() << "  * " << target.first << "\n";
+      llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
+                   << "\n";
       for (auto &source : target.second) {
         RootOrderingEntry &entry = source.second;
         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
@@ -601,6 +635,17 @@ static Value buildPredicateList(pdl::PatternOp pattern,
     bestEdges = solver.preOrderTraversal(roots);
   }
 
+  // Print the best solution.
+  LLVM_DEBUG({
+    llvm::dbgs() << "Best tree:\n";
+    for (const std::pair<Value, Value> &edge : bestEdges) {
+      llvm::dbgs() << "  * " << edge.first;
+      if (edge.second)
+        llvm::dbgs() << " <- " << edge.second;
+      llvm::dbgs() << "\n";
+    }
+  });
+
   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
 
@@ -612,9 +657,9 @@ static Value buildPredicateList(pdl::PatternOp pattern,
   // Traverse the selected optimal branching. For all edges in order, traverse
   // up starting from the connector, until the candidate root is reached, and
   // call getTreePredicates at every node along the way.
-  for (const std::pair<Value, Value> &edge : bestEdges) {
-    Value target = edge.first;
-    Value source = edge.second;
+  for (auto it : llvm::enumerate(bestEdges)) {
+    Value target = it.value().first;
+    Value source = it.value().second;
 
     // Check if we already visited the target root. This happens in two cases:
     // 1) the initial root (bestRoot);
@@ -629,14 +674,13 @@ static Value buildPredicateList(pdl::PatternOp pattern,
     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
     Position *pos = valueToPosition.lookup(connector);
-    assert(pos && "The value has not been traversed yet");
-    bool first = true;
+    assert(pos && "connector has not been traversed yet");
 
     // Traverse from the connector upwards towards the target root.
     for (Value value = connector; value != target;) {
       OpIndex opIndex = parentMap.lookup(value);
       assert(opIndex.parent && "missing parent");
-      visitUpward(predList, opIndex, builder, valueToPosition, pos, first);
+      visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
       value = opIndex.parent;
     }
   }

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 984a31790a8bc..fd6cfe5fa7c5f 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -423,8 +423,8 @@ module @multi_root {
   // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]]
   // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value
   // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]]
-  // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]]
-  // CHECK-DAG:   pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]]
+  // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operand 0 of %[[ROOT2]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[OPERANDS]], %[[VAL1]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]]
   // CHECK-DAG:   pdl_interp.continue
   // CHECK-DAG:   %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]]
   // CHECK-DAG:   %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]]
@@ -433,7 +433,6 @@ module @multi_root {
   // CHECK-DAG:   pdl_interp.is_not_null %[[VAL1]] : !pdl.value
   // CHECK-DAG:   pdl_interp.is_not_null %[[VAL2]] : !pdl.value
   // CHECK-DAG:   pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation
-  // CHECK-DAG:   pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]]
 
   pdl.pattern @rewrite_multi_root : benefit(1) {
     %input1 = pdl.operand
@@ -556,7 +555,7 @@ module @variadic_results_at {
   // CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
   // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] {
   // CHECK-DAG:   %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]]
-  // CHECK-DAG:   pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range<value> -> ^{{.*}}, ^[[CONTINUE:.*]]
+  // CHECK-DAG:   pdl_interp.are_equal %[[OPERANDS]], %[[VALS]] : !pdl.range<value> -> ^{{.*}}, ^[[CONTINUE:.*]]
   // CHECK-DAG:   pdl_interp.is_not_null %[[ROOT2]]
   // CHECK-DAG:   pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1
   // CHECK-DAG:   pdl_interp.check_result_count of %[[ROOT2]] is 0
@@ -612,3 +611,83 @@ module @type_literal {
   }
 }
 
+// -----
+
+// CHECK-LABEL: module @common_connector
+module @common_connector {
+  // Check the correct lowering when multiple roots are using the same
+  // connector.
+
+  // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation)
+  // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 0 of %[[ROOTC]]
+  // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VAL2]] : !pdl.value
+  // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1
+  // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[INTER]]
+  // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL1]] : !pdl.value
+  // CHECK-DAG: pdl_interp.is_not_null %[[OP]]
+  // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.get_result 0 of %[[OP]]
+  // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
+  // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:   pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]]
+  // CHECK-DAG:   pdl_interp.continue
+  // CHECK-DAG:   pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:     pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]]
+  // CHECK-DAG:     %[[ROOTA_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTA]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTA_OP]], %[[VAL0]] : !pdl.value
+  // CHECK-DAG:     %[[ROOTB_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTB]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTB_OP]], %[[VAL0]] : !pdl.value
+  // CHECK-DAG    } -> ^[[CONTA:.*]]
+  pdl.pattern @common_connector : benefit(1) {
+      %type = pdl.type
+      %op = pdl.operation -> (%type, %type : !pdl.type, !pdl.type)
+      %val0 = pdl.result 0 of %op
+      %val1 = pdl.result 1 of %op
+      %rootA = pdl.operation (%val0 : !pdl.value)
+      %rootB = pdl.operation (%val0 : !pdl.value)
+      %inter = pdl.operation (%val1 : !pdl.value) -> (%type : !pdl.type)
+      %val2 = pdl.result 0 of %inter
+      %rootC = pdl.operation (%val2 : !pdl.value)
+      pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation)
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @common_connector_range
+module @common_connector_range {
+  // Check the correct lowering when multiple roots are using the same
+  // connector range.
+
+  // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation)
+  // CHECK-DAG: %[[VALS2:.*]] = pdl_interp.get_operands of %[[ROOTC]] : !pdl.range<value>
+  // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VALS2]] : !pdl.range<value>
+  // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1
+  // CHECK-DAG: %[[VALS1:.*]] = pdl_interp.get_operands of %[[INTER]] : !pdl.range<value>
+  // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS1]] : !pdl.range<value>
+  // CHECK-DAG: pdl_interp.is_not_null %[[OP]]
+  // CHECK-DAG: %[[VALS0:.*]] = pdl_interp.get_results 0 of %[[OP]]
+  // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS0]] : !pdl.value
+  // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
+  // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:   pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]]
+  // CHECK-DAG:   pdl_interp.continue
+  // CHECK-DAG:   pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] {
+  // CHECK-DAG:     pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]]
+  // CHECK-DAG:     %[[ROOTA_OPS:.*]] = pdl_interp.get_operands of %[[ROOTA]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTA_OPS]], %[[VALS0]] : !pdl.range<value>
+  // CHECK-DAG:     %[[ROOTB_OPS:.*]] = pdl_interp.get_operands of %[[ROOTB]]
+  // CHECK-DAG:     pdl_interp.are_equal %[[ROOTB_OPS]], %[[VALS0]] : !pdl.range<value>
+  // CHECK-DAG    } -> ^[[CONTA:.*]]
+  pdl.pattern @common_connector_range : benefit(1) {
+    %types = pdl.types
+    %op = pdl.operation -> (%types, %types : !pdl.range<type>, !pdl.range<type>)
+    %vals0 = pdl.results 0 of %op -> !pdl.range<value>
+    %vals1 = pdl.results 1 of %op -> !pdl.range<value>
+    %rootA = pdl.operation (%vals0 : !pdl.range<value>)
+    %rootB = pdl.operation (%vals0 : !pdl.range<value>)
+    %inter = pdl.operation (%vals1 : !pdl.range<value>) -> (%types : !pdl.range<type>)
+    %vals2 = pdl.results of %inter
+    %rootC = pdl.operation (%vals2 : !pdl.range<value>)
+    pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation)
+  }
+}


        


More information about the Mlir-commits mailing list