[Mlir-commits] [mlir] [mlir] Add `ScalableVectorType` support class (PR #96236)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Jun 20 13:52:09 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/96236
This adds a pseudo-type that wraps a VectorType that aims to provide safe APIs for working with scalable vectors. Slightly contrary to the name, this class can represent both fixed and scalable vectors, however, if you are only dealing with fixed vectors the plain VectorType is likely more convenient.
The main difference from the regular VectorType is that vector dimensions are _not_ represented as `int64_t`, which does not allow encoding the scalability into the dimension. Instead, vector dimensions are represented by a VectorDim class. A VectorDim stores both the size and scalability of a dimension. This makes common errors like only checking the size (but not the scalability) impossible (without being explicit with your intention).
To make this convenient to work with there is VectorDimList which provides ArrayRef-like helper methods along with an iterator for VectorDims.
ScalableVectorType can freely converted to VectorType (and vice versa), though there are two main ways to acquire a ScalableVectorType.
Assignment:
This does not check the scalability of `myVectorType`. This is valid and the helpers on ScalableVectorType will function as normal.
```c++
VectorType myVectorType = ...;
ScalableVectorType scalableVector = myVectorType;
```
Casting:
This checks the scalability of myVectorType. In this case, `scalableVector` will be falsy if `myVectorType` contains no scalable dims.
```c++
VectorType myVectorType = ...;
auto scalableVector = dyn_cast<ScalableVectorType>(myVectorType);
```
Note: The use of this class is entirely optional! It only aims to make writing scalable-aware patterns safer and easier.
>From 5e3038a18265028315cbd0d319e830606571c9a9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 20:09:25 +0000
Subject: [PATCH 1/3] [mlir] Add `ScalableVectorType` support class
This adds a pseudo-type that wraps a VectorType that aims to provide
safe APIs for working with scalable vectors. Slightly contrary to the
name this class can represent both fixed and scalable vectors, however,
if you only are always dealing with fixed vectors the plain VectorType
is likely more convenient.
The main difference from the regular VectorType is that vector
dimensions are _not_ represented as `int64_t`, which does not allow
encoding the scalability into the dimension. Instead, vector dimensions
are represented by a VectorDim class. A VectorDim stores both the size a
nd scalability of a dimension. Makes common errors like only checking
the size (but not the scalability) impossible (without being explicit
with your intention).
To make this convenient to work with there is VectorDimList which
provides ArrayRef-like helper methods along with an iterator for
VectorDims.
ScalableVectorType can freely converted to VectorType (and vice versa),
though there are two main ways to acquire a ScalableVectorType.
Assignment:
This does not check the scalability of myVectorType. This is valid and
the helpers on ScalableVectorType will function as normal.
```c++
VectorType myVectorType = ...;
ScalableVectorType scalableVector = myVectorType;
```
Casting:
This checks the scalability of myVectorType. In the case scalableVector
will be falsy if myVectorType contains no scalable dims.
```c++
VectorType myVectorType = ...;
auto scalableVector = dyn_cast<ScalableVectorType>(myVectorType);
```
Note: The use of this class is entirely optional! It only aims to make
writing scalable-aware patterns safer and easier.
---
.../include/mlir/Support/ScalableVectorType.h | 360 ++++++++++++++++++
1 file changed, 360 insertions(+)
create mode 100644 mlir/include/mlir/Support/ScalableVectorType.h
diff --git a/mlir/include/mlir/Support/ScalableVectorType.h b/mlir/include/mlir/Support/ScalableVectorType.h
new file mode 100644
index 0000000000000..0fa7716ea2bcb
--- /dev/null
+++ b/mlir/include/mlir/Support/ScalableVectorType.h
@@ -0,0 +1,360 @@
+//===- ScalableVectorType.h - Scalable Vector Helpers -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_SCALABLEVECTORTYPE_H
+#define MLIR_SUPPORT_SCALABLEVECTORTYPE_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable) {};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Print the dim.
+ void print(raw_ostream &os) {
+ if (isScalable())
+ os << '[';
+ os << getMinSize();
+ if (isScalable())
+ os << ']';
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDim elements.
+ struct Indexer {
+ explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims) {
+ assert(
+ scalableDims.empty() ||
+ sizes.size() == scalableDims.size() &&
+ "expected `scalableDims` to be empty or match `sizes` in length");
+ }
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) {
+ dim.print(os);
+ return os;
+}
+
+//===----------------------------------------------------------------------===//
+// VectorDimList
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDimList : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index) {};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ // Generic definitions.
+ using value_type = VectorDim;
+ using iterator = Iterator;
+ using const_iterator = Iterator;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+ using size_type = size_t;
+ using difference_type = ptrdiff_t;
+
+ /// Construct from iterator pair.
+ VectorDimList(Iterator begin, Iterator end)
+ : VectorDimList(VectorDimList(begin.getIndexer())
+ .slice(begin.getIndex(), end - begin)) {}
+
+ VectorDimList(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer) {};
+
+ /// Construct from a VectorType.
+ static VectorDimList from(VectorType vectorType) {
+ if (!vectorType)
+ return VectorDimList({}, {});
+ return VectorDimList(vectorType.getShape(), vectorType.getScalableDims());
+ }
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDimList slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDimList(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDimList dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDimList dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return a copy of *this with only the first \p n elements.
+ VectorDimList takeFront(size_t n = 1) const {
+ if (n >= size())
+ return *this;
+ return dropBack(size() - n);
+ }
+
+ /// Return a copy of *this with only the last \p n elements.
+ VectorDimList takeBack(size_t n = 1) const {
+ if (n >= size())
+ return *this;
+ return dropFront(size() - n);
+ }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDimList dropWhile(PredicateT predicate) const {
+ return VectorDimList(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Returns true if one or more of the dims are scalable.
+ bool hasScalableDims() const {
+ return llvm::is_contained(getScalableDims(), true);
+ }
+
+ /// Check for dim equality.
+ bool equals(VectorDimList rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+};
+
+inline bool operator==(VectorDimList lhs, VectorDimList rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDimList lhs, VectorDimList rhs) {
+ return !(lhs == rhs);
+}
+
+inline bool operator==(VectorDimList lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDimList lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
+//===----------------------------------------------------------------------===//
+// ScalableVectorType
+//===----------------------------------------------------------------------===//
+
+/// A pseudo-type that wraps a VectorType that aims to provide safe APIs for
+/// working with scalable vectors. Slightly contrary to the name this class can
+/// represent both fixed and scalable vectors, however, if you only are always
+/// dealing with fixed vectors the plain VectorType is likely more convenient.
+///
+/// The main difference from the regular VectorType is that vector dimensions
+/// are _not_ represented as `int64_t`, which does not allow encoding the
+/// scalability into the dimension. Instead, vector dimensions are represented
+/// by a VectorDim class. A VectorDim stores both the size and scalability of a
+/// dimension. Makes common errors like only checking the size (but not the
+/// scalability) impossible (without being explicit with your intention).
+///
+/// To make this convenient to work with there's VectorDimList provides
+/// ArrayRef-like helper methods along with an iterator for VectorDims.
+///
+/// ScalableVectorType and VectorType can be freely converted between. However,
+/// there is one thing to note:
+///
+/// Assignment from a VectorType always succeeds (scalability is checked):
+/// ```
+/// VectorType someVectorType = ...;
+/// ScalableVectorType vector = someVectorType;
+/// ```
+///
+/// Casting from a Type/VectorType via dyn_cast (or cast) checks scalability:
+/// ```
+/// if (auto scalableVector = dyn_cast<ScalableVectorType>(someVectorType)) {
+/// <vector type has scalable dims>
+/// }
+/// ```
+class ScalableVectorType {
+public:
+ using Dim = VectorDim;
+ using DimList = VectorDimList;
+
+ ScalableVectorType(VectorType vectorType) : vectorType(vectorType) {};
+
+ /// Construct a new ScalableVectorType.
+ static ScalableVectorType get(DimList shape, Type elementType) {
+ return VectorType::get(shape.getSizes(), elementType,
+ shape.getScalableDims());
+ }
+
+ /// Construct a new ScalableVectorType.
+ static ScalableVectorType get(ArrayRef<Dim> shape, Type elementType) {
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ sizes.reserve(shape.size());
+ scalableDims.reserve(shape.size());
+ for (Dim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return VectorType::get(sizes, elementType, scalableDims);
+ }
+
+ inline static bool classof(Type type) {
+ auto vectorType = dyn_cast_if_present<VectorType>(type);
+ return vectorType && vectorType.isScalable();
+ }
+
+ /// Returns the value of the specified dimension (including scalability).
+ Dim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ DimList getDims() const {
+ return DimList(vectorType.getShape(), vectorType.getScalableDims());
+ }
+
+ /// Returns the rank of this vector type.
+ int64_t getRank() const { return vectorType.getRank(); }
+
+ /// Returns true if the vector contains scalable dimensions.
+ bool isScalable() const { return vectorType.isScalable(); }
+ bool allDimsScalable() const { return vectorType.allDimsScalable(); }
+
+ /// Returns the element type of this vector type.
+ Type getElementType() const { return vectorType.getElementType(); }
+
+ /// Clones this vector type with a new element type.
+ ScalableVectorType clone(Type elementType) {
+ return vectorType.clone(elementType);
+ }
+
+ operator VectorType() const { return vectorType; }
+
+ explicit operator bool() const { return bool(vectorType); }
+
+private:
+ VectorType vectorType;
+};
+
+} // namespace mlir
+
+#endif
>From b60dfbc347eac29a18d992667a3e2b9422c84fa5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 20:27:04 +0000
Subject: [PATCH 2/3] [mlir][test] Add basic unit test for `ScalableVectorType`
---
mlir/unittests/Support/CMakeLists.txt | 3 +-
.../Support/ScalableVectorTypeTest.cpp | 76 +++++++++++++++++++
2 files changed, 78 insertions(+), 1 deletion(-)
create mode 100644 mlir/unittests/Support/ScalableVectorTypeTest.cpp
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..680572bbb0cbe 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,7 +1,8 @@
add_mlir_unittest(MLIRSupportTests
IndentedOstreamTest.cpp
StorageUniquerTest.cpp
+ ScalableVectorTypeTest.cpp
)
target_link_libraries(MLIRSupportTests
- PRIVATE MLIRSupport)
+ PRIVATE MLIRSupport MLIRIR)
diff --git a/mlir/unittests/Support/ScalableVectorTypeTest.cpp b/mlir/unittests/Support/ScalableVectorTypeTest.cpp
new file mode 100644
index 0000000000000..5f0237ac68414
--- /dev/null
+++ b/mlir/unittests/Support/ScalableVectorTypeTest.cpp
@@ -0,0 +1,76 @@
+//===- ScalableVectorTypeTest.cpp - ScalableVectorType Tests --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/ScalableVectorType.h"
+#include "mlir/IR/Dialect.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+TEST(ScalableVectorTypeTest, TestVectorDim) {
+ auto fixedDim = VectorDim::getFixed(4);
+ ASSERT_FALSE(fixedDim.isScalable());
+ ASSERT_TRUE(fixedDim.isFixed());
+ ASSERT_EQ(fixedDim.getFixedSize(), 4);
+
+ auto scalableDim = VectorDim::getScalable(8);
+ ASSERT_TRUE(scalableDim.isScalable());
+ ASSERT_FALSE(scalableDim.isFixed());
+ ASSERT_EQ(scalableDim.getMinSize(), 8);
+}
+
+TEST(ScalableVectorTypeTest, BasicFunctionality) {
+ MLIRContext context;
+
+ Type f32 = FloatType::getF32(&context);
+
+ // Construct n-D scalable vector.
+ VectorType scalableVector = ScalableVectorType::get(
+ {VectorDim::getFixed(1), VectorDim::getFixed(2),
+ VectorDim::getScalable(3), VectorDim::getFixed(4),
+ VectorDim::getScalable(5)},
+ f32);
+ // Construct fixed vector.
+ VectorType fixedVector = ScalableVectorType::get(VectorDim::getFixed(1), f32);
+
+ // Check casts.
+ ASSERT_TRUE(isa<ScalableVectorType>(scalableVector));
+ ASSERT_FALSE(isa<ScalableVectorType>(fixedVector));
+ ASSERT_FALSE(VectorDimList::from(fixedVector).hasScalableDims());
+
+ // Check rank/size.
+ auto vType = cast<ScalableVectorType>(scalableVector);
+ ASSERT_EQ(vType.getDims().size(), unsigned(scalableVector.getRank()));
+ ASSERT_TRUE(vType.getDims().hasScalableDims());
+
+ // Check iterating over dimensions.
+ std::array expectedDims{VectorDim::getFixed(1), VectorDim::getFixed(2),
+ VectorDim::getScalable(3), VectorDim::getFixed(4),
+ VectorDim::getScalable(5)};
+ unsigned i = 0;
+ for (VectorDim dim : vType.getDims()) {
+ ASSERT_EQ(dim, expectedDims[i]);
+ i++;
+ }
+}
+
+TEST(ScalableVectorTypeTest, VectorDimListHelpers) {
+ std::array<int64_t, 4> sizes{42, 10, 3, 1};
+ std::array<bool, 4> scalableFlags{false, true, false, true};
+
+ // Manually construct from sizes + flags.
+ VectorDimList dimList(sizes, scalableFlags);
+
+ ASSERT_EQ(dimList.size(), 4U);
+
+ ASSERT_EQ(dimList.front(), VectorDim::getFixed(42));
+ ASSERT_EQ(dimList.back(), VectorDim::getScalable(1));
+
+ std::array innerDims{VectorDim::getScalable(10), VectorDim::getFixed(3)};
+ ASSERT_EQ(dimList.slice(1, 2), innerDims);
+}
>From e29097fbe125ec52e6f462c95d04e6c85455a0d9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 20:28:59 +0000
Subject: [PATCH 3/3] Demonstrable using `ScalableVectorType` and `VectorDim`s
This updates a few places to make use of the new support classes. This
hopefully shows (at least a little) how these classes make scalability
easier.
---
.../Transforms/PolynomialApproximation.cpp | 57 +++++++----------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 63 +++++++------------
.../Vector/Transforms/LowerVectorTransfer.cpp | 17 +++--
.../Transforms/LowerVectorTranspose.cpp | 12 ++--
.../Transforms/VectorDropLeadUnitDim.cpp | 25 +++-----
.../Transforms/VectorTransferOpTransforms.cpp | 25 ++++----
.../Vector/Transforms/VectorTransforms.cpp | 11 ++--
mlir/lib/IR/AsmPrinter.cpp | 14 ++---
8 files changed, 87 insertions(+), 137 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f4fae68da63b3..7c694ca7d55c8 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/ScalableVectorType.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
@@ -39,24 +40,14 @@ using namespace mlir;
using namespace mlir::math;
using namespace mlir::vector;
-// Helper to encapsulate a vector's shape (including scalable dims).
-struct VectorShape {
- ArrayRef<int64_t> sizes;
- ArrayRef<bool> scalableFlags;
-
- bool empty() const { return sizes.empty(); }
-};
-
// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
-static VectorShape vectorShape(Type type) {
+static VectorDimList vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
- return vectorType
- ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
- : VectorShape{};
+ return VectorDimList::from(vectorType);
}
-static VectorShape vectorShape(Value value) {
+static VectorDimList vectorShape(Value value) {
return vectorShape(value.getType());
}
@@ -65,16 +56,14 @@ static VectorShape vectorShape(Value value) {
//----------------------------------------------------------------------------//
// Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, VectorShape shape) {
+static Type broadcast(Type type, VectorDimList shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return !shape.empty()
- ? VectorType::get(shape.sizes, type, shape.scalableFlags)
- : type;
+ return !shape.empty() ? ScalableVectorType::get(shape, type) : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
- VectorShape shape) {
+ VectorDimList shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -227,7 +216,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool isPositive = false) {
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
- VectorShape shape = vectorShape(arg);
+ VectorDimList shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -267,7 +256,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
- VectorShape shape = vectorShape(arg);
+ VectorDimList shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -293,7 +282,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
Type elementType = getElementTypeOrSelf(x);
assert((elementType.isF32() || elementType.isF16()) &&
"x must be f32 or f16 type");
- VectorShape shape = vectorShape(x);
+ VectorDimList shape = vectorShape(x);
if (coeffs.empty())
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -391,7 +380,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
if (!getElementTypeOrSelf(operand).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
Value abs = builder.create<math::AbsFOp>(operand);
@@ -490,7 +479,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- VectorShape shape = vectorShape(op.getResult());
+ VectorDimList shape = vectorShape(op.getResult());
// Compute atan in the valid range.
auto div = builder.create<arith::DivFOp>(y, x);
@@ -556,7 +545,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -644,7 +633,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -791,7 +780,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -846,7 +835,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ VectorDimList shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -910,7 +899,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ VectorDimList shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -988,7 +977,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ VectorDimList shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1097,7 +1086,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
namespace {
-Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
+Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorDimList shape,
Value value, float lowerBound, float upperBound) {
assert(!std::isnan(lowerBound));
assert(!std::isnan(upperBound));
@@ -1289,7 +1278,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1359,7 +1348,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1486,7 +1475,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- VectorShape shape = vectorShape(operand);
+ VectorDimList shape = vectorShape(operand);
Type floatTy = getElementTypeOrSelf(operand.getType());
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1575,7 +1564,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ VectorDimList shape = vectorShape(op.getOperand());
// Only support already-vectorized rsqrt's.
if (shape.empty() || shape.sizes.back() % 8 != 0)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..e2ce56e9e188a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -35,6 +35,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/ScalableVectorType.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -463,23 +464,22 @@ MultiDimReductionOp::getShapeForUnroll() {
}
LogicalResult MultiDimReductionOp::verify() {
- SmallVector<int64_t> targetShape;
- SmallVector<bool> scalableDims;
+ SmallVector<VectorDim> targetDims;
Type inferredReturnType;
- auto sourceScalableDims = getSourceVectorType().getScalableDims();
- for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
- if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
- })) {
- targetShape.push_back(it.value());
- scalableDims.push_back(sourceScalableDims[it.index()]);
+ auto sourceDims = VectorDimList::from(getSourceVectorType());
+ for (auto [idx, dim] : llvm::enumerate(sourceDims))
+ if (!llvm::any_of(getReductionDims().getValue(),
+ [idx = idx](Attribute attr) {
+ return llvm::cast<IntegerAttr>(attr).getValue() == idx;
+ })) {
+ targetDims.push_back(dim);
}
// TODO: update to also allow 0-d vectors when available.
- if (targetShape.empty())
+ if (targetDims.empty())
inferredReturnType = getSourceVectorType().getElementType();
else
- inferredReturnType = VectorType::get(
- targetShape, getSourceVectorType().getElementType(), scalableDims);
+ inferredReturnType = ScalableVectorType::get(
+ targetDims, getSourceVectorType().getElementType());
if (getType() != inferredReturnType)
return emitOpError() << "destination type " << getType()
<< " is incompatible with source type "
@@ -3247,23 +3247,19 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
if (operandsInfo.size() < 2)
return parser.emitError(parser.getNameLoc(),
"expected at least 2 operands");
- VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
- VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
+ ScalableVectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
+ ScalableVectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
if (!vLHS)
return parser.emitError(parser.getNameLoc(),
"expected vector type for operand #1");
VectorType resType;
if (vRHS) {
- SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
- vRHS.getScalableDims()[0]};
- resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType(), scalableDimsRes);
+ resType = ScalableVectorType::get({vLHS.getDim(0), vRHS.getDim(0)},
+ vLHS.getElementType());
} else {
// Scalar RHS operand
- SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
- resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
- scalableDimsRes);
+ resType = ScalableVectorType::get(vLHS.getDim(0), vLHS.getElementType());
}
if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -5308,26 +5304,11 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
///
/// vector<4x1x1xi1> --> vector<4x1>
///
-static VectorType trimTrailingOneDims(VectorType oldType) {
- ArrayRef<int64_t> oldShape = oldType.getShape();
- ArrayRef<int64_t> newShape = oldShape;
-
- ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
- ArrayRef<bool> newScalableDims = oldScalableDims;
-
- while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
- newShape = newShape.drop_back(1);
- newScalableDims = newScalableDims.drop_back(1);
- }
-
- // Make sure we have at least 1 dimension.
- // TODO: Add support for 0-D vectors.
- if (newShape.empty()) {
- newShape = oldShape.take_back();
- newScalableDims = oldScalableDims.take_back();
- }
-
- return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+static ScalableVectorType trimTrailingOneDims(ScalableVectorType oldType) {
+ VectorDimList newDims = oldType.getDims();
+ while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1))
+ newDims = newDims.dropBack();
+ return ScalableVectorType::get(newDims, oldType.getElementType());
}
/// Folds qualifying shape_cast(create_mask) into a new create_mask
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index c31c51489ecc9..b4dd274914edb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/ScalableVectorType.h"
using namespace mlir;
using namespace mlir::vector;
@@ -122,13 +123,11 @@ struct TransferReadPermutationLowering
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
// Apply the reverse transpose to deduce the type of the transfer_read.
- ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
- SmallVector<int64_t> newVectorShape(originalShape.size());
- ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
- SmallVector<bool> newScalableDims(originalShape.size());
- for (const auto &pos : llvm::enumerate(permutation)) {
- newVectorShape[pos.value()] = originalShape[pos.index()];
- newScalableDims[pos.value()] = originalScalableDims[pos.index()];
+ auto originalDims = VectorDimList::from(op.getVectorType());
+ SmallVector<VectorDim> newDims(op.getVectorType().getRank(),
+ VectorDim::getFixed(0));
+ for (auto [originalIdx, newIdx] : llvm::enumerate(permutation)) {
+ newDims[newIdx] = originalDims[originalIdx];
}
// Transpose in_bounds attribute.
@@ -138,8 +137,8 @@ struct TransferReadPermutationLowering
: ArrayAttr();
// Generate new transfer_read operation.
- VectorType newReadType = VectorType::get(
- newVectorShape, op.getVectorType().getElementType(), newScalableDims);
+ VectorType newReadType =
+ ScalableVectorType::get(newDims, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index ca8a6f6d82a6e..fa259d7bf1449 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/ScalableVectorType.h"
#define DEBUG_TYPE "lower-vector-transpose"
@@ -432,18 +433,17 @@ class Transpose2DWithUnitDimToShapeCast
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
- VectorType resType = op.getResultVectorType();
+ ScalableVectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
+ (resType.getDims().front() == VectorDim::getFixed(1) ||
+ resType.getDims().back() == VectorDim::getFixed(1)) &&
transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, Type(resType),
+ input);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 7ed3dea42b771..ca61fe81a3d16 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/ScalableVectorType.h"
#define DEBUG_TYPE "vector-drop-unit-dim"
@@ -24,25 +25,15 @@ using namespace mlir::vector;
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
-static VectorType trimLeadingOneDims(VectorType oldType) {
- ArrayRef<int64_t> oldShape = oldType.getShape();
- ArrayRef<int64_t> newShape = oldShape;
-
- ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
- ArrayRef<bool> newScalableDims = oldScalableDims;
-
- while (!newShape.empty() && newShape.front() == 1 &&
- !newScalableDims.front()) {
- newShape = newShape.drop_front(1);
- newScalableDims = newScalableDims.drop_front(1);
- }
+static ScalableVectorType trimLeadingOneDims(ScalableVectorType oldType) {
+ VectorDimList oldDims = oldType.getDims();
+ VectorDimList newDims = oldDims.dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
// Make sure we have at least 1 dimension per vector type requirements.
- if (newShape.empty()) {
- newShape = oldShape.take_back();
- newScalableDims = oldType.getScalableDims().take_back();
- }
- return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+ if (newDims.empty())
+ newDims = oldDims.takeBack(1);
+ return ScalableVectorType::get(newDims, oldType.getElementType());
}
/// Return a smallVector of size `rank` containing all zeros.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..18b6882ac37f2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/ScalableVectorType.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
@@ -312,31 +313,27 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
/// Trims non-scalable one dimensions from `oldType` and returns the result
/// type.
-static VectorType trimNonScalableUnitDims(VectorType oldType) {
- SmallVector<int64_t> newShape;
- SmallVector<bool> newScalableDims;
- for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
- if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
- continue;
- newShape.push_back(dimSize);
- newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
- }
- return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+static ScalableVectorType trimNonScalableUnitDims(ScalableVectorType oldType) {
+ auto newDims = llvm::to_vector(
+ llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) {
+ return dim != VectorDim::getFixed(1);
+ }));
+ return ScalableVectorType::get(newDims, oldType.getElementType());
}
// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
static FailureOr<Value>
createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
vector::CreateMaskOp op) {
- auto type = op.getType();
+ ScalableVectorType type = op.getType();
VectorType reducedType = trimNonScalableUnitDims(type);
if (reducedType.getRank() == type.getRank())
return failure();
SmallVector<Value> reducedOperands;
- for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
- type.getShape(), type.getScalableDims(), op.getOperands())) {
- if (dim == 1 && !dimIsScalable) {
+ for (auto [dim, operand] :
+ llvm::zip_equal(type.getDims(), op.getOperands())) {
+ if (dim == VectorDim::getFixed(1)) {
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..2ef0659d83326 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -39,6 +39,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/ScalableVectorType.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
@@ -1231,16 +1232,13 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
/// Example 2: it returns "1" if `srcType` is the same memref type with
/// [8192, 16, 8, 1] strides.
static FailureOr<size_t>
-getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+getTransferFoldableInnerUnitDims(MemRefType srcType,
+ ScalableVectorType vectorType) {
SmallVector<int64_t> srcStrides;
int64_t srcOffset;
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
- auto isUnitDim = [](VectorType type, int dim) {
- return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
- };
-
// According to vector.transfer_read/write semantics, the vector can be a
// slice. Thus, we have to offset the check index with `rankDiff` in
// `srcStrides` and source dim sizes.
@@ -1251,7 +1249,8 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
// It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] != 1 ||
- srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
+ srcType.getDimSize(dim + rankDiff) != 1 ||
+ vectorType.getDim(dim) != VectorDim::getFixed(1))
break;
result++;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2c43a6f15aa83..ef1b1812de7c7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/Support/ScalableVectorType.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
@@ -2607,17 +2608,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
})
.Case<VectorType>([&](VectorType vectorTy) {
- auto scalableDims = vectorTy.getScalableDims();
os << "vector<";
- auto vShape = vectorTy.getShape();
- unsigned lastDim = vShape.size();
- unsigned dimIdx = 0;
- for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
- if (!scalableDims.empty() && scalableDims[dimIdx])
- os << '[';
- os << vShape[dimIdx];
- if (!scalableDims.empty() && scalableDims[dimIdx])
- os << ']';
+ auto dims = VectorDimList::from(vectorTy);
+ if (!dims.empty()) {
+ llvm::interleave(dims, os, "x");
os << 'x';
}
printType(vectorTy.getElementType());
More information about the Mlir-commits
mailing list