[Mlir-commits] [mlir] ee90bb3 - Store (cache) the Argument number (index in the argument list) inside the BlockArgumentImpl

Mehdi Amini llvmlistbot at llvm.org
Sat Feb 27 09:21:20 PST 2021


Author: Mehdi Amini
Date: 2021-02-27T17:21:08Z
New Revision: ee90bb3486948c472a67ec3ca0f0d64927f6643d

URL: https://github.com/llvm/llvm-project/commit/ee90bb3486948c472a67ec3ca0f0d64927f6643d
DIFF: https://github.com/llvm/llvm-project/commit/ee90bb3486948c472a67ec3ca0f0d64927f6643d.diff

LOG: Store (cache) the Argument number (index in the argument list) inside the BlockArgumentImpl

This avoids linear search in BlockArgument::getArgNumber().

Differential Revision: https://reviews.llvm.org/D97596

Added: 
    

Modified: 
    mlir/include/mlir/IR/Value.h
    mlir/lib/IR/Block.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index c22741ee6cb6..1e98891ff945 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -249,7 +249,8 @@ inline raw_ostream &operator<<(raw_ostream &os, Value value) {
 namespace detail {
 /// The internal implementation of a BlockArgument.
 class BlockArgumentImpl : public IRObjectWithUseList<OpOperand> {
-  BlockArgumentImpl(Type type, Block *owner) : type(type), owner(owner) {}
+  BlockArgumentImpl(Type type, Block *owner, int64_t index)
+      : type(type), owner(owner), index(index) {}
 
   /// The type of this argument.
   Type type;
@@ -257,6 +258,9 @@ class BlockArgumentImpl : public IRObjectWithUseList<OpOperand> {
   /// The owner of this argument.
   Block *owner;
 
+  /// The position in the argument list.
+  int64_t index;
+
   /// Allow access to owner and constructor.
   friend BlockArgument;
 };
@@ -281,12 +285,12 @@ class BlockArgument : public Value {
   void setType(Type newType) { getImpl()->type = newType; }
 
   /// Returns the number of this argument.
-  unsigned getArgNumber() const;
+  unsigned getArgNumber() const { return getImpl()->index; }
 
 private:
   /// Allocate a new argument with the given type and owner.
-  static BlockArgument create(Type type, Block *owner) {
-    return new detail::BlockArgumentImpl(type, owner);
+  static BlockArgument create(Type type, Block *owner, int64_t index) {
+    return new detail::BlockArgumentImpl(type, owner, index);
   }
 
   /// Destroy and deallocate this argument.
@@ -298,7 +302,10 @@ class BlockArgument : public Value {
         ownerAndKind.getPointer());
   }
 
-  /// Allow access to `create` and `destroy`.
+  /// Cache the position in the block argument list.
+  void setArgNumber(int64_t index) { getImpl()->index = index; }
+
+  /// Allow access to `create`, `destroy` and `setArgNumber`.
   friend Block;
 
   /// Allow access to 'getImpl'.

diff  --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 74209bd0884b..2ae983d9c51d 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -12,17 +12,6 @@
 #include "llvm/ADT/BitVector.h"
 using namespace mlir;
 
-//===----------------------------------------------------------------------===//
-// BlockArgument
-//===----------------------------------------------------------------------===//
-
-/// Returns the number of this argument.
-unsigned BlockArgument::getArgNumber() const {
-  // Arguments are not stored in place, so we have to find it within the list.
-  auto argList = getOwner()->getArguments();
-  return std::distance(argList.begin(), llvm::find(argList, *this));
-}
-
 //===----------------------------------------------------------------------===//
 // Block
 //===----------------------------------------------------------------------===//
@@ -150,7 +139,7 @@ auto Block::getArgumentTypes() -> ValueTypeRange<BlockArgListType> {
 }
 
 BlockArgument Block::addArgument(Type type) {
-  BlockArgument arg = BlockArgument::create(type, this);
+  BlockArgument arg = BlockArgument::create(type, this, arguments.size());
   arguments.push_back(arg);
   return arg;
 }
@@ -165,16 +154,31 @@ auto Block::addArguments(TypeRange types) -> iterator_range<args_iterator> {
 }
 
 BlockArgument Block::insertArgument(unsigned index, Type type) {
-  auto arg = BlockArgument::create(type, this);
+  auto arg = BlockArgument::create(type, this, index);
   assert(index <= arguments.size());
   arguments.insert(arguments.begin() + index, arg);
+  // Update the cached position for all the arguments after the newly inserted
+  // one.
+  ++index;
+  for (BlockArgument arg : llvm::drop_begin(arguments, index))
+    arg.setArgNumber(index++);
   return arg;
 }
 
+/// Insert one value to the given position of the argument list. The existing
+/// arguments are shifted. The block is expected not to have predecessors.
+BlockArgument Block::insertArgument(args_iterator it, Type type) {
+  assert(llvm::empty(getPredecessors()) &&
+         "cannot insert arguments to blocks with predecessors");
+  return insertArgument(it->getArgNumber(), type);
+}
+
 void Block::eraseArgument(unsigned index) {
   assert(index < arguments.size());
   arguments[index].destroy();
   arguments.erase(arguments.begin() + index);
+  for (BlockArgument arg : llvm::drop_begin(arguments, index))
+    arg.setArgNumber(index++);
 }
 
 void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
@@ -188,23 +192,18 @@ void Block::eraseArguments(llvm::BitVector eraseIndices) {
   // We do this in reverse so that we erase later indices before earlier
   // indices, to avoid shifting the later indices.
   unsigned originalNumArgs = getNumArguments();
-  for (unsigned i = 0; i < originalNumArgs; ++i)
-    if (eraseIndices.test(originalNumArgs - i - 1))
-      eraseArgument(originalNumArgs - i - 1);
-}
-
-/// Insert one value to the given position of the argument list. The existing
-/// arguments are shifted. The block is expected not to have predecessors.
-BlockArgument Block::insertArgument(args_iterator it, Type type) {
-  assert(llvm::empty(getPredecessors()) &&
-         "cannot insert arguments to blocks with predecessors");
-
-  // Use the args_iterator (on the BlockArgListType) to compute the insertion
-  // iterator in the underlying argument storage.
-  size_t distance = std::distance(args_begin(), it);
-  auto arg = BlockArgument::create(type, this);
-  arguments.insert(std::next(arguments.begin(), distance), arg);
-  return arg;
+  for (unsigned i = 0; i < originalNumArgs; ++i) {
+    int64_t currentPos = originalNumArgs - i - 1;
+    if (eraseIndices.test(currentPos)) {
+      arguments[currentPos].destroy();
+      arguments.erase(arguments.begin() + currentPos);
+    }
+  }
+  // Update the cached position for the arguments after the first erased one.
+  int64_t index = 0;
+  for (BlockArgument arg :
+       llvm::drop_begin(arguments, eraseIndices.find_first()))
+    arg.setArgNumber(index++);
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list