[Mlir-commits] [mlir] e3cd80e - [mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results

River Riddle llvmlistbot at llvm.org
Wed Jan 26 21:38:03 PST 2022


Author: River Riddle
Date: 2022-01-26T21:37:22-08:00
New Revision: e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863

URL: https://github.com/llvm/llvm-project/commit/e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863
DIFF: https://github.com/llvm/llvm-project/commit/e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863.diff

LOG: [mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results

We already convert to BitVector internally, and other APIs (namely Operation::eraseOperands)
already use BitVector as well. Switching over provides a common format between
API and also reduces the amount of format conversions necessary.

Fixes #53325

Differential Revision: https://reviews.llvm.org/D118083

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/FunctionInterfaces.h
    mlir/include/mlir/IR/FunctionInterfaces.td
    mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/FunctionInterfaces.cpp
    mlir/test/lib/IR/TestFunc.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index e087e6bf55d8c..b03d9ea9f575d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -13,6 +13,7 @@
 #include "SubElementInterfaces.h"
 
 namespace llvm {
+class BitVector;
 struct fltSemantics;
 } // namespace llvm
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b6f90f92190a8..a78413548ba82 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -166,8 +166,8 @@ def Builtin_Function : Builtin_Type<"Function", [
                                        TypeRange resultTypes);
 
     /// Returns a new function type without the specified arguments and results.
-    FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
-                                          ArrayRef<unsigned> resultIndices);
+    FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices,
+                                          const llvm::BitVector &resultIndices);
   }];
 }
 

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index b81abceef6149..b6d6a9515ffe0 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
@@ -82,12 +83,12 @@ void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
                            unsigned originalNumResults, Type newType);
 
 /// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
-                            unsigned originalNumArgs, Type newType);
+void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices,
+                            Type newType);
 
 /// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
-                          unsigned originalNumResults, Type newType);
+void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices,
+                          Type newType);
 
 /// Set a FunctionOpInterface operation's type signature.
 void setFunctionType(Operation *op, Type newType);
@@ -100,7 +101,7 @@ TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
 
 /// Filters out any elements referenced by `indices`. If any types are removed,
 /// `storage` is used to hold the new type list. Returns the new type list.
-TypeRange filterTypesOut(TypeRange types, ArrayRef<unsigned> indices,
+TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices,
                          SmallVectorImpl<Type> &storage);
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index 20c7d7bbd51b7..124f6594081f7 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -280,27 +280,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
 
     /// Erase a single argument at `argIndex`.
-    void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
+    void eraseArgument(unsigned argIndex) {
+      llvm::BitVector argsToErase($_op.getNumArguments());
+      argsToErase.set(argIndex);
+      eraseArguments(argsToErase);
+    }
 
     /// Erases the arguments listed in `argIndices`.
-    /// `argIndices` is allowed to have duplicates and can be in any order.
-    void eraseArguments(ArrayRef<unsigned> argIndices) {
-      unsigned originalNumArgs = $_op.getNumArguments();
-      Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {});
-      function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices,
-                                                originalNumArgs, newType);
+    void eraseArguments(const llvm::BitVector &argIndices) {
+      Type newType = $_op.getTypeWithoutArgs(argIndices);
+      function_interface_impl::eraseFunctionArguments(
+        this->getOperation(), argIndices, newType);
     }
 
     /// Erase a single result at `resultIndex`.
-    void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
+    void eraseResult(unsigned resultIndex) {
+      llvm::BitVector resultsToErase($_op.getNumResults());
+      resultsToErase.set(resultIndex);
+      eraseResults(resultsToErase);
+    }
 
     /// Erases the results listed in `resultIndices`.
-    /// `resultIndices` is allowed to have duplicates and can be in any order.
-    void eraseResults(ArrayRef<unsigned> resultIndices) {
-      unsigned originalNumResults = $_op.getNumResults();
-      Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices);
+    void eraseResults(const llvm::BitVector &resultIndices) {
+      Type newType = $_op.getTypeWithoutResults(resultIndices);
       function_interface_impl::eraseFunctionResults(
-          this->getOperation(), resultIndices, originalNumResults, newType);
+          this->getOperation(), resultIndices, newType);
     }
 
     /// Return the type of this function with the specified arguments and
@@ -320,10 +324,9 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return the type of this function without the specified arguments and
     /// results. This is used to update the function's signature in the
-    /// `eraseArguments` and `eraseResults` methods. The arrays of indices are
-    /// allowed to have duplicates and can be in any order.
+    /// `eraseArguments` and `eraseResults` methods.
     Type getTypeWithoutArgsAndResults(
-      ArrayRef<unsigned> argIndices, ArrayRef<unsigned> resultIndices) {
+      const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) {
       SmallVector<Type> argStorage, resultStorage;
       TypeRange newArgTypes = function_interface_impl::filterTypesOut(
           $_op.getArgumentTypes(), argIndices, argStorage);
@@ -331,6 +334,18 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
           $_op.getResultTypes(), resultIndices, resultStorage);
       return $_op.cloneTypeWith(newArgTypes, newResultTypes);
     }
+    Type getTypeWithoutArgs(const llvm::BitVector &argIndices) {
+      SmallVector<Type> argStorage;
+      TypeRange newArgTypes = function_interface_impl::filterTypesOut(
+          $_op.getArgumentTypes(), argIndices, argStorage);
+      return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes());
+    }
+    Type getTypeWithoutResults(const llvm::BitVector &resultIndices) {
+      SmallVector<Type> resultStorage;
+      TypeRange newResultTypes = function_interface_impl::filterTypesOut(
+          $_op.getResultTypes(), resultIndices, resultStorage);
+      return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes);
+    }
 
     //===------------------------------------------------------------------===//
     // Argument Attributes

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 08780db5f94df..585e873b41881 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -24,10 +24,10 @@ static void updateFuncOp(FuncOp func,
 
   // Collect information about the results will become appended arguments.
   SmallVector<Type, 6> erasedResultTypes;
-  SmallVector<unsigned, 6> erasedResultIndices;
+  llvm::BitVector erasedResultIndices(functionType.getNumResults());
   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
     if (resultType.value().isa<BaseMemRefType>()) {
-      erasedResultIndices.push_back(resultType.index());
+      erasedResultIndices.set(resultType.index());
       erasedResultTypes.push_back(resultType.value());
     }
   }
@@ -40,9 +40,11 @@ static void updateFuncOp(FuncOp func,
   func.setType(newFunctionType);
 
   // Transfer the result attributes to arg attributes.
-  for (int i = 0, e = erasedResultTypes.size(); i < e; i++)
+  auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
+  for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
     func.setArgAttrs(functionType.getNumInputs() + i,
-                     func.getResultAttrs(erasedResultIndices[i]));
+                     func.getResultAttrs(*erasedIndicesIt));
+  }
 
   // Erase the results.
   func.eraseResults(erasedResultIndices);

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 46166f16f1686..63a5962803900 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -172,8 +172,8 @@ FunctionType FunctionType::getWithArgsAndResults(
 
 /// Returns a new function type without the specified arguments and results.
 FunctionType
-FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
-                                       ArrayRef<unsigned> resultIndices) {
+FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices,
+                                       const llvm::BitVector &resultIndices) {
   SmallVector<Type> argStorage, resultStorage;
   TypeRange newArgTypes = function_interface_impl::filterTypesOut(
       getInputs(), argIndices, argStorage);

diff  --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 07da5ce1716f1..4f31c59f3f69e 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -7,26 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/FunctionInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/BitVector.h"
 
 using namespace mlir;
 
-/// Helper to call a callback once on each index in the range
-/// [0, `totalIndices`), *except* for the indices given in `indices`.
-/// `indices` is allowed to have duplicates and can be in any order.
-inline static void iterateIndicesExcept(unsigned totalIndices,
-                                        ArrayRef<unsigned> indices,
-                                        function_ref<void(unsigned)> callback) {
-  llvm::BitVector skipIndices(totalIndices);
-  for (unsigned i : indices)
-    skipIndices.set(i);
-
-  for (unsigned i = 0; i < totalIndices; ++i)
-    if (!skipIndices.test(i))
-      callback(i);
-}
-
 //===----------------------------------------------------------------------===//
 // Tablegen Interface Definitions
 //===----------------------------------------------------------------------===//
@@ -217,8 +200,7 @@ void mlir::function_interface_impl::insertFunctionResults(
 }
 
 void mlir::function_interface_impl::eraseFunctionArguments(
-    Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
-    Type newType) {
+    Operation *op, const llvm::BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
@@ -229,9 +211,9 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(argAttrs.size());
-    iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
-      newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
-    });
+    for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
+      if (!argIndices[i])
+        newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
     setAllArgAttrDicts(op, newArgAttrs);
   }
 
@@ -241,8 +223,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
 }
 
 void mlir::function_interface_impl::eraseFunctionResults(
-    Operation *op, ArrayRef<unsigned> resultIndices,
-    unsigned originalNumResults, Type newType) {
+    Operation *op, const llvm::BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
@@ -251,9 +232,9 @@ void mlir::function_interface_impl::eraseFunctionResults(
   if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(resAttrs.size());
-    iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
-      newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
-    });
+    for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
+      if (!resultIndices[i])
+        newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
     setAllResultAttrDicts(op, newResultAttrs);
   }
 
@@ -282,12 +263,14 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
 
 TypeRange
 mlir::function_interface_impl::filterTypesOut(TypeRange types,
-                                              ArrayRef<unsigned> indices,
+                                              const llvm::BitVector &indices,
                                               SmallVectorImpl<Type> &storage) {
-  if (indices.empty())
+  if (indices.none())
     return types;
-  iterateIndicesExcept(types.size(), indices,
-                       [&](unsigned i) { storage.emplace_back(types[i]); });
+
+  for (unsigned i = 0, e = types.size(); i < e; ++i)
+    if (!indices[i])
+      storage.emplace_back(types[i]);
   return storage;
 }
 

diff  --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index dee9f8a5b2e52..0b4ce1d05992e 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -87,18 +87,10 @@ struct TestFuncEraseArg
     auto module = getOperation();
 
     for (FuncOp func : module.getOps<FuncOp>()) {
-      SmallVector<unsigned, 4> indicesToErase;
-      for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
-        if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
-          // Push back twice to test that duplicate arg indices are handled
-          // correctly.
-          indicesToErase.push_back(argIndex);
-          indicesToErase.push_back(argIndex);
-        }
-      }
-      // Reverse the order to test that unsorted index lists are handled
-      // correctly.
-      std::reverse(indicesToErase.begin(), indicesToErase.end());
+      llvm::BitVector indicesToErase(func.getNumArguments());
+      for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
+        if (func.getArgAttr(argIndex, "test.erase_this_arg"))
+          indicesToErase.set(argIndex);
       func.eraseArguments(indicesToErase);
     }
   }
@@ -115,18 +107,10 @@ struct TestFuncEraseResult
     auto module = getOperation();
 
     for (FuncOp func : module.getOps<FuncOp>()) {
-      SmallVector<unsigned, 4> indicesToErase;
-      for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
-        if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
-          // Push back twice to test that duplicate indices are handled
-          // correctly.
-          indicesToErase.push_back(resultIndex);
-          indicesToErase.push_back(resultIndex);
-        }
-      }
-      // Reverse the order to test that unsorted index lists are handled
-      // correctly.
-      std::reverse(indicesToErase.begin(), indicesToErase.end());
+      llvm::BitVector indicesToErase(func.getNumResults());
+      for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
+        if (func.getResultAttr(resultIndex, "test.erase_this_result"))
+          indicesToErase.set(resultIndex);
       func.eraseResults(indicesToErase);
     }
   }


        


More information about the Mlir-commits mailing list