[llvm-branch-commits] [mlir] 99b39d7 - [mlir:bytecode] Support lazy loading dynamically isolated regions

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 10 00:09:47 PDT 2023


Author: River Riddle
Date: 2023-08-10T09:06:20+02:00
New Revision: 99b39d7df62365d5c0c1fad776b4fd3b0e452277

URL: https://github.com/llvm/llvm-project/commit/99b39d7df62365d5c0c1fad776b4fd3b0e452277
DIFF: https://github.com/llvm/llvm-project/commit/99b39d7df62365d5c0c1fad776b4fd3b0e452277.diff

LOG: [mlir:bytecode] Support lazy loading dynamically isolated regions

We currently only support lazy loading for regions that
statically implement the IsolatedFromAbove trait, but that
limits the amount of operations that can be lazily loaded. This review
lifts that restriction by computing which operations have isolated
regions when numbering, allowing any operation to be lazily loaded
as long as it doesn't use values defined above.

Differential Revision: https://reviews.llvm.org/D156199

Added: 
    

Modified: 
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.h
    mlir/test/Bytecode/bytecode-lazy-loading.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 401629c739652a..d8f2cb106510d9 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -942,7 +942,7 @@ LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
   // emitting the regions first (e.g. if the regions are huge, backpatching the
   // op encoding mask is more annoying).
   if (numRegions) {
-    bool isIsolatedFromAbove = op->hasTrait<OpTrait::IsIsolatedFromAbove>();
+    bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op);
     emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
 
     // If the region is not isolated from above, or we are emitting bytecode

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 284b3c02f1f2ce..788cf5b201f02b 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -115,19 +115,29 @@ static void groupByDialectPerByte(T range) {
 IRNumberingState::IRNumberingState(Operation *op,
                                    const BytecodeWriterConfig &config)
     : config(config) {
-  // Compute a global operation ID numbering according to the pre-order walk of
-  // the IR. This is used as reference to construct use-list orders.
-  unsigned operationID = 0;
-  op->walk<WalkOrder::PreOrder>(
-      [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
+  computeGlobalNumberingState(op);
 
   // Number the root operation.
   number(*op);
 
-  // Push all of the regions of the root operation onto the worklist.
+  // A worklist of region contexts to number and the next value id before that
+  // region.
   SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
-  for (Region &region : op->getRegions())
-    numberContext.emplace_back(&region, nextValueID);
+
+  // Functor to push the regions of the given operation onto the numbering
+  // context.
+  auto addOpRegionsToNumber = [&](Operation *op) {
+    MutableArrayRef<Region> regions = op->getRegions();
+    if (regions.empty())
+      return;
+
+    // Isolated regions don't share value numbers with their parent, so we can
+    // start numbering these regions at zero.
+    unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
+    for (Region &region : regions)
+      numberContext.emplace_back(&region, opFirstValueID);
+  };
+  addOpRegionsToNumber(op);
 
   // Iteratively process each of the nested regions.
   while (!numberContext.empty()) {
@@ -136,14 +146,8 @@ IRNumberingState::IRNumberingState(Operation *op,
     number(*region);
 
     // Traverse into nested regions.
-    for (Operation &op : region->getOps()) {
-      // Isolated regions don't share value numbers with their parent, so we can
-      // start numbering these regions at zero.
-      unsigned opFirstValueID =
-          op.hasTrait<OpTrait::IsIsolatedFromAbove>() ? 0 : nextValueID;
-      for (Region &region : op.getRegions())
-        numberContext.emplace_back(&region, opFirstValueID);
-    }
+    for (Operation &op : region->getOps())
+      addOpRegionsToNumber(&op);
   }
 
   // Number each of the dialects. For now this is just in the order they were
@@ -178,6 +182,116 @@ IRNumberingState::IRNumberingState(Operation *op,
   finalizeDialectResourceNumberings(op);
 }
 
+void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
+  // A simple state struct tracking data used when walking operations.
+  struct StackState {
+    /// The operation currently being walked.
+    Operation *op;
+
+    /// The numbering of the operation.
+    OperationNumbering *numbering;
+
+    /// A flag indicating if the current state or one of its parents has
+    /// unresolved isolation status. This is tracked separately from the
+    /// isIsolatedFromAbove bit on `numbering` because we need to be able to
+    /// handle the given case:
+    ///   top.op {
+    ///     %value = ...
+    ///     middle.op {
+    ///       %value2 = ...
+    ///       inner.op {
+    ///         // Here we mark `inner.op` as not isolated. Note `middle.op`
+    ///         // isn't known not isolated yet.
+    ///         use.op %value2
+    ///
+    ///         // Here inner.op is already known to be non-isolated, but
+    ///         // `middle.op` is now also discovered to be non-isolated.
+    ///         use.op %value
+    ///       }
+    ///     }
+    ///   }
+    bool hasUnresolvedIsolation;
+  };
+
+  // Compute a global operation ID numbering according to the pre-order walk of
+  // the IR. This is used as reference to construct use-list orders.
+  unsigned operationID = 0;
+
+  // Walk each of the operations within the IR, tracking a stack of operations
+  // as we recurse into nested regions. This walk method hooks in at two stages
+  // during the walk:
+  //
+  //   BeforeAllRegions:
+  //     Here we generate a numbering for the operation and push it onto the
+  //     stack if it has regions. We also compute the isolation status of parent
+  //     regions at this stage. This is done by checking the parent regions of
+  //     operands used by the operation, and marking each region between the
+  //     the operand region and the current as not isolated. See
+  //     StackState::hasUnresolvedIsolation above for an example.
+  //
+  //   AfterAllRegions:
+  //     Here we pop the operation from the stack, and if it hasn't been marked
+  //     as non-isolated, we mark it as so. A non-isolated use would have been
+  //     found while walking the regions, so it is safe to mark the operation at
+  //     this point.
+  //
+  SmallVector<StackState> opStack;
+  rootOp->walk([&](Operation *op, const WalkStage &stage) {
+    // After visiting all nested regions, we pop the operation from the stack.
+    if (stage.isAfterAllRegions()) {
+      // If no non-isolated uses were found, we can safely mark this operation
+      // as isolated from above.
+      OperationNumbering *numbering = opStack.pop_back_val().numbering;
+      if (!numbering->isIsolatedFromAbove.has_value())
+        numbering->isIsolatedFromAbove = true;
+      return;
+    }
+
+    // When visiting before nested regions, we process "IsolatedFromAbove"
+    // checks and compute the number for this operation.
+    if (!stage.isBeforeAllRegions())
+      return;
+    // Update the isolation status of parent regions if any have yet to be
+    // resolved.
+    if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
+      Region *parentRegion = op->getParentRegion();
+      for (Value operand : op->getOperands()) {
+        Region *operandRegion = operand.getParentRegion();
+        if (operandRegion == parentRegion)
+          continue;
+        // We've found a use of an operand outside of the current region,
+        // walk the operation stack searching for the parent operation,
+        // marking every region on the way as not isolated.
+        Operation *operandContainerOp = operandRegion->getParentOp();
+        auto it = std::find_if(
+            opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
+              // We only need to mark up to the container region, or the first
+              // that has an unresolved status.
+              return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
+            });
+        assert(it != opStack.rend() && "expected to find the container");
+        for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
+          // If we stopped at a region that knows its isolation status, we can
+          // stop updating the isolation status for the parent regions.
+          state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
+          state.numbering->isIsolatedFromAbove = false;
+        }
+      }
+    }
+
+    // Compute the number for this op and push it onto the stack.
+    auto *numbering =
+        new (opAllocator.Allocate()) OperationNumbering(operationID++);
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      numbering->isIsolatedFromAbove = true;
+    operations.try_emplace(op, numbering);
+    if (op->getNumRegions()) {
+      opStack.emplace_back(StackState{
+          op, numbering, !numbering->isIsolatedFromAbove.has_value()});
+    }
+  });
+}
+
 void IRNumberingState::number(Attribute attr) {
   auto it = attrs.insert({attr, nullptr});
   if (!it.second) {

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index ca30078f3468f4..eab75f50d2ee4f 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -126,6 +126,22 @@ struct DialectNumbering {
   llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
 };
 
+//===----------------------------------------------------------------------===//
+// Operation Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents the numbering entry of an operation.
+struct OperationNumbering {
+  OperationNumbering(unsigned number) : number(number) {}
+
+  /// The number assigned to this operation.
+  unsigned number;
+
+  /// A flag indicating if this operation's regions are isolated. If unset, the
+  /// operation isn't yet known to be isolated.
+  std::optional<bool> isIsolatedFromAbove;
+};
+
 //===----------------------------------------------------------------------===//
 // IRNumberingState
 //===----------------------------------------------------------------------===//
@@ -154,8 +170,8 @@ class IRNumberingState {
     return blockIDs[block];
   }
   unsigned getNumber(Operation *op) {
-    assert(operationIDs.count(op) && "operation not numbered");
-    return operationIDs[op];
+    assert(operations.count(op) && "operation not numbered");
+    return operations[op]->number;
   }
   unsigned getNumber(OperationName opName) {
     assert(opNames.count(opName) && "opName not numbered");
@@ -186,14 +202,23 @@ class IRNumberingState {
     return blockOperationCounts[block];
   }
 
+  /// Return if the given operation is isolated from above.
+  bool isIsolatedFromAbove(Operation *op) {
+    assert(operations.count(op) && "operation not numbered");
+    return operations[op]->isIsolatedFromAbove.value_or(false);
+  }
+
   /// Get the set desired bytecode version to emit.
   int64_t getDesiredBytecodeVersion() const;
-
+  
 private:
   /// This class is used to provide a fake dialect writer for numbering nested
   /// attributes and types.
   struct NumberingDialectWriter;
 
+  /// Compute the global numbering state for the given root operation.
+  void computeGlobalNumberingState(Operation *rootOp);
+
   /// Number the given IR unit for bytecode emission.
   void number(Attribute attr);
   void number(Block &block);
@@ -212,6 +237,7 @@ class IRNumberingState {
 
   /// Mapping from IR to the respective numbering entries.
   DenseMap<Attribute, AttributeNumbering *> attrs;
+  DenseMap<Operation *, OperationNumbering *> operations;
   DenseMap<OperationName, OpNameNumbering *> opNames;
   DenseMap<Type, TypeNumbering *> types;
   DenseMap<Dialect *, DialectNumbering *> registeredDialects;
@@ -228,12 +254,12 @@ class IRNumberingState {
   /// Allocators used for the various numbering entries.
   llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
   llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
+  llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
   llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
   llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
   llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
 
-  /// The value ID for each Operation, Block and Value.
-  DenseMap<Operation *, unsigned> operationIDs;
+  /// The value ID for each Block and Value.
   DenseMap<Block *, unsigned> blockIDs;
   DenseMap<Value, unsigned> valueIDs;
 

diff  --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
index d439f84db17d3c..a1981fce0875fd 100644
--- a/mlir/test/Bytecode/bytecode-lazy-loading.mlir
+++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
@@ -23,6 +23,21 @@ func.func @op_with_passthrough_region_args() {
   }, {
     "test.unknown_op"() : () -> ()
   }
+  
+  // Ensure operations that aren't tagged as IsolatedFromAbove can
+  // still be lazy loaded if they don't have references to values
+  // defined above.
+  "test.one_region_op"() ({
+    "test.unknown_op"() : () -> ()
+  }) : () -> ()
+
+  // Similar test as above, but check that if one region has a reference
+  // to a value defined above, we don't lazy load the operation.
+  "test.two_region_op"() ({
+    "test.unknown_op"() : () -> ()
+  }, {
+    "test.consumer"(%0) : (index) -> ()
+  }) : () -> ()
   return
 }
 
@@ -53,7 +68,12 @@ func.func @op_with_passthrough_region_args() {
 // CHECK: test.consumer
 // CHECK: isolated_region
 // CHECK-NOT: test.consumer
-// CHECK: Has 3 ops to materialize
+// CHECK: test.one_region_op
+// CHECK-NOT: test.op
+// CHECK: test.two_region_op
+// CHECK: test.unknown_op
+// CHECK: test.consumer
+// CHECK: Has 4 ops to materialize
 
 // CHECK: Before Materializing...
 // CHECK: test.isolated_region
@@ -62,7 +82,7 @@ func.func @op_with_passthrough_region_args() {
 // CHECK: test.isolated_region
 // CHECK: ^bb0(%arg0: index):
 // CHECK:  test.consumer
-// CHECK: Has 2 ops to materialize
+// CHECK: Has 3 ops to materialize
 
 // CHECK: Before Materializing...
 // CHECK: test.isolated_region
@@ -70,7 +90,7 @@ func.func @op_with_passthrough_region_args() {
 // CHECK: Materializing...
 // CHECK: test.isolated_region
 // CHECK: test.consumer
-// CHECK: Has 1 ops to materialize
+// CHECK: Has 2 ops to materialize
 
 // CHECK: Before Materializing...
 // CHECK: test.isolated_regions
@@ -79,4 +99,12 @@ func.func @op_with_passthrough_region_args() {
 // CHECK: test.isolated_regions
 // CHECK: test.unknown_op
 // CHECK: test.unknown_op
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: test.one_region_op
+// CHECK-NOT: test.unknown_op
+// CHECK: Materializing...
+// CHECK: test.one_region_op
+// CHECK: test.unknown_op
 // CHECK: Has 0 ops to materialize


        


More information about the llvm-branch-commits mailing list