[Mlir-commits] [mlir] 6127819 - Preserve use-list orders in mlir bytecode

Mehdi Amini llvmlistbot at llvm.org
Sun May 21 17:20:17 PDT 2023


Author: Matteo Franciolini
Date: 2023-05-21T16:48:12-07:00
New Revision: 612781918fb01a2a0985a1c4c9200f5d5d1581cc

URL: https://github.com/llvm/llvm-project/commit/612781918fb01a2a0985a1c4c9200f5d5d1581cc
DIFF: https://github.com/llvm/llvm-project/commit/612781918fb01a2a0985a1c4c9200f5d5d1581cc.diff

LOG: Preserve use-list orders in mlir bytecode

This patch implements a mechanism to read/write use-list orders from/to the mlir bytecode format. When producing bytecode, use-list orders are appended to each value of the IR. When reading bytecode, use-lists orders are loaded in memory and used at the end of parsing to sort the existing use-list chains.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D149755

Added: 
    mlir/include/mlir/Bytecode/Encoding.h
    mlir/test/Bytecode/uselist_orders.mlir
    mlir/test/lib/IR/TestUseListOrders.cpp

Modified: 
    mlir/docs/BytecodeFormat.md
    mlir/include/mlir/IR/UseDefLists.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.h
    mlir/lib/IR/Value.cpp
    mlir/test/Bytecode/invalid/invalid-structure.mlir
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/lib/Bytecode/Encoding.h


################################################################################
diff  --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index 9586c262399a4..ca04d8ccbe267 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -339,11 +339,20 @@ op {
   numSuccessors: varint?,
   successors: varint[],
 
+  numUseListOrders: varint?,
+  useListOrders: uselist[],
+
   regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove)
 
   // regions are stored in a section if isIsolatedFromAbove
   regions: (region | region_section)[]
 }
+
+uselist {
+  indexInRange: varint?,
+  useListEncoding: varint, // (numIndices << 1) | (isIndexPairEncoding)
+  indices: varint[]
+}
 ```
 
 The encoding of an operation is important because this is generally the most
@@ -377,6 +386,26 @@ definition of that value from the start of the first ancestor isolated region.
 If the operation has successors, the number of successors and the indexes of the
 successor blocks within the parent region are encoded.
 
+##### Use-list orders
+
+The reference use-list order is assumed to be the reverse of the global
+enumeration of all the op operands that one would obtain with a pre-order walk
+of the IR. This order is naturally obtained by building blocks of operations
+op-by-op. However, some transformations may shuffle the use-lists with respect
+to this reference ordering. If any of the results of the operation have a
+use-list order that is not sorted with respect to the reference use-list order,
+an encoding is emitted such that it is possible to reconstruct such order after
+parsing the bytecode. The encoding represents an index map from the reference
+operand order to the current use-list order. A bit flag is used to detect if
+this encoding is of type index-pair or not. When the bit flag is set to zero,
+the element at `i` represent the position of the use `i` of the reference list
+into the current use-list. When the bit flag is set to `1`, the encoding
+represent index pairs `(i, j)`, which indicate that the use at position `i` of
+the reference list is mapped to position `j` in the current use-list. When only
+less than half of the elements in the current use-list are shuffled with respect
+to the reference use-list, the index-pair encoding is used to reduce the
+bytecode memory requirements.
+
 ##### Regions
 
 If the operation has regions, the number of regions and if the regions are
@@ -410,6 +439,8 @@ block {
 block_arguments {
   numArgs: varint?,
   args: block_argument[]
+  numUseListOrders: varint?,
+  useListOrders: uselist[],
 }
 
 block_argument {
@@ -421,3 +452,6 @@ block_argument {
 A block is encoded with an array of operations and block arguments. The first
 field is an encoding that combines the number of operations in the block, with a
 flag indicating if the block has arguments.
+
+Use-list orders are attached to block arguments similarly to how they are
+attached to operation results.

diff  --git a/mlir/lib/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h
similarity index 76%
rename from mlir/lib/Bytecode/Encoding.h
rename to mlir/include/mlir/Bytecode/Encoding.h
index 20096ec4928e1..01ad495146010 100644
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/include/mlir/Bytecode/Encoding.h
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LIB_MLIR_BYTECODE_ENCODING_H
-#define LIB_MLIR_BYTECODE_ENCODING_H
+#ifndef MLIR_BYTECODE_ENCODING_H
+#define MLIR_BYTECODE_ENCODING_H
 
 #include <cstdint>
 
@@ -27,7 +27,7 @@ enum {
   kMinSupportedVersion = 0,
 
   /// The current bytecode version.
-  kVersion = 2,
+  kVersion = 3,
 
   /// An arbitrary value used to fill alignment padding.
   kAlignmentByte = 0xCB,
@@ -87,10 +87,27 @@ enum : uint8_t {
   kHasOperands      = 0b00000100,
   kHasSuccessors    = 0b00001000,
   kHasInlineRegions = 0b00010000,
+  kHasUseListOrders = 0b00100000,
   // clang-format on
 };
 } // namespace OpEncodingMask
 
+/// Get the unique ID of a value use. We encode the unique ID combining an owner
+/// number and the argument number such as if ownerID(op1) < ownerID(op2), then
+/// useID(op1) < useID(op2). If uses have the same owner, then argNumber(op1) <
+/// argNumber(op2) implies useID(op1) < useID(op2).
+template <typename OperandT>
+static inline uint64_t getUseID(OperandT &val, unsigned ownerID) {
+  uint32_t operandNumberID;
+  if constexpr (std::is_same_v<OpOperand, OperandT>)
+    operandNumberID = val.getOperandNumber();
+  else if constexpr (std::is_same_v<BlockArgument, OperandT>)
+    operandNumberID = val.getArgNumber();
+  else
+    llvm_unreachable("unexpected operand type");
+  return (static_cast<uint64_t>(ownerID) << 32) | operandNumberID;
+}
+
 } // namespace bytecode
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 197953698efc2..95b7489f3079a 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -44,6 +44,21 @@ class IROperandBase {
   /// of the SSA machinery.
   IROperandBase *getNextOperandUsingThisValue() { return nextUse; }
 
+  /// Initialize the use-def chain by setting the back address to self and
+  /// nextUse to nullptr.
+  void initChainWithUse(IROperandBase **self) {
+    assert(this == *self);
+    back = self;
+    nextUse = nullptr;
+  }
+
+  /// Link the current node to next.
+  void linkTo(IROperandBase *next) {
+    nextUse = next;
+    if (nextUse)
+      nextUse->back = &nextUse;
+  }
+
 protected:
   IROperandBase(Operation *owner) : owner(owner) {}
   IROperandBase(IROperandBase &&other) : owner(other.owner) {
@@ -192,6 +207,30 @@ class IRObjectWithUseList {
       use_begin()->set(newValue);
   }
 
+  /// Shuffle the use-list chain according to the provided indices vector, which
+  /// need to represent a valid shuffle. That is, a vector of unique integers in
+  /// range [0, numUses - 1]. Users of this function need to guarantee the
+  /// validity of the indices vector.
+  void shuffleUseList(ArrayRef<unsigned> indices) {
+    assert((size_t)std::distance(getUses().begin(), getUses().end()) ==
+               indices.size() &&
+           "indices vector expected to have a number of elements equal to the "
+           "number of uses");
+    SmallVector<detail::IROperandBase *> shuffled(indices.size());
+    detail::IROperandBase *ptr = firstUse;
+    for (size_t idx = 0; idx < indices.size();
+         idx++, ptr = ptr->getNextOperandUsingThisValue())
+      shuffled[indices[idx]] = ptr;
+
+    initFirstUse(shuffled.front());
+    auto *current = firstUse;
+    for (auto &next : llvm::drop_begin(shuffled)) {
+      current->linkTo(next);
+      current = next;
+    }
+    current->linkTo(nullptr);
+  }
+
   //===--------------------------------------------------------------------===//
   // Uses
   //===--------------------------------------------------------------------===//
@@ -234,6 +273,12 @@ class IRObjectWithUseList {
   OperandType *getFirstUse() const { return (OperandType *)firstUse; }
 
 private:
+  /// Set use as the first use of the chain.
+  void initFirstUse(detail::IROperandBase *use) {
+    firstUse = use;
+    firstUse->initChainWithUse(&firstUse);
+  }
+
   detail::IROperandBase *firstUse = nullptr;
 
   /// Allow access to `firstUse`.

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a280fbdf64bc8..780e047d15b71 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -187,6 +187,11 @@ class Value {
   /// Returns true if the value is used outside of the given block.
   bool isUsedOutsideOfBlock(Block *block);
 
+  /// Shuffle the use list order according to the provided indices. It is
+  /// responsibility of the caller to make sure that the indices map the current
+  /// use-list chain to another valid use-list chain.
+  void shuffleUseList(ArrayRef<unsigned> indices);
+
   //===--------------------------------------------------------------------===//
   // Uses
 

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 58145fa80db3c..92584d5e688cd 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -7,12 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 // TODO: Support for big-endian architectures.
-// TODO: Properly preserve use lists of values.
 
 #include "mlir/Bytecode/BytecodeReader.h"
-#include "../Encoding.h"
 #include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/Encoding.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/OpImplementation.h"
@@ -29,6 +28,7 @@
 #include "llvm/Support/SourceMgr.h"
 #include <list>
 #include <memory>
+#include <numeric>
 #include <optional>
 
 #define DEBUG_TYPE "mlir-bytecode-reader"
@@ -1281,6 +1281,42 @@ class mlir::BytecodeReader::Impl {
   /// Create a value to use for a forward reference.
   Value createForwardRef();
 
+  //===--------------------------------------------------------------------===//
+  // Use-list order helpers
+
+  /// This struct is a simple storage that contains information required to
+  /// reorder the use-list of a value with respect to the pre-order traversal
+  /// ordering.
+  struct UseListOrderStorage {
+    UseListOrderStorage(bool isIndexPairEncoding,
+                        SmallVector<unsigned, 4> &&indices)
+        : indices(std::move(indices)),
+          isIndexPairEncoding(isIndexPairEncoding){};
+    /// The vector containing the information required to reorder the
+    /// use-list of a value.
+    SmallVector<unsigned, 4> indices;
+
+    /// Whether indices represent a pair of type `(src, dst)` or it is a direct
+    /// indexing, such as `dst = order[src]`.
+    bool isIndexPairEncoding;
+  };
+
+  /// Parse use-list order from bytecode for a range of values if available. The
+  /// range is expected to be either a block argument or an op result range. On
+  /// success, return a map of the position in the range and the use-list order
+  /// encoding. The function assumes to know the size of the range it is
+  /// processing.
+  using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
+  FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
+                                                   uint64_t rangeSize);
+
+  /// Shuffle the use-chain according to the order parsed.
+  LogicalResult sortUseListOrder(Value value);
+
+  /// Recursively visit all the values defined within topLevelOp and sort the
+  /// use-list orders according to the indices parsed.
+  LogicalResult processUseLists(Operation *topLevelOp);
+
   //===--------------------------------------------------------------------===//
   // Fields
 
@@ -1341,17 +1377,27 @@ class mlir::BytecodeReader::Impl {
   /// The reader used to process resources within the bytecode.
   ResourceSectionReader resourceReader;
 
+  /// Worklist of values with custom use-list orders to process before the end
+  /// of the parsing.
+  DenseMap<void *, UseListOrderStorage> valueToUseListMap;
+
   /// The table of strings referenced within the bytecode file.
   StringSectionReader stringReader;
 
   /// The current set of available IR value scopes.
   std::vector<ValueScope> valueScopes;
+
+  /// The global pre-order operation ordering.
+  DenseMap<Operation *, unsigned> operationIDs;
+
   /// A block containing the set of operations defined to create forward
   /// references.
   Block forwardRefOps;
+
   /// A block containing previously created, and no longer used, forward
   /// reference operations.
   Block openForwardRefOps;
+
   /// An operation state used when instantiating forward references.
   OperationState forwardRefOpState;
 
@@ -1597,6 +1643,165 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
                                    dialectReader, bufferOwnerRef);
 }
 
+//===----------------------------------------------------------------------===//
+// UseListOrder Helpers
+
+FailureOr<BytecodeReader::Impl::UseListMapT>
+BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
+                                                uint64_t numResults) {
+  BytecodeReader::Impl::UseListMapT map;
+  uint64_t numValuesToRead = 1;
+  if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
+    return failure();
+
+  for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
+    uint64_t resultIdx = 0;
+    if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
+      return failure();
+
+    uint64_t numValues;
+    bool indexPairEncoding;
+    if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
+      return failure();
+
+    SmallVector<unsigned, 4> useListOrders;
+    for (size_t idx = 0; idx < numValues; idx++) {
+      uint64_t index;
+      if (failed(reader.parseVarInt(index)))
+        return failure();
+      useListOrders.push_back(index);
+    }
+
+    // Store in a map the result index
+    map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
+                                                   std::move(useListOrders)));
+  }
+
+  return map;
+}
+
+/// Sorts each use according to the order specified in the use-list parsed. If
+/// the custom use-list is not found, this means that the order needs to be
+/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
+/// on the same operation, the order will follow the reverse operand number
+/// ordering.
+LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
+  // Early return for trivial use-lists.
+  if (value.use_empty() || value.hasOneUse())
+    return success();
+
+  bool hasIncomingOrder =
+      valueToUseListMap.contains(value.getAsOpaquePointer());
+
+  // Compute the current order of the use-list with respect to the global
+  // ordering. Detect if the order is already sorted while doing so.
+  bool alreadySorted = true;
+  auto &firstUse = *value.use_begin();
+  uint64_t prevID =
+      bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
+  llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
+  for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
+    uint64_t currentID = bytecode::getUseID(
+        item.value(), operationIDs.at(item.value().getOwner()));
+    alreadySorted &= prevID > currentID;
+    currentOrder.push_back({item.index(), currentID});
+    prevID = currentID;
+  }
+
+  // If the order is already sorted, and there wasn't a custom order to apply
+  // from the bytecode file, we are done.
+  if (alreadySorted && !hasIncomingOrder)
+    return success();
+
+  // If not already sorted, sort the indices of the current order by descending
+  // useIDs.
+  if (!alreadySorted)
+    std::sort(
+        currentOrder.begin(), currentOrder.end(),
+        [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
+
+  if (!hasIncomingOrder) {
+    // If the bytecode file did not contain any custom use-list order, it means
+    // that the order was descending useID. Hence, shuffle by the first index
+    // of the `currentOrder` pair.
+    SmallVector<unsigned> shuffle = SmallVector<unsigned>(
+        llvm::map_range(currentOrder, [&](auto item) { return item.first; }));
+    value.shuffleUseList(shuffle);
+    return success();
+  }
+
+  // Pull the custom order info from the map.
+  UseListOrderStorage customOrder =
+      valueToUseListMap.at(value.getAsOpaquePointer());
+  SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
+  uint64_t numUses =
+      std::distance(value.getUses().begin(), value.getUses().end());
+
+  // If the encoding was a pair of indices `(src, dst)` for every permutation,
+  // reconstruct the shuffle vector for every use. Initialize the shuffle vector
+  // as identity, and then apply the mapping encoded in the indices.
+  if (customOrder.isIndexPairEncoding) {
+    // Return failure if the number of indices was not representing pairs.
+    if (shuffle.size() & 1)
+      return failure();
+
+    SmallVector<unsigned, 4> newShuffle(numUses);
+    size_t idx = 0;
+    std::iota(newShuffle.begin(), newShuffle.end(), idx);
+    for (idx = 0; idx < shuffle.size(); idx += 2)
+      newShuffle[shuffle[idx]] = shuffle[idx + 1];
+
+    shuffle = std::move(newShuffle);
+  }
+
+  // Make sure that the indices represent a valid mapping. That is, the sum of
+  // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
+  // duplicates are allowed in the list.
+  DenseSet<unsigned> set;
+  uint64_t accumulator = 0;
+  for (const auto &elem : shuffle) {
+    if (set.contains(elem))
+      return failure();
+    accumulator += elem;
+    set.insert(elem);
+  }
+  if (numUses != shuffle.size() ||
+      accumulator != (((numUses - 1) * numUses) >> 1))
+    return failure();
+
+  // Apply the current ordering map onto the shuffle vector to get the final
+  // use-list sorting indices before shuffling.
+  shuffle = SmallVector<unsigned, 4>(llvm::map_range(
+      currentOrder, [&](auto item) { return shuffle[item.first]; }));
+  value.shuffleUseList(shuffle);
+  return success();
+}
+
+LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
+  // Precompute operation IDs according to the pre-order walk of the IR. We
+  // can't do this while parsing since parseRegions ordering is not strictly
+  // equal to the pre-order walk.
+  unsigned operationID = 0;
+  topLevelOp->walk<mlir::WalkOrder::PreOrder>(
+      [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
+
+  auto blockWalk = topLevelOp->walk([this](Block *block) {
+    for (auto arg : block->getArguments())
+      if (failed(sortUseListOrder(arg)))
+        return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+
+  auto resultWalk = topLevelOp->walk([this](Operation *op) {
+    for (auto result : op->getResults())
+      if (failed(sortUseListOrder(result)))
+        return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+
+  return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
+}
+
 //===----------------------------------------------------------------------===//
 // IR Section
 
@@ -1627,6 +1832,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
         "not all forward unresolved forward operand references");
   }
 
+  // Sort use-lists according to what specified in bytecode.
+  if (failed(processUseLists(*moduleOp)))
+    return reader.emitError(
+        "parsed use-list orders were invalid and could not be applied");
+
   // Resolve dialect version.
   for (const BytecodeDialect &byteCodeDialect : dialects) {
     // Parsing is complete, give an opportunity to each dialect to visit the
@@ -1812,6 +2022,17 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
     }
   }
 
+  /// Parse the use-list orders for the results of the operation. Use-list
+  /// orders are available since version 3 of the bytecode.
+  std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
+  if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
+    size_t numResults = opState.types.size();
+    auto parseResult = parseUseListOrderForRange(reader, numResults);
+    if (failed(parseResult))
+      return failure();
+    resultIdxToUseListMap = std::move(*parseResult);
+  }
+
   /// Parse the regions of the operation.
   if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
     uint64_t numRegions;
@@ -1831,6 +2052,16 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
   if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
     return failure();
 
+  /// Store a map for every value that received a custom use-list order from the
+  /// bytecode file.
+  if (resultIdxToUseListMap.has_value()) {
+    for (size_t idx = 0; idx < op->getNumResults(); idx++) {
+      if (resultIdxToUseListMap->contains(idx)) {
+        valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
+                                      resultIdxToUseListMap->at(idx));
+      }
+    }
+  }
   return op;
 }
 
@@ -1880,6 +2111,28 @@ BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
   if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
     return failure();
 
+  // Uselist orders are available since version 3 of the bytecode.
+  if (version < 3)
+    return success();
+
+  uint8_t hasUseListOrders = 0;
+  if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
+    return failure();
+
+  if (!hasUseListOrders)
+    return success();
+
+  Block &blk = *readState.curBlock;
+  auto argIdxToUseListMap =
+      parseUseListOrderForRange(reader, blk.getNumArguments());
+  if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
+    return failure();
+
+  for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
+    if (argIdxToUseListMap->contains(idx))
+      valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
+                                    argIdxToUseListMap->at(idx));
+
   // We don't parse the operations of the block here, that's done elsewhere.
   return success();
 }

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 158dbe6d161db..c67437f317396 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Bytecode/BytecodeWriter.h"
-#include "../Encoding.h"
 #include "IRNumbering.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/Encoding.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/CachedHashString.h"
 #include "llvm/ADT/MapVector.h"
@@ -470,6 +470,12 @@ class BytecodeWriter {
 
   void writeStringSection(EncodingEmitter &emitter);
 
+  //===--------------------------------------------------------------------===//
+  // Helpers
+
+  void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask,
+                          ValueRange range);
+
   //===--------------------------------------------------------------------===//
   // Fields
 
@@ -667,6 +673,14 @@ void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
       emitter.emitVarInt(numberingState.getNumber(arg.getType()));
       emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
     }
+    if (config.bytecodeVersion > 2) {
+      uint64_t maskOffset = emitter.size();
+      uint8_t encodingMask = 0;
+      emitter.emitByte(0);
+      writeUseListOrders(emitter, encodingMask, args);
+      if (encodingMask)
+        emitter.patchByte(maskOffset, encodingMask);
+    }
   }
 
   // Emit the operations within the block.
@@ -718,6 +732,11 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
       emitter.emitVarInt(numberingState.getNumber(successor));
   }
 
+  // Emit the use-list orders to bytecode, so we can reconstruct the same order
+  // at parsing.
+  if (config.bytecodeVersion > 2)
+    writeUseListOrders(emitter, opEncodingMask, ValueRange(op->getResults()));
+
   // Check for regions.
   unsigned numRegions = op->getNumRegions();
   if (numRegions)
@@ -749,6 +768,94 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
   }
 }
 
+void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter,
+                                        uint8_t &opEncodingMask,
+                                        ValueRange range) {
+  // Loop over the results and store the use-list order per result index.
+  DenseMap<unsigned, llvm::SmallVector<unsigned>> map;
+  for (auto item : llvm::enumerate(range)) {
+    auto value = item.value();
+    // No need to store a custom use-list order if the result does not have
+    // multiple uses.
+    if (value.use_empty() || value.hasOneUse())
+      continue;
+
+    // For each result, assemble the list of pairs (use-list-index,
+    // global-value-index). While doing so, detect if the global-value-index is
+    // already ordered with respect to the use-list-index.
+    bool alreadyOrdered = true;
+    auto &firstUse = *value.use_begin();
+    uint64_t prevID = bytecode::getUseID(
+        firstUse, numberingState.getNumber(firstUse.getOwner()));
+    llvm::SmallVector<std::pair<unsigned, uint64_t>> useListPairs(
+        {{0, prevID}});
+
+    for (auto use : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
+      uint64_t currentID = bytecode::getUseID(
+          use.value(), numberingState.getNumber(use.value().getOwner()));
+      // The use-list order achieved when building the IR at parsing always
+      // pushes new uses on front. Hence, if the order by unique ID is
+      // monotonically decreasing, a roundtrip to bytecode preserves such order.
+      alreadyOrdered &= (prevID > currentID);
+      useListPairs.push_back({use.index(), currentID});
+      prevID = currentID;
+    }
+
+    // Do not emit if the order is already sorted.
+    if (alreadyOrdered)
+      continue;
+
+    // Sort the use indices by the unique ID indices in descending order.
+    std::sort(
+        useListPairs.begin(), useListPairs.end(),
+        [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
+
+    map.try_emplace(item.index(), llvm::map_range(useListPairs, [](auto elem) {
+                      return elem.first;
+                    }));
+  }
+
+  if (map.empty())
+    return;
+
+  opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders;
+  // Emit the number of results that have a custom use-list order if the number
+  // of results is greater than one.
+  if (range.size() != 1)
+    emitter.emitVarInt(map.size());
+
+  for (const auto &item : map) {
+    auto resultIdx = item.getFirst();
+    auto useListOrder = item.getSecond();
+
+    // Compute the number of uses that are actually shuffled. If those are less
+    // than half of the total uses, encoding the index pair `(src, dst)` is more
+    // space efficient.
+    size_t shuffledElements =
+        llvm::count_if(llvm::enumerate(useListOrder),
+                       [](auto item) { return item.index() != item.value(); });
+    bool indexPairEncoding = shuffledElements < (useListOrder.size() / 2);
+
+    // For single result, we don't need to store the result index.
+    if (range.size() != 1)
+      emitter.emitVarInt(resultIdx);
+
+    if (indexPairEncoding) {
+      emitter.emitVarIntWithFlag(shuffledElements * 2, indexPairEncoding);
+      for (auto pair : llvm::enumerate(useListOrder)) {
+        if (pair.index() != pair.value()) {
+          emitter.emitVarInt(pair.value());
+          emitter.emitVarInt(pair.index());
+        }
+      }
+    } else {
+      emitter.emitVarIntWithFlag(useListOrder.size(), indexPairEncoding);
+      for (const auto &index : useListOrder)
+        emitter.emitVarInt(index);
+    }
+  }
+}
+
 void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) {
   // If the region is empty, we only need to emit the number of blocks (which is
   // zero).

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index f3a153178e8d2..129437cf0245f 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -7,9 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IRNumbering.h"
-#include "../Encoding.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -109,6 +107,12 @@ static void groupByDialectPerByte(T range) {
 }
 
 IRNumberingState::IRNumberingState(Operation *op) {
+  // Compute a global operation ID numbering according to the pre-order walk of
+  // the IR. This is used as reference to construct use-list orders.
+  unsigned operationID = 0;
+  op->walk<WalkOrder::PreOrder>(
+      [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
+
   // Number the root operation.
   number(*op);
 

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index aeb624e58ba0c..91f0be05b36d3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -152,6 +152,10 @@ class IRNumberingState {
     assert(blockIDs.count(block) && "block not numbered");
     return blockIDs[block];
   }
+  unsigned getNumber(Operation *op) {
+    assert(operationIDs.count(op) && "operation not numbered");
+    return operationIDs[op];
+  }
   unsigned getNumber(OperationName opName) {
     assert(opNames.count(opName) && "opName not numbered");
     return opNames[opName]->number;
@@ -224,7 +228,8 @@ class IRNumberingState {
   llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
   llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
 
-  /// The value ID for each Block and Value.
+  /// The value ID for each Operation, Block and Value.
+  DenseMap<Operation *, unsigned> operationIDs;
   DenseMap<Block *, unsigned> blockIDs;
   DenseMap<Value, unsigned> valueIDs;
 

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 75976628e82a3..86b9cde76c05d 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -93,6 +93,11 @@ bool Value::isUsedOutsideOfBlock(Block *block) {
   });
 }
 
+/// Shuffles the use-list order according to the provided indices.
+void Value::shuffleUseList(ArrayRef<unsigned> indices) {
+  getImpl()->shuffleUseList(indices);
+}
+
 //===----------------------------------------------------------------------===//
 // OpResult
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index 4668878b10560..ae18cfaff687c 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
 //===--------------------------------------------------------------------===//
 
 // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 2
+// VERSION: bytecode version 127 is newer than the current version 3
 
 //===--------------------------------------------------------------------===//
 // Producer

diff  --git a/mlir/test/Bytecode/uselist_orders.mlir b/mlir/test/Bytecode/uselist_orders.mlir
new file mode 100644
index 0000000000000..b8f4c3df6542c
--- /dev/null
+++ b/mlir/test/Bytecode/uselist_orders.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -split-input-file --test-verify-uselistorder -verify-diagnostics
+
+// COM: --test-verify-uselistorder will randomly shuffle the uselist of every
+//      value and do a roundtrip to bytecode. An error is returned if the
+//      uselist order are not preserved when doing a roundtrip to bytecode. The
+//      test needs to verify diagnostics to be functional.
+
+func.func @base_test(%arg0 : i32) -> i32 {
+  %0 = arith.constant 45 : i32
+  %1 = arith.constant 46 : i32
+  %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+  %3 = "test.addi"(%2, %0) : (i32, i32) -> i32
+  %4 = "test.addi"(%2, %1) : (i32, i32) -> i32
+  %5 = "test.addi"(%3, %4) : (i32, i32) -> i32
+  %6 = "test.addi"(%5, %4) : (i32, i32) -> i32
+  %7 = "test.addi"(%6, %4) : (i32, i32) -> i32
+  return %7 : i32
+}
+
+// -----
+
+func.func @test_with_multiple_uses_in_same_op(%arg0 : i32) -> i32 {
+  %0 = arith.constant 45 : i32
+  %1 = arith.constant 46 : i32
+  %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+  %3 = "test.addi"(%2, %0) : (i32, i32) -> i32
+  %4 = "test.addi"(%2, %1) : (i32, i32) -> i32
+  %5 = "test.addi"(%2, %2) : (i32, i32) -> i32
+  %6 = "test.addi"(%3, %4) : (i32, i32) -> i32
+  %7 = "test.addi"(%6, %5) : (i32, i32) -> i32
+  %8 = "test.addi"(%7, %4) : (i32, i32) -> i32
+  %9 = "test.addi"(%8, %4) : (i32, i32) -> i32
+  return %9 : i32
+}
+
+// -----
+
+func.func @test_with_multiple_block_arg_uses(%arg0 : i32) -> i32 {
+  %0 = arith.constant 45 : i32
+  %1 = arith.constant 46 : i32
+  %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+  %3 = "test.addi"(%2, %arg0) : (i32, i32) -> i32
+  %4 = "test.addi"(%2, %1) : (i32, i32) -> i32
+  %5 = "test.addi"(%2, %2) : (i32, i32) -> i32
+  %6 = "test.addi"(%3, %4) : (i32, i32) -> i32
+  %7 = "test.addi"(%6, %5) : (i32, i32) -> i32
+  %8 = "test.addi"(%7, %4) : (i32, i32) -> i32
+  %9 = "test.addi"(%8, %4) : (i32, i32) -> i32
+  return %9 : i32
+}
+
+// -----
+
+// Test that use-lists in region with no dominance are preserved
+test.graph_region {
+  %0 = "test.foo"(%1) : (i32) -> i32
+  test.graph_region attributes {a} {
+    %a = "test.a"(%b) : (i32) -> i32
+    %b = "test.b"(%2) : (i32) -> i32
+  }
+  %1 = "test.bar"(%2) : (i32) -> i32
+  %2 = "test.baz"() : () -> i32
+}

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index b7411a34c11e0..004e728b4b877 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -18,6 +18,7 @@ add_mlir_library(MLIRTestIR
   TestSymbolUses.cpp
   TestRegions.cpp
   TestTypes.cpp
+  TestUseListOrders.cpp
   TestVisitors.cpp
   TestVisitorsGeneric.cpp
 

diff  --git a/mlir/test/lib/IR/TestUseListOrders.cpp b/mlir/test/lib/IR/TestUseListOrders.cpp
new file mode 100644
index 0000000000000..ff195ec95f72a
--- /dev/null
+++ b/mlir/test/lib/IR/TestUseListOrders.cpp
@@ -0,0 +1,219 @@
+//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/Bytecode/Encoding.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+
+#include <numeric>
+#include <random>
+
+using namespace mlir;
+
+namespace {
+/// This pass tests that:
+/// 1) we can shuffle use-lists correctly;
+/// 2) use-list orders are preserved after a roundtrip to bytecode.
+class TestPreserveUseListOrders
+    : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
+
+  TestPreserveUseListOrders() = default;
+  TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
+      : PassWrapper(pass) {}
+  StringRef getArgument() const final { return "test-verify-uselistorder"; }
+  StringRef getDescription() const final {
+    return "Verify that roundtripping the IR to bytecode preserves the order "
+           "of the uselists";
+  }
+  Option<unsigned> rngSeed{*this, "rng-seed",
+                           llvm::cl::desc("Specify an input random seed"),
+                           llvm::cl::init(1)};
+  void runOnOperation() override {
+    // Clone the module so that we can plug in this pass to any other
+    // independently.
+    auto cloneModule = getOperation().clone();
+
+    // 1. Compute the op numbering of the module.
+    computeOpNumbering(cloneModule);
+
+    // 2. Loop over all the values and shuffle the uses. While doing so, check
+    // that each shuffle is correct.
+    if (failed(shuffleUses(cloneModule)))
+      return signalPassFailure();
+
+    // 3. Do a bytecode roundtrip to version 3, which supports use-list order
+    // preservation.
+    auto roundtripModuleOr = doRoundtripToBytecode(cloneModule, 3);
+    // If the bytecode roundtrip failed, try to roundtrip the original module
+    // to version 2, which does not support use-list. If this also fails, the
+    // original module had an issue unrelated to uselists.
+    if (failed(roundtripModuleOr)) {
+      auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
+      if (failed(testModuleOr))
+        return;
+
+      return signalPassFailure();
+    }
+
+    // 4. Recompute the op numbering on the new module. The numbering should be
+    // the same as (1), but on the new operation pointers.
+    computeOpNumbering(roundtripModuleOr->get());
+
+    // 5. Loop over all the values and verify that the use-list is consistent
+    // with the post-shuffle order of step (2).
+    if (failed(verifyUseListOrders(roundtripModuleOr->get())))
+      return signalPassFailure();
+  }
+
+private:
+  FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
+                                                            uint32_t version) {
+    std::string str;
+    llvm::raw_string_ostream m(str);
+    BytecodeWriterConfig config;
+    config.setDesiredBytecodeVersion(version);
+    if (failed(writeBytecodeToFile(module, m, config)))
+      return failure();
+
+    ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
+    auto newModuleOp = parseSourceString(StringRef(str), parseConfig);
+    if (!newModuleOp.get())
+      return failure();
+    return newModuleOp;
+  }
+
+  /// Compute an ordered numbering for all the operations in the IR.
+  void computeOpNumbering(Operation *topLevelOp) {
+    uint32_t operationID = 0;
+    opNumbering.clear();
+    topLevelOp->walk<mlir::WalkOrder::PreOrder>(
+        [&](Operation *op) { opNumbering.try_emplace(op, operationID++); });
+  }
+
+  template <typename ValueT>
+  SmallVector<uint64_t> getUseIDs(ValueT val) {
+    return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
+      return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
+    }));
+  }
+
+  LogicalResult shuffleUses(Operation *topLevelOp) {
+    uint32_t valueID = 0;
+    /// Permute randomly the use-list of each value. It is guaranteed that at
+    /// least one pair of the use list is permuted.
+    auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
+      for (auto val : range) {
+        if (val.use_empty() || val.hasOneUse())
+          continue;
+
+        /// Get a valid index permutation for the uses of value.
+        SmallVector<unsigned> permutation = getRandomPermutation(val);
+
+        /// Store original order and verify that the shuffle was applied
+        /// correctly.
+        auto useIDs = getUseIDs(val);
+
+        /// Apply shuffle to the uselist.
+        val.shuffleUseList(permutation);
+
+        /// Get the new order and verify the shuffle happened correctly.
+        auto permutedIDs = getUseIDs(val);
+        if (permutedIDs.size() != useIDs.size())
+          return failure();
+        for (size_t idx = 0; idx < permutation.size(); idx++)
+          if (useIDs[idx] != permutedIDs[permutation[idx]])
+            return failure();
+
+        referenceUseListOrder.try_emplace(
+            valueID++, llvm::map_range(val.getUses(), [&](auto &use) {
+              return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
+            }));
+      }
+      return success();
+    };
+
+    return walkOverValues(topLevelOp, doShuffleForRange);
+  }
+
+  LogicalResult verifyUseListOrders(Operation *topLevelOp) {
+    uint32_t valueID = 0;
+    /// Check that the use-list for the value range matches the one stored in
+    /// the reference.
+    auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
+      for (auto val : range) {
+        if (val.use_empty() || val.hasOneUse())
+          continue;
+        auto referenceOrder = referenceUseListOrder.at(valueID++);
+        for (auto [use, referenceID] :
+             llvm::zip(val.getUses(), referenceOrder)) {
+          uint64_t uniqueID =
+              bytecode::getUseID(use, opNumbering.at(use.getOwner()));
+          if (uniqueID != referenceID) {
+            use.getOwner()->emitError()
+                << "found use-list order mismatch for value: " << val;
+            return failure();
+          }
+        }
+      }
+      return success();
+    };
+
+    return walkOverValues(topLevelOp, doValidationForRange);
+  }
+
+  /// Walk over blocks and operations and execute a callable over the ranges of
+  /// operands/results respectively.
+  template <typename FuncT>
+  LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
+    auto blockWalk = topLevelOp->walk([&](Block *block) {
+      if (failed(callable(block->getArguments())))
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+
+    if (blockWalk.wasInterrupted())
+      return failure();
+
+    auto resultsWalk = topLevelOp->walk([&](Operation *op) {
+      if (failed(callable(op->getResults())))
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+
+    return failure(resultsWalk.wasInterrupted());
+  }
+
+  /// Creates a random permutation of the uselist order chain of the provided
+  /// value.
+  SmallVector<unsigned> getRandomPermutation(Value value) {
+    size_t numUses = std::distance(value.use_begin(), value.use_end());
+    SmallVector<unsigned> permutation(numUses);
+    unsigned zero = 0;
+    std::iota(permutation.begin(), permutation.end(), zero);
+    auto rng = std::default_random_engine(rngSeed);
+    std::shuffle(permutation.begin(), permutation.end(), rng);
+    return permutation;
+  }
+
+  /// Map each value to its use-list order encoded with unique use IDs.
+  DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
+
+  /// Map each operation to its global ID.
+  DenseMap<Operation *, uint32_t> opNumbering;
+};
+} // namespace
+
+namespace mlir {
+void registerTestPreserveUseListOrders() {
+  PassRegistration<TestPreserveUseListOrders>();
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 40b9c827fa610..9bc91c2d6bcd3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -53,6 +53,7 @@ void registerTestOperationEqualPass();
 void registerTestPrintDefUsePass();
 void registerTestPrintInvalidPass();
 void registerTestPrintNestingPass();
+void registerTestPreserveUseListOrders();
 void registerTestReducer();
 void registerTestSpirvEntryPointABIPass();
 void registerTestSpirvModuleCombinerPass();
@@ -167,6 +168,7 @@ void registerTestPasses() {
   registerTestPrintDefUsePass();
   registerTestPrintInvalidPass();
   registerTestPrintNestingPass();
+  registerTestPreserveUseListOrders();
   registerTestReducer();
   registerTestSpirvEntryPointABIPass();
   registerTestSpirvModuleCombinerPass();


        


More information about the Mlir-commits mailing list