[Mlir-commits] [mlir] [mlir][ProofOfConcept] Support scalable dimensions in ShapedTypes (PR #65341)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Sep 5 08:59:00 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/65341:

This is a proof of concept for: 
---

This patch adds support for scalable dims to ShapedTypes, it does this by making use of negative sizes for dims. With this:

- Fixed/static dimensions are positive values (>= 0)
- kDynamic stays the same (std::numeric_limits<int64_t>::min())
- Scalable dimensions are negative values (e.g. [8] is -8)

Associated methods/helpers to check for these dimensions have been added to ShapedType. So a shape is considered:

- Dynamic if any dimension is dynamic
- Scalable if any dimension is scalable (and no dimension is dynamic)
- Static/fixed if it is not dynamic or scalable

Treating these dimensions as integers or sizes could be error-prone, so instead, it is recommended to convert them to a ShapeDim before querying the size. Currently, int64_ts can be implicitly converted to/from an mlir::ShapeDim, however, it's hoped that it can eventually be made an explicit conversion with ShapeDim used everywhere.

The ShapeDim class (which simply wraps an int64_t) has methods to check for fixed, dynamic, or scalable dimensions with methods to query:

- The fixedSize(); if this dimension isFixed() return the size, abort otherwise
- The scalableSize(); if this dimension isScalable() return the minimum size (i.e. the size when the runtime multiplier is 1), abort otherwise
- The minSize(); the minimum size of a dimension, 0 for dynamic, scalableSize() for scalable, fixedSize() otherwise

With this, the array of booleans currently added to the VectorType to mark which dimensions are scalable can be removed. This makes the vector type act like any other ShapedType. This removes a lot of the special cases needed to handle scalable dims. Doing this, fixes several issues with misinterpreting or dropping scalable dims for vector types.

For example, now `vectorType.getDim(0) == 1` is only true if dim 0 is a unit dimension, and not a scalable dimension [1].

And:

  VectorType::get(
    vectorType.getShape().drop_front(), vectorType.getElementType())

will correctly drop the leading dimension and preserve all scalable dimensions.

>From 2133111da2a4a01fe35e2560d3157b322395d72f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 1 Sep 2023 11:59:08 +0000
Subject: [PATCH] [mlir][ProofOfConcept] Support scalable dimensions in
 ShapedTypes

This patch adds support for scalable dims to ShapedTypes, it does this
by making use of negative sizes for dims. With this:

- Fixed/static dimensions are positive values (>= 0)
- kDynamic stays the same (std::numeric_limits<int64_t>::min())
- Scalable dimensions are negative values (e.g. [8] is -8)

Associated methods/helpers to check for these dimensions have been added
to ShapedType. So a shape is considered:

- Dynamic if any dimension is dynamic
- Scalable if any dimension is scalable (and no dimension is dynamic)
- Static/fixed if it is not dynamic or scalable

Treating these dimensions as integers or sizes could be error-prone, so
instead, it is recommended to convert them to a ShapeDim before querying
the size. Currently, int64_ts can be implicitly converted to/from an
mlir::ShapeDim, however, it's hoped that it can eventually be made an
explicit conversion with ShapeDim used everywhere.

The ShapeDim class (which simply wraps an int64_t) has methods to check
for fixed, dynamic, or scalable dimensions with methods to query:

- The fixedSize(); if this dimension isFixed() return the size,
  abort otherwise
- The scalableSize(); if this dimension isScalable() return the minimum
  size (i.e. the size when the runtime multiplier is 1), abort otherwise
- The minSize(); the minimum size of a dimension, 0 for dynamic,
  scalableSize() for scalable, fixedSize() otherwise

With this, the array of booleans currently added to the VectorType to
mark which dimensions are scalable can be removed. This makes the vector
type act like any other ShapedType. This removes a lot of the special
cases needed to handle scalable dims. Doing this, fixes several issues
with misinterpreting or dropping scalable dims for vector types.

For example, now `vectorType.getDim(0) == 1` is only true if dim 0 is
a unit dimension, and not a scalable dimension [1].

And:

  VectorType::get(
    vectorType.getShape().drop_front(), vectorType.getElementType())

will correctly drop the leading dimension and preserve all scalable
dimensions.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |  30 ++---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   4 +-
 mlir/include/mlir/IR/BuiltinAttributes.td     |   2 +-
 .../include/mlir/IR/BuiltinDialectBytecode.td |   3 +-
 mlir/include/mlir/IR/BuiltinTypeInterfaces.h  |   1 +
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 110 +++++++++++++++---
 mlir/include/mlir/IR/BuiltinTypes.h           |  33 +-----
 mlir/include/mlir/IR/BuiltinTypes.td          |  31 +++--
 mlir/include/mlir/IR/CommonTypeConstraints.td |   4 +-
 mlir/include/mlir/IR/ShapeDim.h               |  85 ++++++++++++++
 mlir/lib/AsmParser/AttributeParser.cpp        |   9 +-
 mlir/lib/AsmParser/Parser.h                   |   3 +-
 mlir/lib/AsmParser/TypeParser.cpp             |  22 ++--
 mlir/lib/CAPI/IR/BuiltinTypes.cpp             |   2 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   5 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |   6 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |   5 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  16 +--
 .../Transforms/LegalizeForLLVMExport.cpp      |   9 +-
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp       |   3 +-
 mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp  |   3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |  15 ++-
 .../Linalg/Transforms/Vectorization.cpp       |  16 +--
 .../Transforms/SparseVectorization.cpp        |   5 +-
 .../lib/Dialect/SparseTensor/Utils/Merger.cpp |   2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  72 ++++--------
 .../Transforms/LowerVectorMultiReduction.cpp  |  37 +++---
 .../Vector/Transforms/LowerVectorTransfer.cpp |  20 ++--
 .../Transforms/VectorDropLeadUnitDim.cpp      |  19 +--
 .../Vector/Transforms/VectorTransforms.cpp    |   5 +-
 mlir/lib/IR/AsmPrinter.cpp                    |  16 +--
 mlir/lib/IR/BuiltinAttributeInterfaces.cpp    |   2 +-
 mlir/lib/IR/BuiltinAttributes.cpp             |   6 +-
 mlir/lib/IR/BuiltinTypeInterfaces.cpp         |   6 +-
 mlir/lib/IR/BuiltinTypes.cpp                  |  27 ++---
 mlir/lib/Target/LLVMIR/TypeToLLVM.cpp         |   2 +-
 mlir/test/IR/invalid-builtin-attributes.mlir  |   2 +-
 mlir/test/IR/invalid-builtin-types.mlir       |   4 +-
 38 files changed, 358 insertions(+), 284 deletions(-)
 create mode 100644 mlir/include/mlir/IR/ShapeDim.h

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..304fa31cb6ff88b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -44,20 +44,25 @@ def ArmSME_Dialect : Dialect {
 
 class SMETileType<Type datatype, list<int> dims, string description>
   : ShapedContainerType<[datatype],
-      And<[IsVectorOfRankPred<[2]>, allDimsScalableVectorTypePred,
-           IsVectorOfShape<dims>]>,
+      And<[IsVectorOfRankPred<[2]>, IsVectorOfShape<dims>]>,
   description>;
 
-def nxnxv16i8  : SMETileType<I8,   [16, 16], "vector<[16]x[16]xi8>">;
-def nxnxv8i16  : SMETileType<I16,  [8,  8 ], "vector<[8]x[8]xi16>">;
-def nxnxv4i32  : SMETileType<I32,  [4,  4 ], "vector<[4]x[4]xi32>">;
-def nxnxv2i64  : SMETileType<I64,  [2,  2 ], "vector<[2]x[2]xi64>">;
-def nxnxv1i128 : SMETileType<I128, [1,  1 ], "vector<[1]x[1]xi128>">;
 
-def nxnxv8f16  : SMETileType<F16,  [8,  8 ], "vector<[8]x[8]xf16>">;
-def nxnxv8bf16 : SMETileType<BF16, [8,  8 ], "vector<[8]x[8]xbf16>">;
-def nxnxv4f32  : SMETileType<F32,  [4,  4 ], "vector<[4]x[4]xf32>">;
-def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
+class DimOfSize<int size> {
+  int scalable = !sub(0, size);
+  int fixed = size;
+}
+
+def nxnxv16i8  : SMETileType<I8,   [DimOfSize<16>.scalable, DimOfSize<16>.scalable], "vector<[16]x[16]xi8>">;
+def nxnxv8i16  : SMETileType<I16,  [DimOfSize<8>.scalable, DimOfSize<8>.scalable], "vector<[8]x[8]xi16>">;
+def nxnxv4i32  : SMETileType<I32,  [DimOfSize<4>.scalable, DimOfSize<4>.scalable], "vector<[4]x[4]xi32>">;
+def nxnxv2i64  : SMETileType<I64,  [DimOfSize<2>.scalable, DimOfSize<2>.scalable], "vector<[2]x[2]xi64>">;
+def nxnxv1i128 : SMETileType<I128, [DimOfSize<1>.scalable, DimOfSize<1>.scalable], "vector<[1]x[1]xi128>">;
+
+def nxnxv8f16  : SMETileType<F16,  [DimOfSize<8>.scalable, DimOfSize<8>.scalable], "vector<[8]x[8]xf16>">;
+def nxnxv8bf16 : SMETileType<BF16, [DimOfSize<8>.scalable, DimOfSize<8>.scalable], "vector<[8]x[8]xbf16>">;
+def nxnxv4f32  : SMETileType<F32,  [DimOfSize<4>.scalable, DimOfSize<4>.scalable], "vector<[4]x[4]xf32>">;
+def nxnxv2f64  : SMETileType<F64,  [DimOfSize<2>.scalable, DimOfSize<2>.scalable], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
@@ -421,8 +426,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
       "tile", "vector",
       "VectorType::get("
         "::llvm::cast<mlir::VectorType>($_self).getShape().drop_front(),"
-        "::llvm::cast<mlir::VectorType>($_self).getElementType(),"
-        "/*scalableDims=*/{true})">,
+        "::llvm::cast<mlir::VectorType>($_self).getElementType())">,
 ]> {
   let summary = "Move 1-D scalable vector to slice of 2-D tile";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bf42b4053ac05b8..dc0675af3c68df1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -730,7 +730,7 @@ def Vector_ScalableInsertOp :
        AllTypesMatch<["dest", "res"]>,
        PredOpTrait<"position is a multiple of the source length.",
         CPred<
-          "(getPos() % getSourceVectorType().getNumElements()) == 0"
+          "(getPos() % getSourceVectorType().getMinNumElements()) == 0"
         >>]>,
      Arguments<(ins VectorOfRank<[1]>:$source,
                     ScalableVectorOfRank<[1]>:$dest,
@@ -786,7 +786,7 @@ def Vector_ScalableExtractOp :
        AllElementTypesMatch<["source", "res"]>,
        PredOpTrait<"position is a multiple of the result length.",
         CPred<
-          "(getPos() % getResultVectorType().getNumElements()) == 0"
+          "(getPos() % getResultVectorType().getMinNumElements()) == 0"
         >>]>,
      Arguments<(ins ScalableVectorOfRank<[1]>:$source,
                     I64Attr:$pos)>,
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 075eee456a7b58f..038362044190613 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -850,7 +850,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
              "expected sparse indices to be 64-bit integer values");
       assert((::llvm::isa<RankedTensorType, VectorType>(type)) &&
              "type must be ranked tensor or vector");
-      assert(type.hasStaticShape() && "type must have static shape");
+      assert(!type.hasDynamicShape() && "type cannot have dynamic shape");
       return $_get(type.getContext(), type,
                    ::llvm::cast<DenseIntElementsAttr>(indices), values);
     }]>,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index f50b5dd7ad82263..dd0296e8a628911 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -283,13 +283,12 @@ def VectorType : DialectType<(type
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
-  Array<BoolList>:$scalableDims,
   Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, elementType)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c59585..6f2a015c977c096 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_IR_BUILTINTYPEINTERFACES_H
 #define MLIR_IR_BUILTINTYPEINTERFACES_H
 
+#include "mlir/IR/ShapeDim.h"
 #include "mlir/IR/Types.h"
 
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index db38e2e1bce22aa..4758cde72900e35 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -54,8 +54,8 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
     A shape is a list of sizes corresponding to the dimensions of the container.
     If the number of dimensions in the shape is unknown, the shape is "unranked".
     If the number of dimensions is known, the shape "ranked". The sizes of the
-    dimensions of the shape must be positive, or kDynamic (in which case the
-    size of the dimension is dynamic, or not statically known).
+    fixed dimensions are positive, scalable dimensions are negative, and unknown
+    dimensions are kDynamic.
   }];
   let methods = [
     InterfaceMethod<[{
@@ -89,21 +89,53 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   ];
 
   let extraClassDeclaration = [{
-    static constexpr int64_t kDynamic =
-        std::numeric_limits<int64_t>::min();
+    static constexpr int64_t kDynamic = ShapeDim::kDynamic();
 
     /// Whether the given dimension size indicates a dynamic dimension.
     static constexpr bool isDynamic(int64_t dValue) {
-      return dValue == kDynamic;
+      return ShapeDim(dValue).isDynamic();
+    }
+
+    /// Whether the given dimension is scalable (i.e. will be scaled a runtime
+    /// multiplier).
+    static constexpr bool isScalable(int64_t dValue) {
+      return ShapeDim(dValue).isScalable();
     }
 
     /// Whether the given shape has any size that indicates a dynamic dimension.
     static bool isDynamicShape(ArrayRef<int64_t> dSizes) {
-      return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); });
+      return any_of(dSizes, [](ShapeDim dSize) { return dSize.isDynamic(); });
+    }
+
+    /// Whether the given shape has any size that indicates a scalable dimension.
+    /// If any size is dynamic, the overall shape is considered dynamic, not
+    /// scalable.
+    static bool isScalableShape(ArrayRef<int64_t> dSizes) {
+      bool hasScalableDims = false;
+      for (ShapeDim dim: dSizes) {
+        if (dim.isDynamic()) return false;
+        hasScalableDims |= dim.isScalable();
+      }
+      return hasScalableDims;
     }
 
+    /// Whether all dimensions of the shape have a static size (i.e.
+    /// not dynamic or scalable).
+    static bool isStaticShape(ArrayRef<int64_t> dSizes) {
+       return all_of(dSizes, [](ShapeDim dSize) { return dSize.isFixed(); });
+    }
+
+    /// Return the minimum number of elements a shape could hold. For dynamic
+    /// shapes this is always zero. For scalable shapes this is the number of
+    /// elements when the runtime multiplier is one.
+    static int64_t getMinNumElements(ArrayRef<int64_t> shape);
+
     /// Return the number of elements present in the given shape.
-    static int64_t getNumElements(ArrayRef<int64_t> shape);
+    /// Requires: isStaticShape.
+    static int64_t getNumElements(ArrayRef<int64_t> shape) {
+      assert(isStaticShape(shape));
+      return getMinNumElements(shape);
+    }
 
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
@@ -138,10 +170,18 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
       return $_type.getShape().size();
     }
 
+    /// Return the minimum number of elements this type can hold. If this has
+    /// a dynamic shape that will be zero. If this has a scalable this is the
+    /// number of elements when the runtime multiplier is one.
+    int64_t getMinNumElements() const {
+      return ::mlir::ShapedType::getMinNumElements($_type.getShape());
+    }
+
     /// If it has static shape, return the number of elements. Otherwise, abort.
     int64_t getNumElements() const {
-      assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
-      return ::mlir::ShapedType::getNumElements($_type.getShape());
+      assert(hasStaticShape()
+        && "cannot get number of elements for scalable/dynamic size");
+      return getMinNumElements();
     }
 
     /// Returns true if this dimension has a dynamic size (for ranked types);
@@ -151,17 +191,32 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
       return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
     }
 
+    /// Returns true if this dimension has a scalable size (for ranked types);
+    /// aborts for unranked types.
+    bool isScalableDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid index for shaped type");
+      return ::mlir::ShapedType::isScalable($_type.getShape()[idx]);
+    }
+
     /// Returns if this type has a static shape, i.e. if the type is ranked and
-    /// all dimensions have known size (>= 0).
+    /// all dimensions have known size at runtime.
     bool hasStaticShape() const {
       return $_type.hasRank() &&
-             !::mlir::ShapedType::isDynamicShape($_type.getShape());
+        ::mlir::ShapedType::isStaticShape($_type.getShape());
+    }
+
+    /// Returns if this type has a scalable shape, i.e. if the type is ranked
+    /// and any dimension is scalable (but not dynamic).
+    bool hasScalableShape() const {
+      return $_type.hasRank() &&
+        ::mlir::ShapedType::isScalableShape($_type.getShape());
     }
 
-    /// Returns if this type has a static shape and the shape is equal to
-    /// `shape` return true.
-    bool hasStaticShape(::llvm::ArrayRef<int64_t> shape) const {
-      return hasStaticShape() && $_type.getShape() == shape;
+    /// Returns if this type has a dynamic shape, i.e. if the type is ranked
+    /// and any dimension is dynamic.
+    bool hasDynamicShape() const {
+      return $_type.hasRank() &&
+        ::mlir::ShapedType::isDynamicShape($_type.getShape());
     }
 
     /// If this is a ranked type, return the number of dimensions with dynamic
@@ -170,9 +225,28 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
       return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic);
     }
 
-    /// If this is ranked type, return the size of the specified dimension.
-    /// Otherwise, abort.
+    /// If this is a ranked type, return the number of dimensions with scalable
+    /// size. Otherwise, abort.
+    int64_t getNumScalableDims() const {
+      return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isScalable);
+    }
+
+    /// Deprecated. Using getDim() and explicitly querying for dim.fixedSize(),
+    /// dim.scalableSize(), dim.minSize(), or comparing with
+    /// ShapeDim::kDynamic() is preferred.
+    /// Gets the size of a dimension. For dynamic dimensions this returns
+    /// kDynamic, for scalable dimensions the min size, or the size for fixed
+    /// dimensions.
     int64_t getDimSize(unsigned idx) const {
+      ::mlir::ShapeDim dim = getDim(idx);
+      if (dim.isDynamic())
+        return dim;
+      return dim.minSize();
+    }
+
+    /// If this is a ranked type, return the specified dimension.
+    /// Otherwise, abort.
+    ShapeDim getDim(unsigned idx) const {
       assert(idx < getRank() && "invalid index for shaped type");
       return $_type.getShape()[idx];
     }
@@ -181,7 +255,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
     /// dimensions, given its `index` within the shape.
     unsigned getDynamicDimIndex(unsigned index) const {
       assert(index < getRank() && "invalid index");
-      assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index");
+      assert(::mlir::ShapedType::isDynamic(getDim(index)) && "invalid index");
       return llvm::count_if($_type.getShape().take_front(index),
                             ::mlir::ShapedType::isDynamic);
     }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..3ae6ae2badace00 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -313,26 +313,13 @@ class VectorType::Builder {
 public:
   /// Build from another VectorType.
   explicit Builder(VectorType other)
-      : shape(other.getShape()), elementType(other.getElementType()),
-        scalableDims(other.getScalableDims()) {}
+      : shape(other.getShape()), elementType(other.getElementType()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
-          unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
-      : shape(shape), elementType(elementType) {
-    if (scalableDims.empty())
-      scalableDims = SmallVector<bool>(shape.size(), false);
-    else
-      this->scalableDims = scalableDims;
-  }
-
-  Builder &setShape(ArrayRef<int64_t> newShape,
-                    ArrayRef<bool> newIsScalableDim = {}) {
-    if (newIsScalableDim.empty())
-      scalableDims = SmallVector<bool>(shape.size(), false);
-    else
-      scalableDims = newIsScalableDim;
+  Builder(ArrayRef<int64_t> shape, Type elementType)
+      : shape(shape), elementType(elementType) {}
 
+  Builder &setShape(ArrayRef<int64_t> newShape) {
     shape = newShape;
     return *this;
   }
@@ -347,28 +334,18 @@ class VectorType::Builder {
     assert(pos < shape.size() && "overflow");
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
-    if (storageScalableDims.empty())
-      storageScalableDims.append(scalableDims.begin(), scalableDims.end());
     storage.erase(storage.begin() + pos);
-    storageScalableDims.erase(storageScalableDims.begin() + pos);
     shape = {storage.data(), storage.size()};
-    scalableDims =
-        ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
     return *this;
   }
 
-  operator VectorType() {
-    return VectorType::get(shape, elementType, scalableDims);
-  }
+  operator VectorType() { return VectorType::get(shape, elementType); }
 
 private:
   ArrayRef<int64_t> shape;
   // Owning shape data for copy-on-write operations.
   SmallVector<int64_t> storage;
   Type elementType;
-  ArrayRef<bool> scalableDims;
-  // Owning scalableDims data for copy-on-write operations.
-  SmallVector<bool> storageScalableDims;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5ec986ac26de06b..396f6b2078f0817 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1072,22 +1072,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    "Type":$elementType,
-    ArrayRefParameter<"bool">:$scalableDims
+    "Type":$elementType
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<bool>", "{}">:$scalableDims
+      "ArrayRef<int64_t>":$shape, "Type":$elementType
     ), [{
-      // While `scalableDims` is optional, its default value should be
-      // `false` for every dim in `shape`.
-      SmallVector<bool> isScalableVec;
-      if (scalableDims.empty()) {
-        isScalableVec.resize(shape.size(), false);
-        scalableDims = isScalableVec;
-      }
-      return $_get(elementType.getContext(), shape, elementType, scalableDims);
+      return $_get(elementType.getContext(), shape, elementType);
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1102,15 +1093,21 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
       return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
     }
 
+    /// Overload for VectorType::get() for ShapeDim. Cannot be called ::get
+    /// due to ambiguous overloads.
+    static VectorType create(ArrayRef<::mlir::ShapeDim> shape, Type elementType) {
+      SmallVector<int64_t> i64Shape(shape.size());
+      for (auto [i64dim, shapeDim] : llvm::zip(i64Shape, shape))
+        i64dim = shapeDim;
+      return get(i64Shape, elementType);
+    }
+
     /// Returns true if the vector contains scalable dimensions.
     bool isScalable() const {
-      return llvm::is_contained(getScalableDims(), true);
+      return hasScalableShape();
     }
     bool allDimsScalable() const {
-      // Treat 0-d vectors as fixed size.
-      if (getRank() == 0)
-        return false;
-      return !llvm::is_contained(getScalableDims(), false);
+      return llvm::all_of(getShape(), ShapedType::isScalable);
     }
 
     /// Get or create a new VectorType with the same shape as `this` and an
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..a930f44153ef801 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -459,7 +459,7 @@ class VectorOfRankAndType<list<int> allowedRanks,
 class IsVectorOfLengthPred<list<int> allowedLengths> :
   And<[IsVectorTypePred,
        Or<!foreach(allowedlength, allowedLengths,
-                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
+                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getMinNumElements()
                            == }]
                          # allowedlength>)>]>;
 
@@ -477,7 +477,7 @@ class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
 class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
   And<[IsScalableVectorTypePred,
        Or<!foreach(allowedlength, allowedLengths,
-                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
+                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getMinNumElements()
                            == }]
                          # allowedlength>)>]>;
 
diff --git a/mlir/include/mlir/IR/ShapeDim.h b/mlir/include/mlir/IR/ShapeDim.h
new file mode 100644
index 000000000000000..a6293cb07fb930c
--- /dev/null
+++ b/mlir/include/mlir/IR/ShapeDim.h
@@ -0,0 +1,85 @@
+//===- ShapeDim.h - MLIR ShapeDim Class --------------------------------------*-
+// C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ShapeDim class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SHAPEDIM_H
+#define MLIR_IR_SHAPEDIM_H
+
+#include <cassert>
+#include <cstdint>
+#include <limits>
+
+namespace mlir {
+
+struct ShapeDim {
+  /// Deprecated. Use ShapeDim::fixed(), ShapeDim::scalable(), or
+  /// ShapeDim::kDynamic() instead.
+  constexpr ShapeDim(int64_t size) : size(size) {}
+
+  /// Deprecated. Use an explicit conversion instead.
+  constexpr operator int64_t() const { return size; }
+
+  /// Construct a scalable dimension.
+  constexpr static ShapeDim scalable(int64_t size) {
+    assert(size > 0);
+    return ShapeDim{-size};
+  }
+
+  /// Construct a fixed dimension.
+  constexpr static ShapeDim fixed(int64_t size) {
+    assert(size >= 0);
+    return ShapeDim{size};
+  }
+
+  /// Construct a dynamic dimension.
+  constexpr static ShapeDim kDynamic() { return ShapeDim{}; }
+
+  /// Returns whether this is a dynamic dimension.
+  constexpr bool isDynamic() const { return size == kDynamic(); }
+
+  /// Returns whether this is a scalable dimension.
+  constexpr bool isScalable() const { return size < 0 && !isDynamic(); }
+
+  /// Returns whether this is a fixed dimension.
+  constexpr bool isFixed() const { return size >= 0; }
+
+  /// Asserts a dimension is fixed and returns its size.
+  constexpr int64_t fixedSize() const {
+    assert(isFixed());
+    return size;
+  };
+
+  /// Asserts a dimension is scalable and returns its size.
+  constexpr int64_t scalableSize() const {
+    assert(isScalable());
+    return -size;
+  }
+
+  /// Returns the minimum (runtime) size for this dimension.
+  constexpr int64_t minSize() const {
+    if (isScalable())
+      return scalableSize();
+    if (isFixed())
+      return fixedSize();
+    return 0;
+  }
+
+private:
+  constexpr explicit ShapeDim()
+      : ShapeDim(std::numeric_limits<int64_t>::min()) {}
+
+  int64_t size;
+};
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 3437ac9addc5ff6..79d1dc74df62662 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -560,7 +560,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
   }
 
   // Handle the case where no elements were parsed.
-  if (!hexStorage && storage.empty() && type.getNumElements()) {
+  if (!hexStorage && storage.empty() && type.getMinNumElements()) {
     p.emitError(loc) << "parsed zero elements, but type (" << type
                      << ") expected at least 1";
     return nullptr;
@@ -1059,8 +1059,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
     return nullptr;
   }
 
-  if (!sType.hasStaticShape())
-    return (emitError("elements literal type must have static shape"), nullptr);
+  if (sType.hasDynamicShape())
+    return (emitError("elements literal type cannot have dynamic dims"),
+            nullptr);
 
   return sType;
 }
@@ -1134,7 +1135,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   auto valuesEltType = type.getElementType();
   ShapedType valuesType =
       valuesParser.getShape().empty()
-          ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
+          ? RankedTensorType::get({indicesType.getDim(0)}, valuesEltType)
           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
   auto values = valuesParser.getAttr(valuesLoc, valuesType);
 
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 01c55f97a08c2ce..13062b0484d8e19 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -210,8 +210,7 @@ class Parser {
 
   /// Parse a vector type.
   VectorType parseVectorType();
-  ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                       SmallVectorImpl<bool> &scalableDims);
+  ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
                                        bool allowDynamic = true,
                                        bool withTrailingX = true);
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 306e850af27bc58..c3bb9e094014aa1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -444,12 +444,11 @@ VectorType Parser::parseVectorType() {
     return nullptr;
 
   SmallVector<int64_t, 4> dimensions;
-  SmallVector<bool, 4> scalableDims;
-  if (parseVectorDimensionList(dimensions, scalableDims))
+  if (parseVectorDimensionList(dimensions))
     return nullptr;
-  if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
+  if (any_of(dimensions, [](int64_t i) { return i == 0; }))
     return emitError(getToken().getLoc(),
-                     "vector types must have positive constant sizes"),
+                     "vector type must have non-zero sizes"),
            nullptr;
 
   // Parse the element type.
@@ -462,33 +461,32 @@ VectorType Parser::parseVectorType() {
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, scalableDims);
+  return VectorType::get(dimensions, elementType);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
-/// For i-th dimension, `scalableDims[i]` contains either:
-///   * `false` for a non-scalable dimension (e.g. `4`),
-///   * `true` for a scalable dimension (e.g. `[4]`).
+/// For i-th dimension, `dimensions[i]` contains either:
+///   * `dim_size` for a non-scalable dimension (e.g. `4`),
+///   * `-dim_size` for a scalable dimension (e.g. `[4]`).
 ///
 /// vector-dim-list := (static-dim-list `x`)?
 /// static-dim-list ::= static-dim (`x` static-dim)*
 /// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
 ///
 ParseResult
-Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                 SmallVectorImpl<bool> &scalableDims) {
+Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions) {
   // If there is a set of fixed-length dimensions, consume it
   while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
     int64_t value;
     bool scalable = consumeIf(Token::l_square);
     if (parseIntegerInDimensionList(value))
       return failure();
-    dimensions.push_back(value);
+    dimensions.push_back(scalable ? ShapeDim::scalable(value)
+                                  : ShapeDim::fixed(value));
     if (scalable) {
       if (!consumeIf(Token::r_square))
         return emitWrongTokenError("missing ']' closing scalable dimension");
     }
-    scalableDims.push_back(scalable);
     // Make sure we have an 'x' or something like 'xbf32'.
     if (parseXInDimensionList())
       return failure();
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..2c3c62935b0cdd5 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -241,7 +241,7 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
 
 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
   return llvm::cast<ShapedType>(unwrap(type))
-      .getDimSize(static_cast<unsigned>(dim));
+      .getDim(static_cast<unsigned>(dim));
 }
 
 int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..67b8df539c78641 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -498,9 +498,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
-                                    type.getScalableDims().back());
+    return VectorType::get(1, elementType);
+  Type vectorType = VectorType::get(type.getShape().back(), elementType);
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
   assert(
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..622a7b02176b976 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -148,8 +148,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
 
     // Unpack 1-d vector type from 2-d vector type.
     auto tileSliceType =
-        VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                        /*scalableDims=*/{true});
+        VectorType::get(tileType.getShape().drop_front(), tileElementType);
     auto denseAttr1D = DenseElementsAttr::get(
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
@@ -209,8 +208,7 @@ struct BroadcastOpToArmSMELowering
         (srcVectorType && (srcVectorType.getRank() == 0))) {
       // Broadcast scalar or 0-d vector to 1-d vector.
       auto tileSliceType =
-          VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                          /*scalableDims=*/{true});
+          VectorType::get(tileType.getShape().drop_front(), tileElementType);
       broadcastOp1D = rewriter.create<vector::BroadcastOp>(
           loc, tileSliceType, broadcastOp.getSource());
     } else if (srcVectorType && (srcVectorType.getRank() == 1))
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 92f7aa69760395a..2db63608480c135 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -31,8 +31,7 @@ using namespace mlir::vector;
 // Helper to reduce vector type by *all* but one rank at back.
 static VectorType reducedVectorTypeBack(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         tp.getScalableDims().take_back());
+  return VectorType::get(tp.getShape().take_back(), tp.getElementType());
 }
 
 // Helper that picks the proper sequence for inserting.
@@ -1390,7 +1389,7 @@ class VectorCreateMaskOpRewritePattern
         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
     auto loc = op->getLoc();
     Value indices = rewriter.create<LLVM::StepVectorOp>(
-        loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
+        loc, LLVM::getVectorType(idxType, dstType.getDim(0).minSize(),
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
                                                  op.getOperand(0));
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 675bddca61a3e2d..42230d29b51634e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -728,7 +728,6 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       vectorType = targetVectorType;
     }
 
-    auto scalableDimensions = vectorType.getScalableDims();
     auto shape = vectorType.getShape();
     constexpr int64_t singletonShape[] = {1};
     if (vectorType.getRank() == 0)
@@ -738,10 +737,11 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
       // vector.extractelement for 1D vectors).
-      auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
-                                        std::multiplies<int64_t>());
-      auto flatVectorType =
-          VectorType::get({flatLength}, vectorType.getElementType());
+      auto flatLength = vectorType.getMinNumElements();
+      auto flatVectorType = VectorType::create(
+          vectorType.isScalable() ? ShapeDim::scalable(flatLength)
+                                  : ShapeDim::fixed(flatLength),
+          vectorType.getElementType());
       value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
     }
 
@@ -749,10 +749,12 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     SmallVector<Value, 8> loopIndices;
     for (unsigned d = 0; d < shape.size(); d++) {
       // Setup loop bounds and step.
+      ShapeDim dim = shape[d];
       Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-      Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
+      Value upperBound =
+          rewriter.create<arith::ConstantIndexOp>(loc, dim.minSize());
       Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-      if (!scalableDimensions.empty() && scalableDimensions[d]) {
+      if (dim.isScalable()) {
         auto vscale = rewriter.create<vector::VectorScaleOp>(
             loc, rewriter.getIndexType());
         upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 685f8d57f76f52c..7565a9bc8b11b41 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -199,8 +199,7 @@ struct LoadTileSliceToArmSMELowering
     auto one = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getI1Type(),
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
+    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type());
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
@@ -275,8 +274,7 @@ struct StoreTileSliceToArmSMELowering
     auto one = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getI1Type(),
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
+    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type());
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     Value tileI32 = castTileIDToI32(tile, loc, rewriter);
@@ -341,8 +339,7 @@ struct MoveVectorToTileSliceToArmSMELowering
     auto one = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getI1Type(),
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
+    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type());
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index 8b2be7bc1901b9a..6d21f289cc0fa55 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -39,7 +39,8 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
     return false;
 
   unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
-  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+  if (vType.getShape() != ArrayRef<int64_t>{ShapeDim::scalable(minNumElts),
+                                            ShapeDim::scalable(minNumElts)})
     return false;
 
   return true;
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 4af836a93c2a169..9b13089d0f22945 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -29,8 +29,7 @@ using namespace mlir::arm_sve;
 static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
   if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
-    return VectorType::get(sVectorType.getShape(), i1Type,
-                           sVectorType.getScalableDims());
+    return VectorType::get(sVectorType.getShape(), i1Type);
   return nullptr;
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index bc8300a8b7329ea..cc9a80567abe37f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -957,7 +957,7 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
   return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
       .Case([](VectorType ty) {
         if (ty.isScalable())
-          return llvm::ElementCount::getScalable(ty.getNumElements());
+          return llvm::ElementCount::getScalable(ty.getMinNumElements());
         return llvm::ElementCount::getFixed(ty.getNumElements());
       })
       .Case([](LLVMFixedVectorType ty) {
@@ -995,16 +995,17 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
 
   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
   // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, {isScalable});
+  return VectorType::create(isScalable ? ShapeDim::scalable(numElements)
+                                       : ShapeDim::fixed(numElements),
+                            elementType);
 }
 
 Type mlir::LLVM::getVectorType(Type elementType,
                                const llvm::ElementCount &numElements) {
   if (numElements.isScalable())
     return getVectorType(elementType, numElements.getKnownMinValue(),
-                         /*isScalable=*/true);
-  return getVectorType(elementType, numElements.getFixedValue(),
-                       /*isScalable=*/false);
+                         /*isScalable*/ true);
+  return getVectorType(elementType, numElements.getFixedValue());
 }
 
 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
@@ -1028,9 +1029,7 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
   if (useLLVM)
     return LLVMScalableVectorType::get(elementType, numElements);
 
-  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
-  // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
+  return VectorType::create(ShapeDim::scalable(numElements), elementType);
 }
 
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df814d02e0b195c..dd211d9efa7b5c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -201,8 +201,11 @@ struct VectorizationState {
       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
     }
-
-    return VectorType::get(vectorShape, elementType, scalableDims);
+    for (auto [dim, isScalable] : llvm::zip(vectorShape, scalableDims)) {
+      if (isScalable)
+        dim = ShapeDim::scalable(dim);
+    }
+    return VectorType::get(vectorShape, elementType);
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
@@ -1217,9 +1220,9 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     assert(vecOperand && "Vector operand couldn't be found");
 
     if (firstMaxRankedType) {
-      auto vecType = VectorType::get(firstMaxRankedType.getShape(),
-                                     getElementTypeOrSelf(vecOperand.getType()),
-                                     firstMaxRankedType.getScalableDims());
+      auto vecType =
+          VectorType::get(firstMaxRankedType.getShape(),
+                          getElementTypeOrSelf(vecOperand.getType()));
       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
     } else {
       vecOperands.push_back(vecOperand);
@@ -1230,8 +1233,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   for (Type resultType : op->getResultTypes()) {
     resultTypes.push_back(
         firstMaxRankedType
-            ? VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getScalableDims())
+            ? VectorType::get(firstMaxRankedType.getShape(), resultType)
             : resultType);
   }
   //   d. Build and return the new op.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 93ee0647b7b5a6b..d9e4e6d566bdf0c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -56,7 +56,10 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
 
 /// Constructs vector type for element type.
 static VectorType vectorType(VL vl, Type etp) {
-  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
+  return VectorType::create(vl.enableVLAVectorization
+                                ? ShapeDim::scalable(vl.vectorLength)
+                                : ShapeDim::fixed(vl.vectorLength),
+                            etp);
 }
 
 /// Constructs vector type from a memref value.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 4143efbd0ab28e0..3dacf7d715cd8d2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1215,7 +1215,7 @@ Type Merger::inferType(ExprId e, Value src) const {
   // Inspect source type. For vector types, apply the same
   // vectorization to the destination type.
   if (auto vtp = dyn_cast<VectorType>(src.getType()))
-    return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
+    return VectorType::get(vtp.getNumElements(), dtp);
   return dtp;
 }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..52e458d49823339 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -313,22 +313,19 @@ MultiDimReductionOp::getShapeForUnroll() {
 
 LogicalResult MultiDimReductionOp::verify() {
   SmallVector<int64_t> targetShape;
-  SmallVector<bool> scalableDims;
   Type inferredReturnType;
-  auto sourceScalableDims = getSourceVectorType().getScalableDims();
   for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
     if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
           return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
         })) {
       targetShape.push_back(it.value());
-      scalableDims.push_back(sourceScalableDims[it.index()]);
     }
   // TODO: update to also allow 0-d vectors when available.
   if (targetShape.empty())
     inferredReturnType = getSourceVectorType().getElementType();
   else
-    inferredReturnType = VectorType::get(
-        targetShape, getSourceVectorType().getElementType(), scalableDims);
+    inferredReturnType =
+        VectorType::get(targetShape, getSourceVectorType().getElementType());
   if (getType() != inferredReturnType)
     return emitOpError() << "destination type " << getType()
                          << " is incompatible with source type "
@@ -341,8 +338,7 @@ LogicalResult MultiDimReductionOp::verify() {
 Type MultiDimReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
   return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
 }
 
 namespace {
@@ -381,8 +377,7 @@ struct ElideUnitDimsInMultiDimReduction
     if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
       if (mask) {
         VectorType newMaskType =
-            VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
-                            dstVecType.getScalableDims());
+            VectorType::get(dstVecType.getShape(), rewriter.getI1Type());
         mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
       }
       cast = rewriter.create<vector::ShapeCastOp>(
@@ -480,8 +475,7 @@ void ReductionOp::print(OpAsmPrinter &p) {
 Type ReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
   return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
 }
 
 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -1151,8 +1145,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
     auto n =
         std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
-        vectorType.getShape().drop_front(n), vectorType.getElementType(),
-        vectorType.getScalableDims().drop_front(n)));
+        vectorType.getShape().drop_front(n), vectorType.getElementType()));
   }
   return success();
 }
@@ -1510,7 +1503,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
 
   // Get the nth dimension size starting from lowest dimension.
   auto getDimReverse = [](VectorType type, int64_t n) {
-    return type.getShape().take_back(n + 1).front();
+    return type.getDim(type.getRank() - 1 - n).minSize();
   };
   int64_t destinationRank =
       llvm::isa<VectorType>(extractOp.getType())
@@ -2782,15 +2775,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
 
   VectorType resType;
   if (vRHS) {
-    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
-                                      vRHS.getScalableDims()[0]};
-    resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                              vLHS.getElementType(), scalableDimsRes);
+    resType = VectorType::get({vLHS.getDim(0), vRHS.getDim(0)},
+                              vLHS.getElementType());
   } else {
     // Scalar RHS operand
-    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
-    resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
-                              scalableDimsRes);
+    resType = VectorType::get({vLHS.getDim(0)}, vLHS.getElementType());
   }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -2855,8 +2844,7 @@ LogicalResult OuterProductOp::verify() {
 Type OuterProductOp::getExpectedMaskType() {
   auto vecType = this->getResultVectorType();
   return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
 }
 
 //===----------------------------------------------------------------------===//
@@ -3504,10 +3492,7 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
   assert(invPermMap && "Inversed permutation map couldn't be computed");
   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
 
-  SmallVector<bool> scalableDims =
-      applyPermutationMap(invPermMap, vecType.getScalableDims());
-
-  return VectorType::get(maskShape, i1Type, scalableDims);
+  return VectorType::get(maskShape, i1Type);
 }
 
 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -4463,8 +4448,7 @@ LogicalResult GatherOp::verify() {
 Type GatherOp::getExpectedMaskType() {
   auto vecType = this->getIndexVectorType();
   return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getScalableDims());
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
 }
 
 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
@@ -4647,7 +4631,7 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned rankB = b.size();
   assert(rankA < rankB);
 
-  auto isOne = [](int64_t v) { return v == 1; };
+  auto isOne = [](ShapeDim dim) { return dim.minSize() == 1; };
 
   // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
   // casted to a 0-d vector.
@@ -4657,10 +4641,10 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned i = 0;
   unsigned j = 0;
   while (i < rankA && j < rankB) {
-    int64_t dimA = a[i];
+    int64_t dimA = ShapeDim(a[i]).minSize();
     int64_t dimB = 1;
     while (dimB < dimA && j < rankB)
-      dimB *= b[j++];
+      dimB *= ShapeDim(b[j++]).minSize();
     if (dimA != dimB)
       break;
     ++i;
@@ -4788,22 +4772,15 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
   ArrayRef<int64_t> newShape = oldShape;
 
-  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
-  ArrayRef<bool> newScalableDims = oldScalableDims;
-
-  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
+  while (!newShape.empty() && newShape.back() == 1)
     newShape = newShape.drop_back(1);
-    newScalableDims = newScalableDims.drop_back(1);
-  }
 
   // Make sure we have at least 1 dimension.
   // TODO: Add support for 0-D vectors.
-  if (newShape.empty()) {
+  if (newShape.empty())
     newShape = oldShape.take_back();
-    newScalableDims = oldScalableDims.take_back();
-  }
 
-  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+  return VectorType::get(newShape, oldType.getElementType());
 }
 
 /// Folds qualifying shape_cast(create_mask) into a new create_mask
@@ -5095,15 +5072,11 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
                                 Value vector, ArrayRef<int64_t> transp) {
   VectorType vt = llvm::cast<VectorType>(vector.getType());
   SmallVector<int64_t, 4> transposedShape(vt.getRank());
-  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
-  for (unsigned i = 0; i < transp.size(); ++i) {
+  for (unsigned i = 0; i < transp.size(); ++i)
     transposedShape[i] = vt.getShape()[transp[i]];
-    transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
-  }
 
   result.addOperands(vector);
-  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
-                                  transposedScalableDims));
+  result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
   result.addAttribute(TransposeOp::getTranspAttrName(result.name),
                       builder.getI64ArrayAttr(transp));
 }
@@ -5314,11 +5287,10 @@ LogicalResult ConstantMaskOp::verify() {
         "must specify array attr of size equal vector result rank");
   // Verify that each array attr element is in bounds of corresponding vector
   // result dimension size.
-  auto resultShape = resultType.getShape();
   SmallVector<int64_t, 4> maskDimSizes;
   for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
     int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
-    if (attrValue < 0 || attrValue > resultShape[it.index()])
+    if (attrValue < 0 || attrValue > resultType.getDim(it.index()).minSize())
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
     maskDimSizes.push_back(attrValue);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd8..cf4e35e9e2e7906 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -154,8 +154,6 @@ class ReduceMultiDimReductionRank
 
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
-    auto srcScalableDims =
-        multiReductionOp.getSourceVectorType().getScalableDims();
     auto loc = multiReductionOp.getLoc();
 
     // If rank less than 2, nothing to do.
@@ -164,7 +162,7 @@ class ReduceMultiDimReductionRank
 
     // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
     // `vscale * vscale` that's currently not modelled.
-    if (llvm::count(srcScalableDims, true) > 1)
+    if (multiReductionOp.getSourceVectorType().getNumScalableDims() > 1)
       return failure();
 
     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
@@ -174,20 +172,16 @@ class ReduceMultiDimReductionRank
 
     // 1. Separate reduction and parallel dims.
     SmallVector<int64_t, 4> parallelDims, parallelShapes;
-    SmallVector<bool, 4> parallelScalableDims;
     SmallVector<int64_t, 4> reductionDims, reductionShapes;
-    bool isReductionDimScalable = false;
     for (const auto &it : llvm::enumerate(reductionMask)) {
       int64_t i = it.index();
       bool isReduction = it.value();
       if (isReduction) {
         reductionDims.push_back(i);
         reductionShapes.push_back(srcShape[i]);
-        isReductionDimScalable |= srcScalableDims[i];
       } else {
         parallelDims.push_back(i);
         parallelShapes.push_back(srcShape[i]);
-        parallelScalableDims.push_back(srcScalableDims[i]);
       }
     }
 
@@ -196,13 +190,13 @@ class ReduceMultiDimReductionRank
     int flattenedReductionDim = 0;
     if (!parallelShapes.empty()) {
       flattenedParallelDim = 1;
-      for (auto d : parallelShapes)
-        flattenedParallelDim *= d;
+      for (ShapeDim d : parallelShapes)
+        flattenedParallelDim *= d.minSize();
     }
     if (!reductionShapes.empty()) {
       flattenedReductionDim = 1;
-      for (auto d : reductionShapes)
-        flattenedReductionDim *= d;
+      for (ShapeDim d : reductionShapes)
+        flattenedReductionDim *= d.minSize();
     }
     // We must at least have some parallel or some reduction.
     assert((flattenedParallelDim || flattenedReductionDim) &&
@@ -223,23 +217,19 @@ class ReduceMultiDimReductionRank
     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
     // a single parallel (resp. reduction) dim.
     SmallVector<bool, 2> mask;
-    SmallVector<bool, 2> scalableDims;
     SmallVector<int64_t, 2> vectorShape;
-    bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
+    bool isParallelDimScalable = ShapedType::isScalableShape(parallelShapes);
     if (flattenedParallelDim) {
       mask.push_back(false);
       vectorShape.push_back(flattenedParallelDim);
-      scalableDims.push_back(isParallelDimScalable);
     }
     if (flattenedReductionDim) {
       mask.push_back(true);
       vectorShape.push_back(flattenedReductionDim);
-      scalableDims.push_back(isReductionDimScalable);
     }
     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
       std::swap(mask.front(), mask.back());
       std::swap(vectorShape.front(), vectorShape.back());
-      std::swap(scalableDims.front(), scalableDims.back());
     }
 
     Value newVectorMask;
@@ -253,17 +243,16 @@ class ReduceMultiDimReductionRank
     }
 
     auto castedType = VectorType::get(
-        vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
-        scalableDims);
+        vectorShape, multiReductionOp.getSourceVectorType().getElementType());
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
 
     Value acc = multiReductionOp.getAcc();
     if (flattenedParallelDim) {
-      auto accType = VectorType::get(
-          {flattenedParallelDim},
-          multiReductionOp.getSourceVectorType().getElementType(),
-          /*scalableDims=*/{isParallelDimScalable});
+      auto accType = VectorType::create(
+          {isParallelDimScalable ? ShapeDim::scalable(flattenedParallelDim)
+                                 : ShapeDim::fixed(flattenedParallelDim)},
+          multiReductionOp.getSourceVectorType().getElementType());
       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
     }
     // 6. Creates the flattened form of vector.multi_reduction with inner/outer
@@ -282,8 +271,8 @@ class ReduceMultiDimReductionRank
 
     // 8. Creates shape cast for the output n-D -> 2-D.
     VectorType outputCastedType = VectorType::get(
-        parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
-        parallelScalableDims);
+        parallelShapes,
+        multiReductionOp.getSourceVectorType().getElementType());
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 68160bcf59e6678..c83cb9dcc7bb58f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -115,12 +115,8 @@ struct TransferReadPermutationLowering
     // Apply the reverse transpose to deduce the type of the transfer_read.
     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
     SmallVector<int64_t> newVectorShape(originalShape.size());
-    ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
-    SmallVector<bool> newScalableDims(originalShape.size());
-    for (const auto &pos : llvm::enumerate(permutation)) {
+    for (const auto &pos : llvm::enumerate(permutation))
       newVectorShape[pos.value()] = originalShape[pos.index()];
-      newScalableDims[pos.value()] = originalScalableDims[pos.index()];
-    }
 
     // Transpose in_bounds attribute.
     ArrayAttr newInBoundsAttr =
@@ -129,8 +125,8 @@ struct TransferReadPermutationLowering
                          : ArrayAttr();
 
     // Generate new transfer_read operation.
-    VectorType newReadType = VectorType::get(
-        newVectorShape, op.getVectorType().getElementType(), newScalableDims);
+    VectorType newReadType =
+        VectorType::get(newVectorShape, op.getVectorType().getElementType());
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
@@ -350,14 +346,12 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
 
     SmallVector<int64_t> newShape(
         originalVecType.getShape().take_back(reducedShapeRank));
-    SmallVector<bool> newScalableDims(
-        originalVecType.getScalableDims().take_back(reducedShapeRank));
     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
     if (newShape.empty())
       return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
 
-    VectorType newReadType = VectorType::get(
-        newShape, originalVecType.getElementType(), newScalableDims);
+    VectorType newReadType =
+        VectorType::get(newShape, originalVecType.getElementType());
     ArrayAttr newInBoundsAttr =
         op.getInBounds()
             ? rewriter.getArrayAttr(
@@ -495,7 +489,7 @@ struct VectorLoadToMemrefLoadLowering
   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
     auto vecType = loadOp.getVectorType();
-    if (vecType.getNumElements() != 1)
+    if (vecType.getMinNumElements() != 1)
       return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
 
     auto memrefLoad = rewriter.create<memref::LoadOp>(
@@ -514,7 +508,7 @@ struct VectorStoreToMemrefStoreLowering
   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
     auto vecType = storeOp.getVectorType();
-    if (vecType.getNumElements() != 1)
+    if (vecType.getMinNumElements() != 1)
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 913c826dd912470..d814455e68f8ff3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -23,23 +23,14 @@ using namespace mlir::vector;
 // Returns `vector<1xT>` if `oldType` only has one element.
 static VectorType trimLeadingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
-  ArrayRef<int64_t> newShape = oldShape;
-
-  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
-  ArrayRef<bool> newScalableDims = oldScalableDims;
-
-  while (!newShape.empty() && newShape.front() == 1 &&
-         !newScalableDims.front()) {
-    newShape = newShape.drop_front(1);
-    newScalableDims = newScalableDims.drop_front(1);
-  }
+  ArrayRef<int64_t> newShape =
+      oldShape.drop_while([](uint64_t dim) { return dim == 1; });
 
   // Make sure we have at least 1 dimension per vector type requirements.
-  if (newShape.empty()) {
+  if (newShape.empty())
     newShape = oldShape.take_back();
-    newScalableDims = oldType.getScalableDims().take_back();
-  }
-  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
+
+  return VectorType::get(newShape, oldType.getElementType());
 }
 
 /// Return a smallVector of size `rank` containing all zeros.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 207df69929c1c9f..d9cb50e9b24c8bf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1017,10 +1017,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
     Value mask = rewriter.create<vector::CreateMaskOp>(
-        loc,
-        VectorType::get(vtp.getShape(), rewriter.getI1Type(),
-                        vtp.getScalableDims()),
-        b);
+        loc, VectorType::get(vtp.getShape(), rewriter.getI1Type()), b);
     if (xferOp.getMask()) {
       // Intersect the in-bounds with the mask specified as an op parameter.
       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c662edd592036ce..fb44139047ec1d0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2341,7 +2341,7 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
     return printEltFn(0);
 
   // Special case for degenerate tensors.
-  auto numElements = type.getNumElements();
+  auto numElements = type.getMinNumElements();
   if (numElements == 0)
     return;
 
@@ -2398,7 +2398,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
   auto elementType = type.getElementType();
 
   // Check to see if we should format this attribute as a hex string.
-  auto numElements = type.getNumElements();
+  auto numElements = type.getMinNumElements();
   if (!attr.isSplat() && allowHex &&
       shouldPrintElementsAttrWithHex(numElements)) {
     ArrayRef<char> rawData = attr.getRawData();
@@ -2542,16 +2542,12 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         }
       })
       .Case<VectorType>([&](VectorType vectorTy) {
-        auto scalableDims = vectorTy.getScalableDims();
         os << "vector<";
-        auto vShape = vectorTy.getShape();
-        unsigned lastDim = vShape.size();
-        unsigned dimIdx = 0;
-        for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
-          if (!scalableDims.empty() && scalableDims[dimIdx])
+        for (ShapeDim dim : vectorTy.getShape()) {
+          if (dim.isScalable())
             os << '[';
-          os << vShape[dimIdx];
-          if (!scalableDims.empty() && scalableDims[dimIdx])
+          os << dim.minSize();
+          if (dim.isScalable())
             os << ']';
           os << 'x';
         }
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 9b5235a6c5ceb8b..fa93ad6e5756b68 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -29,7 +29,7 @@ Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
 }
 
 int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
-  return elementsAttr.getShapedType().getNumElements();
+  return elementsAttr.getShapedType().getMinNumElements();
 }
 
 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 5f1129326f4f772..22e130dd492e13c 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1064,7 +1064,7 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
                                          bool &detectedSplat) {
   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
-  int64_t numElements = type.getNumElements();
+  int64_t numElements = type.getMinNumElements();
 
   // The initializer is always a splat if the result type has a single element.
   detectedSplat = numElements == 1;
@@ -1268,7 +1268,7 @@ Type DenseElementsAttr::getElementType() const {
 }
 
 int64_t DenseElementsAttr::getNumElements() const {
-  return getType().getNumElements();
+  return getType().getMinNumElements();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1318,7 +1318,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
 
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    ArrayRef<char> data) {
-  assert(type.hasStaticShape() && "type must have static shape");
+  assert(!type.hasDynamicShape() && "type cannot have dynamic shape");
   bool isSplat = false;
   bool isValid = isValidRawBuffer(type, data, isSplat);
   assert(isValid);
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3f..01fc14e7c1199c0 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -25,10 +25,10 @@ using namespace mlir::detail;
 
 constexpr int64_t ShapedType::kDynamic;
 
-int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+int64_t ShapedType::getMinNumElements(ArrayRef<int64_t> shape) {
   int64_t num = 1;
-  for (int64_t dim : shape) {
-    num *= dim;
+  for (ShapeDim dim : shape) {
+    num *= dim.minSize();
     assert(num >= 0 && "integer overflow in element count computation");
   }
   return num;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..d847dc02c4472bc 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -226,21 +226,23 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 //===----------------------------------------------------------------------===//
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                 ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<bool> scalableDims) {
+                                 ArrayRef<int64_t> shape, Type elementType) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
            << elementType;
 
-  if (any_of(shape, [](int64_t i) { return i <= 0; }))
-    return emitError()
-           << "vector types must have positive constant sizes but got "
-           << shape;
+  for (auto dim : shape) {
+    if (dim == 0)
+      return emitError() << "vector type must have non-zero sizes";
+    if (dim > std::numeric_limits<std::uint32_t>::max())
+      return emitError() << "vector element count too large (possible "
+                            "misinterpretation of scalable dimension)"
+                         << elementType;
+  }
 
-  if (scalableDims.size() != shape.size())
-    return emitError() << "number of dims must match, got "
-                       << scalableDims.size() << " and " << shape.size();
+  if (any_of(shape, [](ShapeDim dim) { return dim == ShapedType::kDynamic; }))
+    return emitError() << "kDynamic is not supported in vector types";
 
   return success();
 }
@@ -250,17 +252,16 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
     return VectorType();
   if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getScalableDims());
+      return VectorType::get(getShape(), scaledEt);
   if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getScalableDims());
+      return VectorType::get(getShape(), scaledEt);
   return VectorType();
 }
 
 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
-  return VectorType::get(shape.value_or(getShape()), elementType,
-                         getScalableDims());
+  return VectorType::get(shape.value_or(getShape()), elementType);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 6d8b415ff09dceb..59354d50615067f 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -138,7 +138,7 @@ class TypeToLLVMIRTranslatorImpl {
            "expected compatible with LLVM vector type");
     if (type.isScalable())
       return llvm::ScalableVectorType::get(translateType(type.getElementType()),
-                                           type.getNumElements());
+                                           type.getMinNumElements());
     return llvm::FixedVectorType::get(translateType(type.getElementType()),
                                       type.getNumElements());
   }
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 1ff44605cb7ecde..eebff104c5439a7 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -7,7 +7,7 @@ func.func @elementsattr_non_tensor_type() -> () {
 // -----
 
 func.func @elementsattr_non_ranked() -> () {
-  "foo"(){bar = dense<[4]> : tensor<?xi32>} : () -> () // expected-error {{elements literal type must have static shape}}
+  "foo"(){bar = dense<[4]> : tensor<?xi32>} : () -> () // expected-error {{elements literal type cannot have dynamic dims}}
 }
 
 // -----
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c1f1..3c5aae2e05b0206 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -125,12 +125,12 @@ func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector type must have non-zero sizes}}
 func.func @zero_vector_type() -> vector<0xi32>
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector type must have non-zero sizes}}
 func.func @zero_in_vector_type() -> vector<1x0xi32>
 
 // -----



More information about the Mlir-commits mailing list