[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)
Benjamin Maxwell
llvmlistbot at llvm.org
Sun Dec 3 12:54:16 PST 2023
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/74251
Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.
For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:
```c++
while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
}
```
Which would be wrong if you (more naturally) wrote it as:
```c++
auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
});
```
As this would trim scalable one dims (`[1]`), which are not unit dims like their fixed counterpart.
---
This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.
Two new methods are added to VectorType:
```c++
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();
```
These are backed by two new classes: `VectorDim` and `VectorDims`.
`VectorDim` represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.
`VectorDims` represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).
There are also new builders to construct VectorTypes from both the `VectorDims` class and an `ArrayRef<VectorDim>`.
With these changes the previous example becomes:
```c++
auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
});
```
Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with `1`, would fail to build.
>From 666dd288cbcc1fa5ba9a0e780baf25c3f3dd1c28 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Sun, 3 Dec 2023 20:01:42 +0000
Subject: [PATCH] [mlir] Add first-class support for scalability in VectorType
dims
Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.
For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:
```c++
while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
}
```
Which would be wrong if you (more naturally) wrote it as:
```c++
auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
});
```
As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.
This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.
Two new methods are added to VectorType:
```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```
These are backed by two new classes: `VectorDim` and `VectorDims`.
`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.
`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).
There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.
With these changes the previous example becomes:
```c++
auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
});
```
Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
---
mlir/include/mlir/IR/BuiltinTypes.h | 189 ++++++++++++++++++
mlir/include/mlir/IR/BuiltinTypes.td | 23 +++
.../Conversion/VectorToSCF/VectorToSCF.cpp | 3 +-
.../Vector/Transforms/LowerVectorMask.cpp | 4 +-
mlir/unittests/IR/ShapedTypeTest.cpp | 93 +++++++++
5 files changed, 309 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b377849567d18 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
namespace llvm {
class BitVector;
@@ -181,6 +182,194 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable){};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDims.
+ struct Indexer {
+ Indexer() = default;
+ explicit constexpr Indexer(ArrayRef<int64_t> sizes,
+ ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims){};
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index){};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ /// Construct from iterator pair.
+ VectorDims(Iterator begin, Iterator end)
+ : VectorDims(
+ VectorDims(begin.getIndexer().sizes,
+ begin.getIndexer().scalableDims)
+ .slice(begin.getIndex(), end.getIndex() - begin.getIndex())) {}
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDims slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDims(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDims dropWhile(PredicateT predicate) const {
+ return VectorDims(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+ /// Check for dim equality.
+ bool equals(VectorDims rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..a77a32946bc21 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "ArrayRef<VectorDim>": $shape, "Type":$elementType), [{
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ for (VectorDim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return get(sizes, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "VectorDims": $shape, "Type":$elementType), [{
+ return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Returns the value of the specified dimension (including scalability).
+ VectorDim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ VectorDims getDims() const {
+ return VectorDims(getShape(), getScalableDims());
+ }
+
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 33a77d7576ba7..b0f94901b6d2c 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1258,7 +1258,8 @@ struct UnrollTransferWriteConversion
// argument into `transfer_write` to become a scalar. We solve
// this by broadcasting the scalar to a 0D vector.
xferVec = b.create<vector::BroadcastOp>(
- loc, VectorType::get({}, extracted.getType()), extracted);
+ loc, VectorType::get(VectorDims{}, extracted.getType()),
+ extracted);
} else {
xferVec = extracted;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index f53bb5157eb37..56509aaf44343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -114,8 +114,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
- DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
- value));
+ DenseIntElementsAttr::get(
+ VectorType::get(VectorDims{}, rewriter.getI1Type()), value));
return success();
}
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..a27ddc9492936 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,97 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+TEST(ShapedTypeTest, VectorDims) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+ VectorDim::getFixed(8), VectorDim::getScalable(9),
+ VectorDim::getFixed(1)};
+ VectorType vectorType = VectorType::get(dims, f32);
+
+ // Directly check values
+ {
+ auto dim0 = vectorType.getDim(0);
+ ASSERT_EQ(dim0.getMinSize(), 2);
+ ASSERT_TRUE(dim0.isFixed());
+
+ auto dim1 = vectorType.getDim(1);
+ ASSERT_EQ(dim1.getMinSize(), 4);
+ ASSERT_TRUE(dim1.isScalable());
+
+ auto dim2 = vectorType.getDim(2);
+ ASSERT_EQ(dim2.getMinSize(), 8);
+ ASSERT_TRUE(dim2.isFixed());
+
+ auto dim3 = vectorType.getDim(3);
+ ASSERT_EQ(dim3.getMinSize(), 9);
+ ASSERT_TRUE(dim3.isScalable());
+
+ auto dim4 = vectorType.getDim(4);
+ ASSERT_EQ(dim4.getMinSize(), 1);
+ ASSERT_TRUE(dim4.isFixed());
+ }
+
+ // Test indexing via getDim(idx)
+ {
+ for (unsigned i = 0; i < dims.size(); i++)
+ ASSERT_EQ(vectorType.getDim(i), dims[i]);
+ }
+
+ // Test using VectorDims::Iterator
+ {
+ for (auto [dim, expectedDim] :
+ llvm::zip_equal(vectorType.getDims(), dims)) {
+ ASSERT_EQ(dim, expectedDim);
+ }
+ }
+
+ // Test dropFront()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropFront();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+ }
+
+ // Test dropBack()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropBack();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i]);
+ }
+
+ // Test front()
+ { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+ // Test back()
+ { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+ // Test dropWhile.
+ {
+ SmallVector<VectorDim> dims{
+ VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+ VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+ VectorType vectorTypeWithLeadingUnitDims = VectorType::get(dims, f32);
+ ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(), 5U);
+
+ // Drop leading unit dims.
+ auto withoutLeadingUnitDims =
+ vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+ SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+ VectorDim::getScalable(4)};
+
+ ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+ }
+}
+
} // namespace
More information about the Mlir-commits
mailing list