[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Dec 7 07:05:13 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/74251
>From ca8250c41acf806f045562468f1da02e21c004b5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Sun, 3 Dec 2023 20:01:42 +0000
Subject: [PATCH 1/2] [mlir] Add first-class support for scalability in
VectorType dims
Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.
For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:
```c++
while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
}
```
Which would be wrong if you (more naturally) wrote it as:
```c++
auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
});
```
As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.
This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.
Two new methods are added to VectorType:
```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```
These are backed by two new classes: `VectorDim` and `VectorDims`.
`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.
`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).
There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.
With these changes the previous example becomes:
```c++
auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
});
```
Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
---
mlir/include/mlir/IR/BuiltinTypes.h | 234 +++++++++++++++++++++++++++
mlir/include/mlir/IR/BuiltinTypes.td | 23 +++
mlir/unittests/IR/ShapedTypeTest.cpp | 101 ++++++++++++
3 files changed, 358 insertions(+)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..2039316e6ba25 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
namespace llvm {
class BitVector;
@@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable){};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Print the dim.
+ void print(raw_ostream &os) {
+ if (isScalable())
+ os << '[';
+ os << getMinSize();
+ if (isScalable())
+ os << ']';
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDim elements.
+ struct Indexer {
+ explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims) {
+ assert(
+ scalableDims.empty() ||
+ sizes.size() == scalableDims.size() &&
+ "expected `scalableDims` to be empty or match `sizes` in length");
+ }
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) {
+ dim.print(os);
+ return os;
+}
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index){};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ // Generic definitions.
+ using value_type = VectorDim;
+ using iterator = Iterator;
+ using const_iterator = Iterator;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+ using size_type = size_t;
+ using difference_type = ptrdiff_t;
+
+ /// Construct from iterator pair.
+ VectorDims(Iterator begin, Iterator end)
+ : VectorDims(VectorDims(begin.getIndexer())
+ .slice(begin.getIndex(), end - begin)) {}
+
+ VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDims slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDims(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return a copy of *this with only the first \p n elements.
+ VectorDims takeFront(size_t n = 1) const {
+ if (n >= size())
+ return *this;
+ return dropBack(size() - n);
+ }
+
+ /// Return a copy of *this with only the last \p n elements.
+ VectorDims takeBack(size_t n = 1) const {
+ if (n >= size())
+ return *this;
+ return dropFront(size() - n);
+ }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDims dropWhile(PredicateT predicate) const {
+ return VectorDims(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Returns true if one or more of the dims are scalable.
+ bool hasScalableDims() const {
+ return llvm::is_contained(getScalableDims(), true);
+ }
+
+ /// Check for dim equality.
+ bool equals(VectorDims rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..8835074efbc66 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1114,6 +1114,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ for (VectorDim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return get(sizes, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+ return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
@@ -1121,6 +1133,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Returns the value of the specified dimension (including scalability).
+ VectorDim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ VectorDims getDims() const {
+ return VectorDims(getShape(), getScalableDims());
+ }
+
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+TEST(ShapedTypeTest, VectorDims) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+ VectorDim::getFixed(8), VectorDim::getScalable(9),
+ VectorDim::getFixed(1)};
+ VectorType vectorType = VectorType::get(f32, dims);
+
+ // Directly check values
+ {
+ auto dim0 = vectorType.getDim(0);
+ ASSERT_EQ(dim0.getMinSize(), 2);
+ ASSERT_TRUE(dim0.isFixed());
+
+ auto dim1 = vectorType.getDim(1);
+ ASSERT_EQ(dim1.getMinSize(), 4);
+ ASSERT_TRUE(dim1.isScalable());
+
+ auto dim2 = vectorType.getDim(2);
+ ASSERT_EQ(dim2.getMinSize(), 8);
+ ASSERT_TRUE(dim2.isFixed());
+
+ auto dim3 = vectorType.getDim(3);
+ ASSERT_EQ(dim3.getMinSize(), 9);
+ ASSERT_TRUE(dim3.isScalable());
+
+ auto dim4 = vectorType.getDim(4);
+ ASSERT_EQ(dim4.getMinSize(), 1);
+ ASSERT_TRUE(dim4.isFixed());
+ }
+
+ // Test indexing via getDim(idx)
+ {
+ for (unsigned i = 0; i < dims.size(); i++)
+ ASSERT_EQ(vectorType.getDim(i), dims[i]);
+ }
+
+ // Test using VectorDims::Iterator in for-each loop
+ {
+ unsigned i = 0;
+ for (VectorDim dim : vectorType.getDims())
+ ASSERT_EQ(dim, dims[i++]);
+ ASSERT_EQ(i, vectorType.getRank());
+ }
+
+ // Test using VectorDims::Iterator in LLVM iterator helper
+ {
+ for (auto [dim, expectedDim] :
+ llvm::zip_equal(vectorType.getDims(), dims)) {
+ ASSERT_EQ(dim, expectedDim);
+ }
+ }
+
+ // Test dropFront()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropFront();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+ }
+
+ // Test dropBack()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropBack();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i]);
+ }
+
+ // Test front()
+ { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+ // Test back()
+ { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+ // Test dropWhile.
+ {
+ SmallVector<VectorDim> dims{
+ VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+ VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+ VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+ ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+ unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+ // Drop leading unit dims.
+ auto withoutLeadingUnitDims =
+ vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+ SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+ VectorDim::getScalable(4)};
+ ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+ }
+}
+
} // namespace
>From e95d21fbf3f7176856dda3fb00989c7e1abbd86f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 7 Dec 2023 13:59:59 +0000
Subject: [PATCH 2/2] Demonstrate using the new APIs in scalable-aware code :)
This is not a complete change, this just updates a few examples found
by grepping for getScalableDims().
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 5 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 3 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 12 ++---
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 8 ++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 53 ++++++++-----------
.../Vector/Transforms/LowerVectorTransfer.cpp | 14 +++--
.../Transforms/LowerVectorTranspose.cpp | 6 +--
.../Transforms/VectorDropLeadUnitDim.cpp | 23 +++-----
.../Transforms/VectorTransferOpTransforms.cpp | 20 +++----
.../Vector/Transforms/VectorTransforms.cpp | 5 +-
mlir/lib/IR/AsmPrinter.cpp | 13 ++---
mlir/lib/IR/BuiltinTypes.cpp | 4 +-
13 files changed, 63 insertions(+), 106 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..bfbe51d4e4e32 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -490,12 +490,11 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return {};
if (type.getShape().empty())
return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
- type.getScalableDims().back());
+ Type vectorType = VectorType::get(elementType, type.getDims().takeBack());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
// Only the trailing dimension can be scalable.
- if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+ if (type.getDims().dropBack().hasScalableDims())
return failure();
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index cd5df0be740b9..881bddaf228f9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -37,8 +37,7 @@ using namespace mlir::vector;
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
- tp.getScalableDims().take_back());
+ return VectorType::get(tp.getElementType(), tp.getDims().takeBack());
}
// Helper that picks the proper sequence for inserting.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe..4869cf304e75b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -319,12 +319,13 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
auto vectorType = dyn_cast<VectorType>(type.getElementType());
// Vectors with leading scalable dims are not supported.
// It may be possible to support these in future by using dynamic memref dims.
- if (vectorType.getScalableDims().front())
+ VectorDim leadingDim = vectorType.getDims().front();
+ if (leadingDim.isScalable())
return failure();
auto memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
- newMemrefShape.push_back(vectorType.getDimSize(0));
+ newMemrefShape.push_back(leadingDim.getFixedSize());
return MemRefType::get(newMemrefShape,
VectorType::Builder(vectorType).dropDim(0));
}
@@ -1091,18 +1092,17 @@ struct UnrollTransferReadConversion
auto vecType = dyn_cast<VectorType>(vec.getType());
auto xferVecType = xferOp.getVectorType();
- if (xferVecType.getScalableDims()[0]) {
+ VectorDim dim = xferVecType.getDim(0);
+ if (dim.isScalable()) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
- int64_t dimSize = xferVecType.getShape()[0];
-
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
- for (int64_t i = 0; i < dimSize; ++i) {
+ for (int64_t i = 0; i < dim.getFixedSize(); ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
vec = generateInBoundsCheck(
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 594c9b4c270f2..fa49a21eafa14 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -30,8 +30,7 @@ using namespace mlir::arm_sve;
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
- return VectorType::get(sVectorType.getShape(), i1Type,
- sVectorType.getScalableDims());
+ return VectorType::get(i1Type, sVectorType.getDims());
return nullptr;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9..08a476c37b3f3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1217,9 +1217,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
assert(vecOperand && "Vector operand couldn't be found");
if (firstMaxRankedType) {
- auto vecType = VectorType::get(firstMaxRankedType.getShape(),
- getElementTypeOrSelf(vecOperand.getType()),
- firstMaxRankedType.getScalableDims());
+ auto vecType = VectorType::get(getElementTypeOrSelf(vecOperand.getType()),
+ firstMaxRankedType.getDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
vecOperands.push_back(vecOperand);
@@ -1230,8 +1229,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
firstMaxRankedType
- ? VectorType::get(firstMaxRankedType.getShape(), resultType,
- firstMaxRankedType.getScalableDims())
+ ? VectorType::get(resultType, firstMaxRankedType.getDims())
: resultType);
}
// d. Build and return the new op.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133f..dcadfbde627a8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() {
}
LogicalResult MultiDimReductionOp::verify() {
- SmallVector<int64_t> targetShape;
- SmallVector<bool> scalableDims;
+ SmallVector<VectorDim> targetDims;
Type inferredReturnType;
- auto sourceScalableDims = getSourceVectorType().getScalableDims();
- for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
- if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
- })) {
- targetShape.push_back(it.value());
- scalableDims.push_back(sourceScalableDims[it.index()]);
+ for (auto [idx, dim] : llvm::enumerate(getSourceVectorType().getDims()))
+ if (!llvm::any_of(getReductionDims().getValue(),
+ [idx = idx](Attribute attr) {
+ return llvm::cast<IntegerAttr>(attr).getValue() == idx;
+ })) {
+ targetDims.push_back(dim);
}
// TODO: update to also allow 0-d vectors when available.
- if (targetShape.empty())
+ if (targetDims.empty())
inferredReturnType = getSourceVectorType().getElementType();
else
- inferredReturnType = VectorType::get(
- targetShape, getSourceVectorType().getElementType(), scalableDims);
+ inferredReturnType =
+ VectorType::get(getSourceVectorType().getElementType(), targetDims);
if (getType() != inferredReturnType)
return emitOpError() << "destination type " << getType()
<< " is incompatible with source type "
@@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() {
/// Returns the mask type expected by this operation.
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getScalableDims());
+ return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getDims());
}
namespace {
@@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction
if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
if (mask) {
VectorType newMaskType =
- VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
- dstVecType.getScalableDims());
+ VectorType::get(rewriter.getI1Type(), dstVecType.getDims());
mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
}
cast = rewriter.create<vector::ShapeCastOp>(
@@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getScalableDims());
+ return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getDims());
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
- vectorType.getShape().drop_front(n), vectorType.getElementType(),
- vectorType.getScalableDims().drop_front(n)));
+ vectorType.getElementType(), vectorType.getDims().dropFront(n)));
}
return success();
}
@@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
VectorType resType;
if (vRHS) {
- SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
- vRHS.getScalableDims()[0]};
- resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType(), scalableDimsRes);
+ resType = VectorType::get(vLHS.getElementType(),
+ {vLHS.getDim(0), vRHS.getDim(0)});
} else {
// Scalar RHS operand
- SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
- resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
- scalableDimsRes);
+ resType = VectorType::get(vLHS.getElementType(), {vLHS.getDim(0)});
}
if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() {
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
- return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getScalableDims());
+ return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getDims());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6eda..941c00fd96b56 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -113,13 +113,11 @@ struct TransferReadPermutationLowering
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
// Apply the reverse transpose to deduce the type of the transfer_read.
- ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
- SmallVector<int64_t> newVectorShape(originalShape.size());
- ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
- SmallVector<bool> newScalableDims(originalShape.size());
+ auto originalDims = op.getVectorType().getDims();
+ SmallVector<VectorDim> newVectorDims(op.getVectorType().getRank(),
+ VectorDim::getFixed(0));
for (const auto &pos : llvm::enumerate(permutation)) {
- newVectorShape[pos.value()] = originalShape[pos.index()];
- newScalableDims[pos.value()] = originalScalableDims[pos.index()];
+ newVectorDims[pos.value()] = originalDims[pos.index()];
}
// Transpose in_bounds attribute.
@@ -129,8 +127,8 @@ struct TransferReadPermutationLowering
: ArrayAttr();
// Generate new transfer_read operation.
- VectorType newReadType = VectorType::get(
- newVectorShape, op.getVectorType().getElementType(), newScalableDims);
+ VectorType newReadType =
+ VectorType::get(op.getVectorType().getElementType(), newVectorDims);
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 4d43a76c4a4ef..0ede6239a9495 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -344,10 +344,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
+ (resType.getDims().front() == VectorDim::getFixed(1) ||
+ resType.getDims().back() == VectorDim::getFixed(1)) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a60..5f8a9c9e9915a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -25,24 +25,15 @@ using namespace mlir::vector;
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
- ArrayRef<int64_t> oldShape = oldType.getShape();
- ArrayRef<int64_t> newShape = oldShape;
-
- ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
- ArrayRef<bool> newScalableDims = oldScalableDims;
-
- while (!newShape.empty() && newShape.front() == 1 &&
- !newScalableDims.front()) {
- newShape = newShape.drop_front(1);
- newScalableDims = newScalableDims.drop_front(1);
- }
+ VectorDims oldDims = oldType.getDims();
+ VectorDims newDims = oldDims.dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
// Make sure we have at least 1 dimension per vector type requirements.
- if (newShape.empty()) {
- newShape = oldShape.take_back();
- newScalableDims = oldType.getScalableDims().take_back();
- }
- return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+ if (newDims.empty())
+ newDims = oldDims.takeBack(1);
+
+ return VectorType::get(oldType.getElementType(), newDims);
}
/// Return a smallVector of size `rank` containing all zeros.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b431..8eec25dbc04ba 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -316,15 +316,11 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
/// Trims non-scalable one dimensions from `oldType` and returns the result
/// type.
static VectorType trimNonScalableUnitDims(VectorType oldType) {
- SmallVector<int64_t> newShape;
- SmallVector<bool> newScalableDims;
- for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
- if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
- continue;
- newShape.push_back(dimSize);
- newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
- }
- return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+ auto newDims = llvm::to_vector(
+ llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) {
+ return dim != VectorDim::getFixed(1);
+ }));
+ return VectorType::get(oldType.getElementType(), newDims);
}
// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
@@ -337,9 +333,9 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
return failure();
SmallVector<Value> reducedOperands;
- for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
- type.getShape(), type.getScalableDims(), op.getOperands())) {
- if (dim == 1 && !dimIsScalable) {
+ for (auto [dim, operand] :
+ llvm::zip_equal(type.getDims(), op.getOperands())) {
+ if (dim == VectorDim::getFixed(1)) {
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1..b5bc167c1c53a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1039,10 +1039,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
Value mask = rewriter.create<vector::CreateMaskOp>(
- loc,
- VectorType::get(vtp.getShape(), rewriter.getI1Type(),
- vtp.getScalableDims()),
- b);
+ loc, VectorType::get(rewriter.getI1Type(), vtp.getDims()), b);
if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4b76dcf7f8a9f..6eb187b6101a3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2558,17 +2558,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
})
.Case<VectorType>([&](VectorType vectorTy) {
- auto scalableDims = vectorTy.getScalableDims();
os << "vector<";
- auto vShape = vectorTy.getShape();
- unsigned lastDim = vShape.size();
- unsigned dimIdx = 0;
- for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
- if (!scalableDims.empty() && scalableDims[dimIdx])
- os << '[';
- os << vShape[dimIdx];
- if (!scalableDims.empty() && scalableDims[dimIdx])
- os << ']';
+ auto dims = vectorTy.getDims();
+ if (!dims.empty()) {
+ llvm::interleave(dims, os, "x");
os << 'x';
}
printType(vectorTy.getElementType());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d452803..eac79ea34d655 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -250,10 +250,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getScalableDims());
+ return VectorType::get(scaledEt, getDims());
if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getScalableDims());
+ return VectorType::get(scaledEt, getDims());
return VectorType();
}
More information about the Mlir-commits
mailing list