[Mlir-commits] [mlir] [mlir] Add first-class support for scalability in VectorType dims (PR #74251)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Dec 7 07:06:53 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/74251

>From ca8250c41acf806f045562468f1da02e21c004b5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Sun, 3 Dec 2023 20:01:42 +0000
Subject: [PATCH 1/2] [mlir] Add first-class support for scalability in
 VectorType dims

Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.

For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:

```c++
  while (!newShape.empty() && newShape.front() == 1 &&
         !newScalableDims.front()) {
    newShape = newShape.drop_front(1);
    newScalableDims = newScalableDims.drop_front(1);
  }
```

Which would be wrong if you (more naturally) wrote it as:

```c++
  auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
    return dim == 1;
  });
```

As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.

This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```

These are backed by two new classes: `VectorDim` and `VectorDims`.

`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.

`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.

With these changes the previous example becomes:

```c++
  auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
    return dim == VectorDim::getFixed(1);
  });
```

Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
---
 mlir/include/mlir/IR/BuiltinTypes.h  | 234 +++++++++++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypes.td |  23 +++
 mlir/unittests/IR/ShapedTypeTest.cpp | 101 ++++++++++++
 3 files changed, 358 insertions(+)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829..2039316e6ba250 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+  explicit constexpr VectorDim(int64_t quantity, bool scalable)
+      : quantity(quantity), scalable(scalable){};
+
+  /// Constructs a new fixed dimension.
+  constexpr static VectorDim getFixed(int64_t quantity) {
+    return VectorDim(quantity, false);
+  }
+
+  /// Constructs a new scalable dimension.
+  constexpr static VectorDim getScalable(int64_t quantity) {
+    return VectorDim(quantity, true);
+  }
+
+  /// Returns true if this dimension is scalable;
+  constexpr bool isScalable() const { return scalable; }
+
+  /// Returns true if this dimension is fixed.
+  constexpr bool isFixed() const { return !isScalable(); }
+
+  /// Returns the minimum number of elements this dimension can contain.
+  constexpr int64_t getMinSize() const { return quantity; }
+
+  /// If this dimension is fixed returns the number of elements, otherwise
+  /// aborts.
+  constexpr int64_t getFixedSize() const {
+    assert(isFixed());
+    return quantity;
+  }
+
+  constexpr bool operator==(VectorDim const &dim) const {
+    return quantity == dim.quantity && scalable == dim.scalable;
+  }
+
+  constexpr bool operator!=(VectorDim const &dim) const {
+    return !(*this == dim);
+  }
+
+  /// Print the dim.
+  void print(raw_ostream &os) {
+    if (isScalable())
+      os << '[';
+    os << getMinSize();
+    if (isScalable())
+      os << ']';
+  }
+
+  /// Helper class for indexing into a list of sizes (and possibly empty) list
+  /// of scalable dimensions, extracting VectorDim elements.
+  struct Indexer {
+    explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+        : sizes(sizes), scalableDims(scalableDims) {
+      assert(
+          scalableDims.empty() ||
+          sizes.size() == scalableDims.size() &&
+              "expected `scalableDims` to be empty or match `sizes` in length");
+    }
+
+    VectorDim operator[](size_t idx) const {
+      int64_t size = sizes[idx];
+      bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+      return VectorDim(size, scalable);
+    }
+
+    ArrayRef<int64_t> sizes;
+    ArrayRef<bool> scalableDims;
+  };
+
+private:
+  int64_t quantity;
+  bool scalable;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) {
+  dim.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+  using VectorDim::Indexer::Indexer;
+
+  class Iterator : public llvm::iterator_facade_base<
+                       Iterator, std::random_access_iterator_tag, VectorDim,
+                       std::ptrdiff_t, VectorDim, VectorDim> {
+  public:
+    Iterator(VectorDim::Indexer indexer, size_t index)
+        : indexer(indexer), index(index){};
+
+    // Iterator boilerplate.
+    ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+    bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+    bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+    Iterator &operator+=(ptrdiff_t offset) {
+      index += offset;
+      return *this;
+    }
+    Iterator &operator-=(ptrdiff_t offset) {
+      index -= offset;
+      return *this;
+    }
+    VectorDim operator*() const { return indexer[index]; }
+
+    VectorDim::Indexer getIndexer() const { return indexer; }
+    ptrdiff_t getIndex() const { return index; }
+
+  private:
+    VectorDim::Indexer indexer;
+    ptrdiff_t index;
+  };
+
+  // Generic definitions.
+  using value_type = VectorDim;
+  using iterator = Iterator;
+  using const_iterator = Iterator;
+  using reverse_iterator = std::reverse_iterator<iterator>;
+  using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+  using size_type = size_t;
+  using difference_type = ptrdiff_t;
+
+  /// Construct from iterator pair.
+  VectorDims(Iterator begin, Iterator end)
+      : VectorDims(VectorDims(begin.getIndexer())
+                       .slice(begin.getIndex(), end - begin)) {}
+
+  VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+  Iterator begin() const { return Iterator(*this, 0); }
+  Iterator end() const { return Iterator(*this, size()); }
+
+  /// Check if the dims are empty.
+  bool empty() const { return sizes.empty(); }
+
+  /// Get the number of dims.
+  size_t size() const { return sizes.size(); }
+
+  /// Return the first dim.
+  VectorDim front() const { return (*this)[0]; }
+
+  /// Return the last dim.
+  VectorDim back() const { return (*this)[size() - 1]; }
+
+  /// Chop of thie first \p n dims, and keep the remaining \p m
+  /// dims.
+  VectorDims slice(size_t n, size_t m) const {
+    ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+    ArrayRef<bool> newScalableDims =
+        scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+    return VectorDims(newSizes, newScalableDims);
+  }
+
+  /// Drop the first \p n dims.
+  VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+  /// Drop the last \p n dims.
+  VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+  /// Return a copy of *this with only the first \p n elements.
+  VectorDims takeFront(size_t n = 1) const {
+    if (n >= size())
+      return *this;
+    return dropBack(size() - n);
+  }
+
+  /// Return a copy of *this with only the last \p n elements.
+  VectorDims takeBack(size_t n = 1) const {
+    if (n >= size())
+      return *this;
+    return dropFront(size() - n);
+  }
+
+  /// Return copy of *this with the first n dims matching the predicate removed.
+  template <class PredicateT>
+  VectorDims dropWhile(PredicateT predicate) const {
+    return VectorDims(llvm::find_if_not(*this, predicate), end());
+  }
+
+  /// Returns true if one or more of the dims are scalable.
+  bool hasScalableDims() const {
+    return llvm::is_contained(getScalableDims(), true);
+  }
+
+  /// Check for dim equality.
+  bool equals(VectorDims rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Check for dim equality.
+  bool equals(ArrayRef<VectorDim> rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Return the underlying sizes.
+  ArrayRef<int64_t> getSizes() const { return sizes; }
+
+  /// Return the underlying scalable dims.
+  ArrayRef<bool> getScalableDims() const { return scalableDims; }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return !(lhs == rhs);
+}
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32a..8835074efbc66e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1114,6 +1114,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         scalableDims = isScalableVec;
       }
       return $_get(elementType.getContext(), shape, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+      SmallVector<int64_t> sizes;
+      SmallVector<bool> scalableDims;
+      for (VectorDim dim : shape) {
+        sizes.push_back(dim.getMinSize());
+        scalableDims.push_back(dim.isScalable());
+      }
+      return get(sizes, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+      return get(shape.getSizes(), elementType, shape.getScalableDims());
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1121,6 +1133,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Returns the value of the specified dimension (including scalability).
+    VectorDim getDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid dim index for vector type");
+      return getDims()[idx];
+    }
+
+    /// Returns the dimensions of this vector type (including scalability).
+    VectorDims getDims() const {
+      return VectorDims(getShape(), getScalableDims());
+    }
+
     /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648c..07625da6ee8895 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+TEST(ShapedTypeTest, VectorDims) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+                              VectorDim::getFixed(8), VectorDim::getScalable(9),
+                              VectorDim::getFixed(1)};
+  VectorType vectorType = VectorType::get(f32, dims);
+
+  // Directly check values
+  {
+    auto dim0 = vectorType.getDim(0);
+    ASSERT_EQ(dim0.getMinSize(), 2);
+    ASSERT_TRUE(dim0.isFixed());
+
+    auto dim1 = vectorType.getDim(1);
+    ASSERT_EQ(dim1.getMinSize(), 4);
+    ASSERT_TRUE(dim1.isScalable());
+
+    auto dim2 = vectorType.getDim(2);
+    ASSERT_EQ(dim2.getMinSize(), 8);
+    ASSERT_TRUE(dim2.isFixed());
+
+    auto dim3 = vectorType.getDim(3);
+    ASSERT_EQ(dim3.getMinSize(), 9);
+    ASSERT_TRUE(dim3.isScalable());
+
+    auto dim4 = vectorType.getDim(4);
+    ASSERT_EQ(dim4.getMinSize(), 1);
+    ASSERT_TRUE(dim4.isFixed());
+  }
+
+  // Test indexing via getDim(idx)
+  {
+    for (unsigned i = 0; i < dims.size(); i++)
+      ASSERT_EQ(vectorType.getDim(i), dims[i]);
+  }
+
+  // Test using VectorDims::Iterator in for-each loop
+  {
+    unsigned i = 0;
+    for (VectorDim dim : vectorType.getDims())
+      ASSERT_EQ(dim, dims[i++]);
+    ASSERT_EQ(i, vectorType.getRank());
+  }
+
+  // Test using VectorDims::Iterator in LLVM iterator helper
+  {
+    for (auto [dim, expectedDim] :
+         llvm::zip_equal(vectorType.getDims(), dims)) {
+      ASSERT_EQ(dim, expectedDim);
+    }
+  }
+
+  // Test dropFront()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropFront();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+  }
+
+  // Test dropBack()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropBack();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i]);
+  }
+
+  // Test front()
+  { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+  // Test back()
+  { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+  // Test dropWhile.
+  {
+    SmallVector<VectorDim> dims{
+        VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+        VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+    VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+    ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+              unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+    // Drop leading unit dims.
+    auto withoutLeadingUnitDims =
+        vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+            [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+    SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+                                        VectorDim::getScalable(4)};
+    ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+  }
+}
+
 } // namespace

>From f69fad1049c744d38c41bf478d37ad219c05f038 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 7 Dec 2023 13:59:59 +0000
Subject: [PATCH 2/2] Demonstrate using the new APIs in scalable-aware code :)

This is not a complete change, this just updates a few examples found
by grepping for getScalableDims().
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  5 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  3 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 12 +--
 mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp  |  3 +-
 .../Linalg/Transforms/Vectorization.cpp       |  8 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 75 +++++++------------
 .../Vector/Transforms/LowerVectorTransfer.cpp | 14 ++--
 .../Transforms/LowerVectorTranspose.cpp       |  6 +-
 .../Transforms/VectorDropLeadUnitDim.cpp      | 23 ++----
 .../Transforms/VectorTransferOpTransforms.cpp | 20 ++---
 .../Vector/Transforms/VectorTransforms.cpp    |  5 +-
 mlir/lib/IR/AsmPrinter.cpp                    | 13 +---
 mlir/lib/IR/BuiltinTypes.cpp                  |  4 +-
 13 files changed, 68 insertions(+), 123 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53e..bfbe51d4e4e329 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -490,12 +490,11 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
     return {};
   if (type.getShape().empty())
     return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
-                                    type.getScalableDims().back());
+  Type vectorType = VectorType::get(elementType, type.getDims().takeBack());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
   // Only the trailing dimension can be scalable.
-  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
+  if (type.getDims().dropBack().hasScalableDims())
     return failure();
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index cd5df0be740b9c..881bddaf228f93 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -37,8 +37,7 @@ using namespace mlir::vector;
 // Helper to reduce vector type by *all* but one rank at back.
 static VectorType reducedVectorTypeBack(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         tp.getScalableDims().take_back());
+  return VectorType::get(tp.getElementType(), tp.getDims().takeBack());
 }
 
 // Helper that picks the proper sequence for inserting.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..4869cf304e75b1 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -319,12 +319,13 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
   auto vectorType = dyn_cast<VectorType>(type.getElementType());
   // Vectors with leading scalable dims are not supported.
   // It may be possible to support these in future by using dynamic memref dims.
-  if (vectorType.getScalableDims().front())
+  VectorDim leadingDim = vectorType.getDims().front();
+  if (leadingDim.isScalable())
     return failure();
   auto memrefShape = type.getShape();
   SmallVector<int64_t, 8> newMemrefShape;
   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
-  newMemrefShape.push_back(vectorType.getDimSize(0));
+  newMemrefShape.push_back(leadingDim.getFixedSize());
   return MemRefType::get(newMemrefShape,
                          VectorType::Builder(vectorType).dropDim(0));
 }
@@ -1091,18 +1092,17 @@ struct UnrollTransferReadConversion
     auto vecType = dyn_cast<VectorType>(vec.getType());
     auto xferVecType = xferOp.getVectorType();
 
-    if (xferVecType.getScalableDims()[0]) {
+    VectorDim dim = xferVecType.getDim(0);
+    if (dim.isScalable()) {
       // Cannot unroll a scalable dimension at compile time.
       return failure();
     }
 
     VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
 
-    int64_t dimSize = xferVecType.getShape()[0];
-
     // Generate fully unrolled loop of transfer ops.
     Location loc = xferOp.getLoc();
-    for (int64_t i = 0; i < dimSize; ++i) {
+    for (int64_t i = 0; i < dim.getFixedSize(); ++i) {
       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
 
       vec = generateInBoundsCheck(
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 594c9b4c270f21..fa49a21eafa14e 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -30,8 +30,7 @@ using namespace mlir::arm_sve;
 static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
   if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
-    return VectorType::get(sVectorType.getShape(), i1Type,
-                           sVectorType.getScalableDims());
+    return VectorType::get(i1Type, sVectorType.getDims());
   return nullptr;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..08a476c37b3f3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1217,9 +1217,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     assert(vecOperand && "Vector operand couldn't be found");
 
     if (firstMaxRankedType) {
-      auto vecType = VectorType::get(firstMaxRankedType.getShape(),
-                                     getElementTypeOrSelf(vecOperand.getType()),
-                                     firstMaxRankedType.getScalableDims());
+      auto vecType = VectorType::get(getElementTypeOrSelf(vecOperand.getType()),
+                                     firstMaxRankedType.getDims());
       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
     } else {
       vecOperands.push_back(vecOperand);
@@ -1230,8 +1229,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   for (Type resultType : op->getResultTypes()) {
     resultTypes.push_back(
         firstMaxRankedType
-            ? VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getScalableDims())
+            ? VectorType::get(resultType, firstMaxRankedType.getDims())
             : resultType);
   }
   //   d. Build and return the new op.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133fc..7c1af857800dc3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -422,23 +422,21 @@ MultiDimReductionOp::getShapeForUnroll() {
 }
 
 LogicalResult MultiDimReductionOp::verify() {
-  SmallVector<int64_t> targetShape;
-  SmallVector<bool> scalableDims;
+  SmallVector<VectorDim> targetDims;
   Type inferredReturnType;
-  auto sourceScalableDims = getSourceVectorType().getScalableDims();
-  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
-    if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
-        })) {
-      targetShape.push_back(it.value());
-      scalableDims.push_back(sourceScalableDims[it.index()]);
+  for (auto [idx, dim] : llvm::enumerate(getSourceVectorType().getDims()))
+    if (!llvm::any_of(getReductionDims().getValue(),
+                      [idx = idx](Attribute attr) {
+                        return llvm::cast<IntegerAttr>(attr).getValue() == idx;
+                      })) {
+      targetDims.push_back(dim);
     }
   // TODO: update to also allow 0-d vectors when available.
-  if (targetShape.empty())
+  if (targetDims.empty())
     inferredReturnType = getSourceVectorType().getElementType();
   else
-    inferredReturnType = VectorType::get(
-        targetShape, getSourceVectorType().getElementType(), scalableDims);
+    inferredReturnType =
+        VectorType::get(getSourceVectorType().getElementType(), targetDims);
   if (getType() != inferredReturnType)
     return emitOpError() << "destination type " << getType()
                          << " is incompatible with source type "
@@ -450,9 +448,8 @@ LogicalResult MultiDimReductionOp::verify() {
 /// Returns the mask type expected by this operation.
 Type MultiDimReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+  return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getDims());
 }
 
 namespace {
@@ -491,8 +488,7 @@ struct ElideUnitDimsInMultiDimReduction
     if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
       if (mask) {
         VectorType newMaskType =
-            VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
-                            dstVecType.getScalableDims());
+            VectorType::get(rewriter.getI1Type(), dstVecType.getDims());
         mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
       }
       cast = rewriter.create<vector::ShapeCastOp>(
@@ -559,9 +555,8 @@ LogicalResult ReductionOp::verify() {
 /// Returns the mask type expected by this operation.
 Type ReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+  return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getDims());
 }
 
 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -1252,8 +1247,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
     auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
                               vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
-        vectorType.getShape().drop_front(n), vectorType.getElementType(),
-        vectorType.getScalableDims().drop_front(n)));
+        vectorType.getElementType(), vectorType.getDims().dropFront(n)));
   }
   return success();
 }
@@ -3040,15 +3034,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
 
   VectorType resType;
   if (vRHS) {
-    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
-                                      vRHS.getScalableDims()[0]};
-    resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                              vLHS.getElementType(), scalableDimsRes);
+    resType = VectorType::get(vLHS.getElementType(),
+                              {vLHS.getDim(0), vRHS.getDim(0)});
   } else {
     // Scalar RHS operand
-    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
-    resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
-                              scalableDimsRes);
+    resType = VectorType::get(vLHS.getElementType(), {vLHS.getDim(0)});
   }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -3115,9 +3105,8 @@ LogicalResult OuterProductOp::verify() {
 /// verification purposes. It requires the operation to be vectorized."
 Type OuterProductOp::getExpectedMaskType() {
   auto vecType = this->getResultVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+  return VectorType::get(IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getDims());
 }
 
 //===----------------------------------------------------------------------===//
@@ -5064,25 +5053,13 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
 ///   vector<4x1x1xi1> --> vector<4x1>
 ///
 static VectorType trimTrailingOneDims(VectorType oldType) {
-  ArrayRef<int64_t> oldShape = oldType.getShape();
-  ArrayRef<int64_t> newShape = oldShape;
-
-  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
-  ArrayRef<bool> newScalableDims = oldScalableDims;
-
-  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
-    newShape = newShape.drop_back(1);
-    newScalableDims = newScalableDims.drop_back(1);
-  }
-
-  // Make sure we have at least 1 dimension.
+  // Note: This will always keep at least one dim (even if it's a unit dim).
   // TODO: Add support for 0-D vectors.
-  if (newShape.empty()) {
-    newShape = oldShape.take_back();
-    newScalableDims = oldScalableDims.take_back();
-  }
+  VectorDims newDims = oldType.getDims();
+  while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1))
+    newDims = newDims.dropBack();
 
-  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+  return VectorType::get(oldType.getElementType(), newDims);
 }
 
 /// Folds qualifying shape_cast(create_mask) into a new create_mask
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..941c00fd96b561 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -113,13 +113,11 @@ struct TransferReadPermutationLowering
     permutationMap = inversePermutation(permutationMap);
     AffineMap newMap = permutationMap.compose(map);
     // Apply the reverse transpose to deduce the type of the transfer_read.
-    ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
-    SmallVector<int64_t> newVectorShape(originalShape.size());
-    ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
-    SmallVector<bool> newScalableDims(originalShape.size());
+    auto originalDims = op.getVectorType().getDims();
+    SmallVector<VectorDim> newVectorDims(op.getVectorType().getRank(),
+                                         VectorDim::getFixed(0));
     for (const auto &pos : llvm::enumerate(permutation)) {
-      newVectorShape[pos.value()] = originalShape[pos.index()];
-      newScalableDims[pos.value()] = originalScalableDims[pos.index()];
+      newVectorDims[pos.value()] = originalDims[pos.index()];
     }
 
     // Transpose in_bounds attribute.
@@ -129,8 +127,8 @@ struct TransferReadPermutationLowering
                          : ArrayAttr();
 
     // Generate new transfer_read operation.
-    VectorType newReadType = VectorType::get(
-        newVectorShape, op.getVectorType().getElementType(), newScalableDims);
+    VectorType newReadType =
+        VectorType::get(op.getVectorType().getElementType(), newVectorDims);
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 4d43a76c4a4efc..0ede6239a9495b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -344,10 +344,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       // Source with leading unit dim (inverse) is also replaced. Unit dim must
       // be fixed. Non-unit can be scalable.
       if (resType.getRank() == 2 &&
-          ((resType.getShape().front() == 1 &&
-            !resType.getScalableDims().front()) ||
-           (resType.getShape().back() == 1 &&
-            !resType.getScalableDims().back())) &&
+          (resType.getDims().front() == VectorDim::getFixed(1) ||
+           resType.getDims().back() == VectorDim::getFixed(1)) &&
           transp == ArrayRef<int64_t>({1, 0})) {
         rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
         return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..5f8a9c9e9915a8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -25,24 +25,15 @@ using namespace mlir::vector;
 // Trims leading one dimensions from `oldType` and returns the result type.
 // Returns `vector<1xT>` if `oldType` only has one element.
 static VectorType trimLeadingOneDims(VectorType oldType) {
-  ArrayRef<int64_t> oldShape = oldType.getShape();
-  ArrayRef<int64_t> newShape = oldShape;
-
-  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
-  ArrayRef<bool> newScalableDims = oldScalableDims;
-
-  while (!newShape.empty() && newShape.front() == 1 &&
-         !newScalableDims.front()) {
-    newShape = newShape.drop_front(1);
-    newScalableDims = newScalableDims.drop_front(1);
-  }
+  VectorDims oldDims = oldType.getDims();
+  VectorDims newDims = oldDims.dropWhile(
+      [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
 
   // Make sure we have at least 1 dimension per vector type requirements.
-  if (newShape.empty()) {
-    newShape = oldShape.take_back();
-    newScalableDims = oldType.getScalableDims().take_back();
-  }
-  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+  if (newDims.empty())
+    newDims = oldDims.takeBack(1);
+
+  return VectorType::get(oldType.getElementType(), newDims);
 }
 
 /// Return a smallVector of size `rank` containing all zeros.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ed42e6508b4310..8eec25dbc04bad 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -316,15 +316,11 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
 /// Trims non-scalable one dimensions from `oldType` and returns the result
 /// type.
 static VectorType trimNonScalableUnitDims(VectorType oldType) {
-  SmallVector<int64_t> newShape;
-  SmallVector<bool> newScalableDims;
-  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
-    if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
-      continue;
-    newShape.push_back(dimSize);
-    newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
-  }
-  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+  auto newDims = llvm::to_vector(
+      llvm::make_filter_range(oldType.getDims(), [](VectorDim dim) {
+        return dim != VectorDim::getFixed(1);
+      }));
+  return VectorType::get(oldType.getElementType(), newDims);
 }
 
 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
@@ -337,9 +333,9 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
     return failure();
 
   SmallVector<Value> reducedOperands;
-  for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
-           type.getShape(), type.getScalableDims(), op.getOperands())) {
-    if (dim == 1 && !dimIsScalable) {
+  for (auto [dim, operand] :
+       llvm::zip_equal(type.getDims(), op.getOperands())) {
+    if (dim == VectorDim::getFixed(1)) {
       // If the mask for the unit dim is not a constant of 1, do nothing.
       auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
       if (!constant || (constant.value() != 1))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6e7fab293d3a1c..b5bc167c1c53ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1039,10 +1039,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
     Value mask = rewriter.create<vector::CreateMaskOp>(
-        loc,
-        VectorType::get(vtp.getShape(), rewriter.getI1Type(),
-                        vtp.getScalableDims()),
-        b);
+        loc, VectorType::get(rewriter.getI1Type(), vtp.getDims()), b);
     if (xferOp.getMask()) {
       // Intersect the in-bounds with the mask specified as an op parameter.
       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4b76dcf7f8a9f7..6eb187b6101a37 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2558,17 +2558,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         }
       })
       .Case<VectorType>([&](VectorType vectorTy) {
-        auto scalableDims = vectorTy.getScalableDims();
         os << "vector<";
-        auto vShape = vectorTy.getShape();
-        unsigned lastDim = vShape.size();
-        unsigned dimIdx = 0;
-        for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
-          if (!scalableDims.empty() && scalableDims[dimIdx])
-            os << '[';
-          os << vShape[dimIdx];
-          if (!scalableDims.empty() && scalableDims[dimIdx])
-            os << ']';
+        auto dims = vectorTy.getDims();
+        if (!dims.empty()) {
+          llvm::interleave(dims, os, "x");
           os << 'x';
         }
         printType(vectorTy.getElementType());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d4528035..eac79ea34d655c 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -250,10 +250,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
     return VectorType();
   if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getScalableDims());
+      return VectorType::get(scaledEt, getDims());
   if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getScalableDims());
+      return VectorType::get(scaledEt, getDims());
   return VectorType();
 }
 



More information about the Mlir-commits mailing list