[Mlir-commits] [mlir] a776ecb - [mlir][IR] Add an Operation::eraseOperands that supports batch erasure

River Riddle llvmlistbot at llvm.org
Tue Mar 9 15:08:09 PST 2021


Author: River Riddle
Date: 2021-03-09T15:07:53-08:00
New Revision: a776ecb6c2b83454f4cd22a0e302aef558381dab

URL: https://github.com/llvm/llvm-project/commit/a776ecb6c2b83454f4cd22a0e302aef558381dab
DIFF: https://github.com/llvm/llvm-project/commit/a776ecb6c2b83454f4cd22a0e302aef558381dab.diff

LOG: [mlir][IR] Add an Operation::eraseOperands that supports batch erasure

This method allows for removing multiple disjoint operands at once, reducing the need to erase operands individually (which results in shifting the operand list).

Differential Revision: https://reviews.llvm.org/D98290

Added: 
    

Modified: 
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/unittests/IR/OperationSupportTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f76ea2657811..00eb2c5d973f 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -243,6 +243,12 @@ class alignas(8) Operation final
     getOperandStorage().eraseOperands(idx, length);
   }
 
+  /// Erases the operands that have their corresponding bit set in
+  /// `eraseIndices` and removes them from the operand list.
+  void eraseOperands(const llvm::BitVector &eraseIndices) {
+    getOperandStorage().eraseOperands(eraseIndices);
+  }
+
   // Support operand iteration.
   using operand_range = OperandRange;
   using operand_iterator = operand_range::iterator;

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 3aa710e5893c..60af4b09e0e1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -28,6 +28,10 @@
 #include "llvm/Support/TrailingObjects.h"
 #include <memory>
 
+namespace llvm {
+class BitVector;
+} // end namespace llvm
+
 namespace mlir {
 class Dialect;
 class DictionaryAttr;
@@ -495,6 +499,10 @@ class OperandStorage final
   /// Erase the operands held by the storage within the given range.
   void eraseOperands(unsigned start, unsigned length);
 
+  /// Erase the operands held by the storage that have their corresponding bit
+  /// set in `eraseIndices`.
+  void eraseOperands(const llvm::BitVector &eraseIndices);
+
   /// Get the operation operands held by the storage.
   MutableArrayRef<OpOperand> getOperands() {
     return getStorage().getOperands();

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 28acdc2fcb29..c60e05665142 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -12,10 +12,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Operation.h"
+#include "llvm/ADT/BitVector.h"
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -300,6 +300,26 @@ void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
     operands[storage.numOperands + i].~OpOperand();
 }
 
+void detail::OperandStorage::eraseOperands(
+    const llvm::BitVector &eraseIndices) {
+  TrailingOperandStorage &storage = getStorage();
+  MutableArrayRef<OpOperand> operands = storage.getOperands();
+  assert(eraseIndices.size() == operands.size());
+
+  // Check that at least one operand is erased.
+  int firstErasedIndice = eraseIndices.find_first();
+  if (firstErasedIndice == -1)
+    return;
+
+  // Shift all of the removed operands to the end, and destroy them.
+  storage.numOperands = firstErasedIndice;
+  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
+    if (!eraseIndices.test(i))
+      operands[storage.numOperands++] = std::move(operands[i]);
+  for (OpOperand &operand : operands.drop_front(storage.numOperands))
+    operand.~OpOperand();
+}
+
 /// Resize the storage to the given size. Returns the array containing the new
 /// operands.
 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,

diff  --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index c48e8623c609..b4fbc4be4b89 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -9,6 +9,7 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/BitVector.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -150,6 +151,37 @@ TEST(OperandStorageTest, MutableRange) {
   useOp->destroy();
 }
 
+TEST(OperandStorageTest, RangeErase) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  Type type = builder.getNoneType();
+  Operation *useOp = createOp(&context, /*operands=*/llvm::None, {type, type});
+  Value operand1 = useOp->getResult(0);
+  Value operand2 = useOp->getResult(1);
+
+  // Create an operation with operands to erase.
+  Operation *user =
+      createOp(&context, {operand2, operand1, operand2, operand1});
+  llvm::BitVector eraseIndices(user->getNumOperands());
+
+  // Check erasing no operands.
+  user->eraseOperands(eraseIndices);
+  EXPECT_EQ(user->getNumOperands(), 4u);
+
+  // Check erasing disjoint operands.
+  eraseIndices.set(0);
+  eraseIndices.set(3);
+  user->eraseOperands(eraseIndices);
+  EXPECT_EQ(user->getNumOperands(), 2u);
+  EXPECT_EQ(user->getOperand(0), operand1);
+  EXPECT_EQ(user->getOperand(1), operand2);
+
+  // Destroy the operations.
+  user->destroy();
+  useOp->destroy();
+}
+
 TEST(OperationOrderTest, OrderIsAlwaysValid) {
   MLIRContext context;
   Builder builder(&context);


        


More information about the Mlir-commits mailing list