[Mlir-commits] [mlir] 6635c12 - [mlir] Use SmallBitVector instead of SmallDenseSet for AffineMap::compressSymbols

Benjamin Kramer llvmlistbot at llvm.org
Sun Feb 6 15:24:44 PST 2022


Author: Benjamin Kramer
Date: 2022-02-07T00:21:44+01:00
New Revision: 6635c12ada0d2d4f2e3b579e818980d0a587135f

URL: https://github.com/llvm/llvm-project/commit/6635c12ada0d2d4f2e3b579e818980d0a587135f
DIFF: https://github.com/llvm/llvm-project/commit/6635c12ada0d2d4f2e3b579e818980d0a587135f.diff

LOG: [mlir] Use SmallBitVector instead of SmallDenseSet for AffineMap::compressSymbols

This is both more efficient and more ergonomic to use, as inverting a
bit vector is trivial while inverting a set is annoying.

Sadly this leaks into a bunch of APIs downstream, so adapt them as well.

This would be NFC, but there is an ordering dependency in MemRefOps's
computeMemRefRankReductionMask. This is now deterministic, previously it
was dependent on SmallDenseSet's unspecified iteration order.

Differential Revision: https://reviews.llvm.org/D119076

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
index a4dcf9622098..1d4353f176b1 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
@@ -31,8 +31,8 @@ detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
 void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
                              function_ref<bool(int64_t)> isDynamic);
 
-void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
-                            llvm::SmallDenseSet<unsigned> &dimsToProject);
+llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
+                                            ArrayRef<int64_t> shape);
 
 /// Pattern to rewrite a subview op with constant arguments.
 template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6438e707f7c2..81839dfde9c5 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1639,7 +1639,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
     /// Return the dimensions of the source type that are dropped when
     /// the result is rank-reduced.
-    llvm::SmallDenseSet<unsigned> getDroppedDims();
+    llvm::SmallBitVector getDroppedDims();
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ec597e9f650a..a4dc2953b201 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -316,8 +316,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
 
     /// Return the dimensions of the source that are dropped in the
     /// result when the result is rank-reduced.
-    llvm::SmallDenseSet<unsigned> getDroppedDims();
-
+    llvm::SmallBitVector getDroppedDims();
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index cd6fce22b614..82a33b065440 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -18,7 +18,10 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/DenseSet.h"
+
+namespace llvm {
+class SmallBitVector;
+} // namespace llvm
 
 namespace mlir {
 
@@ -372,8 +375,7 @@ AffineMap compressUnusedDims(AffineMap map);
 SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
 
 /// Drop the dims that are not listed in `unusedDims`.
-AffineMap compressDims(AffineMap map,
-                       const llvm::SmallDenseSet<unsigned> &unusedDims);
+AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
 
 /// Drop the symbols that are not used.
 AffineMap compressUnusedSymbols(AffineMap map);
@@ -385,7 +387,7 @@ SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
 
 /// Drop the symbols that are not listed in `unusedSymbols`.
 AffineMap compressSymbols(AffineMap map,
-                          const llvm::SmallDenseSet<unsigned> &unusedSymbols);
+                          const llvm::SmallBitVector &unusedSymbols);
 
 /// Returns a map with the same dimension and symbol count as `map`, but whose
 /// results are the unique affine expressions of `map`.
@@ -521,9 +523,8 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
 ///    result               : affine_map<(d0, d1) -> (d0, 0)>
 ///
 /// This function also compresses unused symbols away.
-AffineMap
-getProjectedMap(AffineMap map,
-                const llvm::SmallDenseSet<unsigned> &projectedDimensions);
+AffineMap getProjectedMap(AffineMap map,
+                          const llvm::SmallBitVector &projectedDimensions);
 
 /// Apply a permutation from `map` to `source` and return the result.
 template <typename T>

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index a4a82d66b0cb..d21d21e045fb 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 
@@ -1503,10 +1504,10 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
     assert(mixedSizes.size() == mixedStrides.size() &&
            "expected sizes and strides of equal length");
-    llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+    llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
          i >= 0 && j >= 0; --i) {
-      if (unusedDims.contains(i))
+      if (unusedDims.test(i))
         continue;
 
       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.

diff  --git a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
index 74dd373f17cd..4e35c5b31924 100644
--- a/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 
@@ -37,16 +38,16 @@ void mlir::canonicalizeSubViewPart(
   }
 }
 
-void mlir::getPositionsOfShapeOne(
-    unsigned rank, ArrayRef<int64_t> shape,
-    llvm::SmallDenseSet<unsigned> &dimsToProject) {
-  dimsToProject.reserve(rank);
+llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
+                                                  ArrayRef<int64_t> shape) {
+  llvm::SmallBitVector dimsToProject(shape.size());
   for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
     if (shape[pos] == 1) {
-      dimsToProject.insert(pos);
+      dimsToProject.set(pos);
       --rank;
     }
   }
+  return dimsToProject;
 }
 
 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9eda19ae3f4b..32d8ee098bce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -520,10 +520,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
 /// Prune all dimensions that are of reduction iterator type from `map`.
 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
                                            AffineMap map) {
-  llvm::SmallDenseSet<unsigned> projectedDims;
+  llvm::SmallBitVector projectedDims(iteratorTypes.size());
   for (const auto &attr : llvm::enumerate(iteratorTypes)) {
     if (!isParallelIterator(attr.value()))
-      projectedDims.insert(attr.index());
+      projectedDims.set(attr.index());
   }
   return getProjectedMap(map, projectedDims);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index d8e00ca5d4ed..65a0ed8ec68f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -187,7 +187,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
     return failure(hasDynamicShape);
 
   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
-  llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
+  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
 
   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
   SmallVector<int64_t> staticSizes;
@@ -195,7 +195,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
     // Skip dropped dimensions.
-    if (droppedDims.contains(en.index()))
+    if (droppedDims.test(en.index()))
       continue;
     // If the size is an attribute add it directly to `staticSizes`.
     if (en.value().is<Attribute>()) {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dad063ea30bc..21cb8d60f2d3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 using namespace mlir::memref;
@@ -590,17 +591,17 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
 /// This accounts for cases where there are multiple unit-dims, but only a
 /// subset of those are dropped. For MemRefTypes these can be disambiguated
 /// using the strides. If a dimension is dropped the stride must be dropped too.
-static llvm::Optional<llvm::SmallDenseSet<unsigned>>
+static llvm::Optional<llvm::SmallBitVector>
 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
                                ArrayRef<OpFoldResult> sizes) {
-  llvm::SmallDenseSet<unsigned> unusedDims;
+  llvm::SmallBitVector unusedDims(originalType.getRank());
   if (originalType.getRank() == reducedType.getRank())
     return unusedDims;
 
   for (const auto &dim : llvm::enumerate(sizes))
     if (auto attr = dim.value().dyn_cast<Attribute>())
       if (attr.cast<IntegerAttr>().getInt() == 1)
-        unusedDims.insert(dim.index());
+        unusedDims.set(dim.index());
 
   SmallVector<int64_t> originalStrides, candidateStrides;
   int64_t originalOffset, candidateOffset;
@@ -623,8 +624,9 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
       getNumOccurences(originalStrides);
   std::map<int64_t, unsigned> candidateStridesNumOccurences =
       getNumOccurences(candidateStrides);
-  llvm::SmallDenseSet<unsigned> prunedUnusedDims;
-  for (unsigned dim : unusedDims) {
+  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
+    if (!unusedDims.test(dim))
+      continue;
     int64_t originalStride = originalStrides[dim];
     if (currUnaccountedStrides[originalStride] >
         candidateStridesNumOccurences[originalStride]) {
@@ -635,7 +637,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
     if (currUnaccountedStrides[originalStride] ==
         candidateStridesNumOccurences[originalStride]) {
       // The stride for this is not dropped. Keep as is.
-      prunedUnusedDims.insert(dim);
+      unusedDims.reset(dim);
       continue;
     }
     if (currUnaccountedStrides[originalStride] <
@@ -646,17 +648,16 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
     }
   }
 
-  for (auto prunedDim : prunedUnusedDims)
-    unusedDims.erase(prunedDim);
-  if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
+  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
+      originalType.getRank())
     return llvm::None;
   return unusedDims;
 }
 
-llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
+llvm::SmallBitVector SubViewOp::getDroppedDims() {
   MemRefType sourceType = getSourceType();
   MemRefType resultType = getType();
-  llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+  llvm::Optional<llvm::SmallBitVector> unusedDims =
       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
   assert(unusedDims && "unable to find unused dims of subview");
   return *unusedDims;
@@ -698,12 +699,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
              memrefType.getDynamicDimIndex(unsignedIndex));
 
   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
-    llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
+    llvm::SmallBitVector unusedDims = subview.getDroppedDims();
     unsigned resultIndex = 0;
     unsigned sourceRank = subview.getSourceType().getRank();
     unsigned sourceIndex = 0;
     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
-      if (unusedDims.count(i))
+      if (unusedDims.test(i))
         continue;
       if (resultIndex == unsignedIndex) {
         sourceIndex = i;
@@ -1734,11 +1735,11 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
   int rankDiff = inferredType.getRank() - resultRank;
   if (rankDiff > 0) {
     auto shape = inferredType.getShape();
-    llvm::SmallDenseSet<unsigned> dimsToProject;
-    mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+    llvm::SmallBitVector dimsToProject =
+        getPositionsOfShapeOne(rankDiff, shape);
     SmallVector<int64_t> projectedShape;
     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
-      if (!dimsToProject.contains(pos))
+      if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
 
     AffineMap map = inferredType.getLayout().getAffineMap();
@@ -2015,7 +2016,7 @@ static MemRefType getCanonicalSubViewResultType(
   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
                                                        mixedSizes, mixedStrides)
                                 .cast<MemRefType>();
-  llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+  llvm::Optional<llvm::SmallBitVector> unusedDims =
       computeMemRefRankReductionMask(currentSourceType, currentResultType,
                                      mixedSizes);
   // Return nullptr as failure mode.
@@ -2023,7 +2024,7 @@ static MemRefType getCanonicalSubViewResultType(
     return nullptr;
   SmallVector<int64_t> shape;
   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
-    if (unusedDims->count(sizes.index()))
+    if (unusedDims->test(sizes.index()))
       continue;
     shape.push_back(sizes.value());
   }

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 213b442998b0..5080a69688b9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 
@@ -50,9 +51,9 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
   // Check if this is rank-reducing case. Then for every unit-dim size add a
   // zero to the indices.
   unsigned resultDim = 0;
-  llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
-    if (unusedDims.count(dim))
+    if (unusedDims.test(dim))
       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
     else
       useIndices.push_back(indices[resultDim++]);
@@ -106,11 +107,11 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
                                            memref::SubViewOp subViewOp,
                                            AffineMap currPermutationMap) {
-  llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
+  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
   SmallVector<AffineExpr> exprs;
   int64_t sourceRank = subViewOp.getSourceType().getRank();
   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
-    if (unusedDims.count(dim))
+    if (unusedDims.test(dim))
       continue;
     exprs.push_back(getAffineDimExpr(dim, context));
   }

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 58233d6dd6ff..a26ad993b495 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -935,11 +936,11 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
   int rankDiff = inferredType.getRank() - resultRank;
   if (rankDiff > 0) {
     auto shape = inferredType.getShape();
-    llvm::SmallDenseSet<unsigned> dimsToProject;
-    mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+    llvm::SmallBitVector dimsToProject =
+        getPositionsOfShapeOne(rankDiff, shape);
     SmallVector<int64_t> projectedShape;
     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
-      if (!dimsToProject.contains(pos))
+      if (!dimsToProject.test(pos))
         projectedShape.push_back(shape[pos]);
     inferredType =
         RankedTensorType::get(projectedShape, inferredType.getElementType());
@@ -1076,10 +1077,10 @@ getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
   return resultType;
 }
 
-llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
-  llvm::SmallDenseSet<unsigned> droppedDims;
+llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
   ArrayRef<int64_t> resultShape = getType().getShape();
   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
+  llvm::SmallBitVector droppedDims(mixedSizes.size());
   unsigned shapePos = 0;
   for (const auto &size : enumerate(mixedSizes)) {
     Optional<int64_t> sizeVal = getConstantIntValue(size.value());
@@ -1091,7 +1092,7 @@ llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
       shapePos++;
       continue;
     }
-    droppedDims.insert(size.index());
+    droppedDims.set(size.index());
   }
   return droppedDims;
 }
@@ -1101,10 +1102,10 @@ LogicalResult ExtractSliceOp::reifyResultShapes(
   reifiedReturnShapes.resize(1);
   reifiedReturnShapes[0].reserve(getType().getRank());
   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
-  llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
+  llvm::SmallBitVector droppedDims = getDroppedDims();
   Location loc = getLoc();
   for (const auto &size : enumerate(mixedSizes)) {
-    if (droppedDims.count(size.index()))
+    if (droppedDims.test(size.index()))
       continue;
     if (auto attr = size.value().dyn_cast<Attribute>()) {
       reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index ecdf8376b5fc..7fbe7488b97e 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -547,13 +547,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
 }
 
 AffineMap mlir::compressDims(AffineMap map,
-                             const llvm::SmallDenseSet<unsigned> &unusedDims) {
+                             const llvm::SmallBitVector &unusedDims) {
   unsigned numDims = 0;
   SmallVector<AffineExpr> dimReplacements;
   dimReplacements.reserve(map.getNumDims());
   MLIRContext *context = map.getContext();
   for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
-    if (unusedDims.contains(dim))
+    if (unusedDims.test(dim))
       dimReplacements.push_back(getAffineConstantExpr(0, context));
     else
       dimReplacements.push_back(getAffineDimExpr(numDims++, context));
@@ -566,15 +566,11 @@ AffineMap mlir::compressDims(AffineMap map,
 }
 
 AffineMap mlir::compressUnusedDims(AffineMap map) {
-  llvm::SmallDenseSet<unsigned> usedDims;
+  llvm::SmallBitVector unusedDims(map.getNumDims(), true);
   map.walkExprs([&](AffineExpr expr) {
     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
-      usedDims.insert(dimExpr.getPosition());
+      unusedDims.reset(dimExpr.getPosition());
   });
-  llvm::SmallDenseSet<unsigned> unusedDims;
-  for (unsigned d = 0, e = map.getNumDims(); d != e; ++d)
-    if (!usedDims.contains(d))
-      unusedDims.insert(d);
   return compressDims(map, unusedDims);
 }
 
@@ -613,15 +609,14 @@ SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
                             [](AffineMap m) { return compressUnusedDims(m); });
 }
 
-AffineMap
-mlir::compressSymbols(AffineMap map,
-                      const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
+AffineMap mlir::compressSymbols(AffineMap map,
+                                const llvm::SmallBitVector &unusedSymbols) {
   unsigned numSymbols = 0;
   SmallVector<AffineExpr> symReplacements;
   symReplacements.reserve(map.getNumSymbols());
   MLIRContext *context = map.getContext();
   for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
-    if (unusedSymbols.contains(sym))
+    if (unusedSymbols.test(sym))
       symReplacements.push_back(getAffineConstantExpr(0, context));
     else
       symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
@@ -634,15 +629,11 @@ mlir::compressSymbols(AffineMap map,
 }
 
 AffineMap mlir::compressUnusedSymbols(AffineMap map) {
-  llvm::SmallDenseSet<unsigned> usedSymbols;
+  llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
   map.walkExprs([&](AffineExpr expr) {
     if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
-      usedSymbols.insert(symExpr.getPosition());
+      unusedSymbols.reset(symExpr.getPosition());
   });
-  llvm::SmallDenseSet<unsigned> unusedSymbols;
-  for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d)
-    if (!usedSymbols.contains(d))
-      unusedSymbols.insert(d);
   return compressSymbols(map, unusedSymbols);
 }
 
@@ -732,9 +723,8 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
                         maps.front().getContext());
 }
 
-AffineMap
-mlir::getProjectedMap(AffineMap map,
-                      const llvm::SmallDenseSet<unsigned> &unusedDims) {
+AffineMap mlir::getProjectedMap(AffineMap map,
+                                const llvm::SmallBitVector &unusedDims) {
   return compressUnusedSymbols(compressDims(map, unusedDims));
 }
 

diff  --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index fc06cd35cc8c..6f8e46fab444 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -154,8 +154,8 @@ func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG13]])[%[[ARG7]], %[[ARG1]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG8]], %[[ARG2]]]
-//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG9]], %[[ARG3]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG8]], %[[ARG2]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG9]], %[[ARG3]]]
 //  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]](%[[ARG15]])[%[[ARG10]], %[[ARG4]]]
 //  CHECK-DAG:   %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
 //  CHECK-DAG:   %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]


        


More information about the Mlir-commits mailing list