[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 03:55:06 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.

For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:

```c++
while (!newShape.empty() && newShape.front() == 1 &&
        !newScalableDims.front()) {
  newShape = newShape.drop_front(1);
  newScalableDims = newScalableDims.drop_front(1);
}
```

Which would be wrong if you (more naturally) wrote it as:

```c++
auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
  return dim == 1;
});
```

As this would trim scalable one dims (`[1]`), which are not unit dims like their fixed counterparts.

---

This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

```c++
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();
```

These are backed by two new classes: `VectorDim` and `VectorDims`.

`VectorDim` represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.

`VectorDims` represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both the `VectorDims` class and an `ArrayRef<VectorDim>`.

With these changes the previous example becomes:

```c++
auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
  return dim == VectorDim::getFixed(1);
});
```

Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with `1`, would fail to build.

---
Full diff: https://github.com/llvm/llvm-project/pull/74251.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+192) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23) 
- (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+101) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+  explicit constexpr VectorDim(int64_t quantity, bool scalable)
+      : quantity(quantity), scalable(scalable){};
+
+  /// Constructs a new fixed dimension.
+  constexpr static VectorDim getFixed(int64_t quantity) {
+    return VectorDim(quantity, false);
+  }
+
+  /// Constructs a new scalable dimension.
+  constexpr static VectorDim getScalable(int64_t quantity) {
+    return VectorDim(quantity, true);
+  }
+
+  /// Returns true if this dimension is scalable;
+  constexpr bool isScalable() const { return scalable; }
+
+  /// Returns true if this dimension is fixed.
+  constexpr bool isFixed() const { return !isScalable(); }
+
+  /// Returns the minimum number of elements this dimension can contain.
+  constexpr int64_t getMinSize() const { return quantity; }
+
+  /// If this dimension is fixed returns the number of elements, otherwise
+  /// aborts.
+  constexpr int64_t getFixedSize() const {
+    assert(isFixed());
+    return quantity;
+  }
+
+  constexpr bool operator==(VectorDim const &dim) const {
+    return quantity == dim.quantity && scalable == dim.scalable;
+  }
+
+  constexpr bool operator!=(VectorDim const &dim) const {
+    return !(*this == dim);
+  }
+
+  /// Helper class for indexing into a list of sizes (and possibly empty) list
+  /// of scalable dimensions, extracting VectorDim elements.
+  struct Indexer {
+    explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+        : sizes(sizes), scalableDims(scalableDims) {
+      assert(
+          scalableDims.empty() ||
+          sizes.size() == scalableDims.size() &&
+              "expected `scalableDims` to be empty or match `sizes` in length");
+    }
+
+    VectorDim operator[](size_t idx) const {
+      int64_t size = sizes[idx];
+      bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+      return VectorDim(size, scalable);
+    }
+
+    ArrayRef<int64_t> sizes;
+    ArrayRef<bool> scalableDims;
+  };
+
+private:
+  int64_t quantity;
+  bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+  using VectorDim::Indexer::Indexer;
+
+  class Iterator : public llvm::iterator_facade_base<
+                       Iterator, std::random_access_iterator_tag, VectorDim,
+                       std::ptrdiff_t, VectorDim, VectorDim> {
+  public:
+    Iterator(VectorDim::Indexer indexer, size_t index)
+        : indexer(indexer), index(index){};
+
+    // Iterator boilerplate.
+    ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+    bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+    bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+    Iterator &operator+=(ptrdiff_t offset) {
+      index += offset;
+      return *this;
+    }
+    Iterator &operator-=(ptrdiff_t offset) {
+      index -= offset;
+      return *this;
+    }
+    VectorDim operator*() const { return indexer[index]; }
+
+    VectorDim::Indexer getIndexer() const { return indexer; }
+    ptrdiff_t getIndex() const { return index; }
+
+  private:
+    VectorDim::Indexer indexer;
+    ptrdiff_t index;
+  };
+
+  /// Construct from iterator pair.
+  VectorDims(Iterator begin, Iterator end)
+      : VectorDims(VectorDims(begin.getIndexer())
+                       .slice(begin.getIndex(), end - begin)) {}
+
+  VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+  Iterator begin() const { return Iterator(*this, 0); }
+  Iterator end() const { return Iterator(*this, size()); }
+
+  /// Check if the dims are empty.
+  bool empty() const { return sizes.empty(); }
+
+  /// Get the number of dims.
+  size_t size() const { return sizes.size(); }
+
+  /// Return the first dim.
+  VectorDim front() const { return (*this)[0]; }
+
+  /// Return the last dim.
+  VectorDim back() const { return (*this)[size() - 1]; }
+
+  /// Chop of thie first \p n dims, and keep the remaining \p m
+  /// dims.
+  VectorDims slice(size_t n, size_t m) const {
+    ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+    ArrayRef<bool> newScalableDims =
+        scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+    return VectorDims(newSizes, newScalableDims);
+  }
+
+  /// Drop the first \p n dims.
+  VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+  /// Drop the last \p n dims.
+  VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+  /// Return copy of *this with the first n dims matching the predicate removed.
+  template <class PredicateT>
+  VectorDims dropWhile(PredicateT predicate) const {
+    return VectorDims(llvm::find_if_not(*this, predicate), end());
+  }
+
+  /// Return the underlying sizes.
+  ArrayRef<int64_t> getSizes() const { return sizes; }
+
+  /// Return the underlying scalable dims.
+  ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+  /// Check for dim equality.
+  bool equals(VectorDims rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Check for dim equality.
+  bool equals(ArrayRef<VectorDim> rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return !(lhs == rhs);
+}
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         scalableDims = isScalableVec;
       }
       return $_get(elementType.getContext(), shape, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+      SmallVector<int64_t> sizes;
+      SmallVector<bool> scalableDims;
+      for (VectorDim dim : shape) {
+        sizes.push_back(dim.getMinSize());
+        scalableDims.push_back(dim.isScalable());
+      }
+      return get(sizes, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+      return get(shape.getSizes(), elementType, shape.getScalableDims());
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Returns the value of the specified dimension (including scalability).
+    VectorDim getDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid dim index for vector type");
+      return getDims()[idx];
+    }
+
+    /// Returns the dimensions of this vector type (including scalability).
+    VectorDims getDims() const {
+      return VectorDims(getShape(), getScalableDims());
+    }
+
     /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+TEST(ShapedTypeTest, VectorDims) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+                              VectorDim::getFixed(8), VectorDim::getScalable(9),
+                              VectorDim::getFixed(1)};
+  VectorType vectorType = VectorType::get(f32, dims);
+
+  // Directly check values
+  {
+    auto dim0 = vectorType.getDim(0);
+    ASSERT_EQ(dim0.getMinSize(), 2);
+    ASSERT_TRUE(dim0.isFixed());
+
+    auto dim1 = vectorType.getDim(1);
+    ASSERT_EQ(dim1.getMinSize(), 4);
+    ASSERT_TRUE(dim1.isScalable());
+
+    auto dim2 = vectorType.getDim(2);
+    ASSERT_EQ(dim2.getMinSize(), 8);
+    ASSERT_TRUE(dim2.isFixed());
+
+    auto dim3 = vectorType.getDim(3);
+    ASSERT_EQ(dim3.getMinSize(), 9);
+    ASSERT_TRUE(dim3.isScalable());
+
+    auto dim4 = vectorType.getDim(4);
+    ASSERT_EQ(dim4.getMinSize(), 1);
+    ASSERT_TRUE(dim4.isFixed());
+  }
+
+  // Test indexing via getDim(idx)
+  {
+    for (unsigned i = 0; i < dims.size(); i++)
+      ASSERT_EQ(vectorType.getDim(i), dims[i]);
+  }
+
+  // Test using VectorDims::Iterator in for-each loop
+  {
+    unsigned i = 0;
+    for (VectorDim dim : vectorType.getDims())
+      ASSERT_EQ(dim, dims[i++]);
+    ASSERT_EQ(i, vectorType.getRank());
+  }
+
+  // Test using VectorDims::Iterator in LLVM iterator helper
+  {
+    for (auto [dim, expectedDim] :
+         llvm::zip_equal(vectorType.getDims(), dims)) {
+      ASSERT_EQ(dim, expectedDim);
+    }
+  }
+
+  // Test dropFront()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropFront();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+  }
+
+  // Test dropBack()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropBack();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i]);
+  }
+
+  // Test front()
+  { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+  // Test back()
+  { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+  // Test dropWhile.
+  {
+    SmallVector<VectorDim> dims{
+        VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+        VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+    VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+    ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+              unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+    // Drop leading unit dims.
+    auto withoutLeadingUnitDims =
+        vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+            [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+    SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+                                        VectorDim::getScalable(4)};
+    ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+  }
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/74251


More information about the Mlir-commits mailing list