[PATCH] D124750: [MLIR] Add a utility to sort the operands of commutative ops

Jeff Niu via Phabricator via cfe-commits cfe-commits at lists.llvm.org
Thu Jun 30 00:11:22 PDT 2022


Mogball added a comment.

I'm glad the `DenseSet`s are gone, but my three-ish biggest gripes are:

- The algorithm is conceptually simple, but there is way more code than is necessary to achieve it.
- More comments (excluding "doc" comments) than code is generally not a good sign
- The implementation is still inefficient in a lot of obvious ways.



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:10-11
+// This file implements a commutativity utility pattern and a function to
+// populate this pattern. The function is intended to be used inside passes to
+// simplify the matching of commutative operations.
+//
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:21-38
+/// Stores the "ancestor" of an operand of some op. The operand of any op is
+/// produced by a set of ops and block arguments. Each of these ops and block
+/// arguments is called an "ancestor" of this operand.
+struct Ancestor {
+  /// Stores true when the "ancestor" is an op and false when the "ancestor" is
+  /// a block argument.
+  bool isOp;
----------------
This class isn't necessary. `Ancestor` can just be `Operation *`. If it's null, then we know it's a block argument (the bool flag is redundant).


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:40
+
+/// Declares various "types" of ancestors.
+enum AncestorType {
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:65-66
+    if (!ancestor.isOp) {
+      // When `ancestor` is a block argument, we assign `type` as
+      // `BLOCK_ARGUMENT` and `opName` remains "".
+      type = BLOCK_ARGUMENT;
----------------
My biggest complaint with respect to readability is that there are more comments than code. This is fine if the comment has a big explanation about the algorithm and how keys are represented, especially with a nice ASCII diagram as you have below. But if this constructor had 0 comments except maybe "Only non-constant ops are sorted by name", it would be perfectly fine.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:76
+      // `CONSTANT_OP` and `opName` remains "".
+      type = CONSTANT_OP;
+    }
----------------
Constant ops could be sorted by name as well.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:88-96
+  bool operator<(const AncestorKey &key) const {
+    if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) ||
+        (type == NON_CONSTANT_OP && key.type == CONSTANT_OP))
+      return true;
+    if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) ||
+        (key.type == NON_CONSTANT_OP && type == CONSTANT_OP))
+      return false;
----------------
This should behave the same as the manually written comparator. `operator<` of `std::tuple` compares the first element and then the next if the first is equal.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:110
+  /// operand at a particular point in time.
+  DenseSet<Operation *> visitedAncestors;
+
----------------
Since this is a set of pointers expected to be small, you can use `SmallPtrSet` for a big efficiency boost (linear scan instead of hashing when small).


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:148-149
+
+  /// Stores true iff the operand has been assigned a sorted position yet.
+  bool isSorted = false;
+
----------------
Since you are moving sorted operands into their sorted position and tracking the unsorted range, you shouldn't even need this flag because you will always know the sorted and unsorted subranges.

There are multiple loops in which you iterate over the entire operand list but skip those where this flag is set/unset. In those cases, you can always just iterate the subrange of interest.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:155-159
+    Ancestor ancestor(op);
+    ancestorQueue.push(ancestor);
+    if (ancestor.isOp)
+      visitedAncestors.insert(ancestor.op);
+    return;
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:162-177
+  /// Pop the ancestor from the front of the queue.
+  void popAncestor() {
+    assert(!ancestorQueue.empty() &&
+           "to pop the ancestor from the front of the queue, the ancestor "
+           "queue should be non-empty");
+    ancestorQueue.pop();
+    return;
----------------
I would drop these helpers. `std::queue` will already assert if `pop` or `front` are called on an empty queue.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:200-217
+  unsigned keyASize = keyA.size();
+  unsigned keyBSize = keyB.size();
+  unsigned smallestSize = keyASize;
+  if (keyBSize < smallestSize)
+    smallestSize = keyBSize;
+
+  for (unsigned i = 0; i < smallestSize; i++) {
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:267
+    ArrayRef<AncestorKey> key,
+    SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfOperandsWithKey,
+    ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands,
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:269
+    ArrayRef<std::unique_ptr<OperandBFS>> bfsOfOperands,
+    bool &hasOneOperandWithKey) {
+  bool keyFound = false;
----------------
This flag is not necessary because you can just check `bfsOfOperandsWithKey.size() == 1`


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:277
+    ArrayRef<AncestorKey> currentKey = bfsOfOperand->key;
+    if (compareKeys(key, currentKey) == 0) {
+      bfsOfOperandsWithKey.push_back(
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:278-279
+    if (compareKeys(key, currentKey) == 0) {
+      bfsOfOperandsWithKey.push_back(
+          std::make_unique<OperandBFS>(*bfsOfOperand));
+      if (keyFound)
----------------
You don't need to make a copy. In fact, I think you should just track the indices.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:298-304
+
+  // If there exists no unsorted operand, return false.
+  if (llvm::all_of(bfsOfOperands,
+                   [](const std::unique_ptr<OperandBFS> &bfsOfOperand) {
+                     return bfsOfOperand->isSorted;
+                   }))
+    return false;
----------------
This shouldn't be necessary if you're tracking the unsorted subrange. Just quit when it gets to 1 element.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:307-308
+  // Get the smallest key present among the unsorted operands.
+  ArrayRef<AncestorKey> smallestKey =
+      computeTheSmallestUnsortedKey(bfsOfOperands);
+
----------------



================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:312
+  getBFSOfOperandsWithKey(
+      /*key=*/smallestKey,
+      /*bfsOfOperandsWithKey=*/bfsOfSmallestUnsortedOperands, bfsOfOperands,
----------------
Argument names in comments are not necessary when the passed variable has a descriptive name.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:329-353
+  assert(frontPosition >= 0 && frontPosition < bfsOfOperands.size() &&
+         "`frontPosition` should be valid");
+  unsigned positionOfOperandToShift;
+  bool foundOperandToShift = false;
+  for (auto &indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) {
+    std::unique_ptr<OperandBFS> &bfsOfOperand = indexedBfsOfOperand.value();
+    if (bfsOfOperand->isSorted)
----------------
There is no way you need this much code. A `std::swap` between the current operand and the first unsorted position should be enough.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:360
+/// the smallest position containing an unsorted operand).
+static void shiftTheSmallestUnsortedOperandsToTheSmallestUnsortedPositions(
+    SmallVectorImpl<std::unique_ptr<OperandBFS>> &bfsOfOperands,
----------------
This is possibly the longest function name I've ever seen. Please make it more concise.


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:400
+      if (!operandDefOp ||
+          !bfsOfOperand->visitedAncestors.contains(operandDefOp))
+        bfsOfOperand->pushAncestor(operandDefOp);
----------------
This check could be moved into `pushAncestor`


================
Comment at: mlir/lib/Transforms/Utils/CommutativityUtils.cpp:535-538
+
+    // If the operands were already sorted, return failure.
+    if (unsortedOperands == sortedOperands)
+      return failure();
----------------
You could just be returning a flag to indicate whether any swapping occurred so that you don't have to track before and after.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D124750/new/

https://reviews.llvm.org/D124750



More information about the cfe-commits mailing list