[Mlir-commits] [mlir] 79c83e1 - [mlir][VectorType] Allow arbitrary dimensions to be scalable

Andrzej Warzynski llvmlistbot at llvm.org
Tue Jun 27 11:27:08 PDT 2023


Author: Andrzej Warzynski
Date: 2023-06-27T19:21:59+01:00
New Revision: 79c83e12c8884fa46f2f2594836af93474f6ca5a

URL: https://github.com/llvm/llvm-project/commit/79c83e12c8884fa46f2f2594836af93474f6ca5a
DIFF: https://github.com/llvm/llvm-project/commit/79c83e12c8884fa46f2f2594836af93474f6ca5a.diff

LOG: [mlir][VectorType] Allow arbitrary dimensions to be scalable

At the moment, only the trailing dimensions in the vector type can be
scalable, i.e. this is supported:

    vector<2x[4]xf32>

and this is not allowed:

    vector<[2]x4xf32>

This patch extends the vector type so that arbitrary dimensions can be
scalable. To this end, an array of bool values is added to every vector
type to denote whether the corresponding dimensions are scalable or not.
For example, for this vector:

  vector<[2]x[3]x4xf32>

the following array would be created:

  {true, true, false}.

Additionally, the current syntax:

  vector<[2x3]x4xf32>

is replaced with:

  vector<[2]x[3]x4xf32>

This is primarily to simplify parsing (this way, the parser can easily
process one dimension at a time rather than e.g. tracking whether
"scalable block" has been entered/left).

NOTE: The `isScalableDim` parameter of `VectorType` (introduced in this
patch) makes `numScalableDims` redundant. For the time being,
`numScalableDims` is preserved to facilitate the transition between the
two parameters. `numScalableDims` will be removed in one of the
subsequent patches.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Differential Revision: https://reviews.llvm.org/D153372

Added: 
    

Modified: 
    mlir/include/mlir/Bytecode/BytecodeImplementation.h
    mlir/include/mlir/IR/BuiltinDialectBytecode.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/BytecodeBase.td
    mlir/lib/AsmParser/Parser.h
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Dialect/Builtin/invalid.mlir
    mlir/test/Dialect/Builtin/ops.mlir
    mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 1cc96e4c764e3..4e74c124adcde 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -162,6 +162,9 @@ class DialectBytecodeReader {
   /// Read a blob from the bytecode.
   virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
 
+  /// Read a bool from the bytecode.
+  virtual LogicalResult readBool(bool &result) = 0;
+
 private:
   /// Read a handle to a dialect resource.
   virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
@@ -251,6 +254,9 @@ class DialectBytecodeWriter {
   /// written as-is, with no additional compression or compaction.
   virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
 
+  /// Write a bool to the output stream.
+  virtual void writeOwnedBool(bool value) = 0;
+
   /// Return the bytecode version being emitted for.
   virtual int64_t getBytecodeVersion() const = 0;
 };

diff  --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 47d6c0df55485..40e6f04451c65 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -279,13 +279,14 @@ def VectorType : DialectType<(type
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
+  Array<BoolList>:$scalableDims,
   VarInt:$numScalableDims,
   Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "$_val.getNumScalableDims()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)";
 }
 }
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index acb355654ef71..1fd869be76e9b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -306,17 +306,28 @@ class VectorType::Builder {
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : shape(other.getShape()), elementType(other.getElementType()),
-        numScalableDims(other.getNumScalableDims()) {}
+        numScalableDims(other.getNumScalableDims()),
+        scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          unsigned numScalableDims = 0)
+          unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
       : shape(shape), elementType(elementType),
-        numScalableDims(numScalableDims) {}
+        numScalableDims(numScalableDims) {
+    if (scalableDims.empty())
+      scalableDims = SmallVector<bool>(shape.size(), false);
+    else
+      this->scalableDims = scalableDims;
+  }
 
-  Builder &setShape(ArrayRef<int64_t> newShape,
-                    unsigned newNumScalableDims = 0) {
+  Builder &setShape(ArrayRef<int64_t> newShape, unsigned newNumScalableDims = 0,
+                    ArrayRef<bool> newIsScalableDim = {}) {
     numScalableDims = newNumScalableDims;
+    if (newIsScalableDim.empty())
+      scalableDims = SmallVector<bool>(shape.size(), false);
+    else
+      scalableDims = newIsScalableDim;
+
     shape = newShape;
     return *this;
   }
@@ -333,8 +344,13 @@ class VectorType::Builder {
       numScalableDims--;
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
+    if (storageScalableDims.empty())
+      storageScalableDims.append(scalableDims.begin(), scalableDims.end());
     storage.erase(storage.begin() + pos);
+    storageScalableDims.erase(storageScalableDims.begin() + pos);
     shape = {storage.data(), storage.size()};
+    scalableDims =
+        ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
     return *this;
   }
 
@@ -344,7 +360,7 @@ class VectorType::Builder {
   operator Type() {
     if (shape.empty())
       return elementType;
-    return VectorType::get(shape, elementType, numScalableDims);
+    return VectorType::get(shape, elementType, numScalableDims, scalableDims);
   }
 
 private:
@@ -353,6 +369,9 @@ class VectorType::Builder {
   SmallVector<int64_t> storage;
   Type elementType;
   unsigned numScalableDims;
+  ArrayRef<bool> scalableDims;
+  // Owning scalableDims data for copy-on-write operations.
+  SmallVector<bool> storageScalableDims;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 58a0156d54a1f..dead6297f379e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1024,8 +1024,9 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     ```
     vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
     vector-element-type ::= float-type | integer-type | index-type
-    vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
-    static-dim-list ::= decimal-literal (`x` decimal-literal)*
+    vector-dim-list := (static-dim-list `x`)?
+    static-dim-list ::= static-dim (`x` static-dim)*
+    static-dim ::= (decimal-literal | `[` decimal-literal `]`)
     ```
 
     The vector type represents a SIMD style vector used by target-specific
@@ -1033,10 +1034,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     vectors (e.g. vector<16 x f32>) we also support multidimensional registers
     on targets that support them (like TPUs). The dimensions of a vector type
     can be fixed-length, scalable, or a combination of the two. The scalable
-    dimensions in a vector are indicated between square brackets ([ ]), and
-    all fixed-length dimensions, if present, must precede the set of scalable
-    dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>`
-    is not.
+    dimensions in a vector are indicated between square brackets ([ ]).
 
     Vector shapes must be positive decimal integers. 0D vectors are allowed by
     omitting the dimension: `vector<f32>`.
@@ -1055,24 +1053,37 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     vector<[4]xf32>
 
     // A 2D scalable-length vector that contains a multiple of 2x8 f32 elements.
-    vector<[2x8]xf32>
+    vector<[2]x[8]xf32>
 
     // A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements.
     vector<4x[4]xf32>
+
+    // A 3D mixed fixed/scalable vector in which only the inner dimension is
+    // scalable.
+    vector<2x[4]x8xf32>
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    "unsigned":$numScalableDims
+    "unsigned":$numScalableDims,
+    ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"unsigned", "0">:$numScalableDims
+      CArg<"unsigned", "0">:$numScalableDims,
+      CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
+      // While `scalableDims` is optional, its default value should be
+      // `false` for every dim in `shape`.
+      SmallVector<bool> isScalableVec;
+      if (scalableDims.empty()) {
+        isScalableVec.resize(shape.size(), false);
+        scalableDims = isScalableVec;
+      }
       return $_get(elementType.getContext(), shape, elementType,
-                   numScalableDims);
+                   numScalableDims, scalableDims);
     }]>
   ];
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 2f9b1c1efcffe..07f1e284156c3 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -92,6 +92,11 @@ def Blob :
   WithBuilder<"$_args",
   WithPrinter<"$_writer.writeOwnedBlob($_getter)",
   WithType   <"ArrayRef<char>">>>>;
+def Bool :
+  WithParser <"succeeded($_reader.readBool($_var))",
+  WithBuilder<"$_args",
+  WithPrinter<"$_writer.writeOwnedBool($_getter)",
+  WithType   <"bool">>>>;
 class KnownWidthAPInt<string s> :
   WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
   WithBuilder<"$_args",
@@ -125,6 +130,7 @@ class Array<Bytecode t> {
 //   for the list print/parsing.
 class List<Bytecode t> : WithGetter<"$_member", t>;
 def SignedVarIntList : List<SignedVarInt>;
+def BoolList : List<Bool>;
 
 // Define dialect attribute or type.
 class DialectAttrOrType<dag d> {

diff  --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 749b82c2ed4c6..655412da2b742 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -211,7 +211,8 @@ class Parser {
   /// Parse a vector type.
   VectorType parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                       unsigned &numScalableDims);
+                                       unsigned &numScalableDims,
+                                       SmallVectorImpl<bool> &scalableDims);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
                                        bool allowDynamic = true,
                                        bool withTrailingX = true);

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 211049204b268..6eeea41d97c42 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -440,8 +440,9 @@ VectorType Parser::parseVectorType() {
     return nullptr;
 
   SmallVector<int64_t, 4> dimensions;
+  SmallVector<bool, 4> scalableDims;
   unsigned numScalableDims;
-  if (parseVectorDimensionList(dimensions, numScalableDims))
+  if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims))
     return nullptr;
   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
     return emitError(getToken().getLoc(),
@@ -458,51 +459,43 @@ VectorType Parser::parseVectorType() {
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, numScalableDims);
+  return VectorType::get(dimensions, elementType, numScalableDims,
+                         scalableDims);
 }
 
-/// Parse a dimension list in a vector type. This populates the dimension list,
-/// and returns the number of scalable dimensions in `numScalableDims`.
+/// Parse a dimension list in a vector type. This populates the dimension list.
+/// For i-th dimension, `scalableDims[i]` contains either:
+///   * `false` for a non-scalable dimension (e.g. `4`),
+///   * `true` for a scalable dimension (e.g. `[4]`).
+/// This method also returns the number of scalable dimensions in
+/// `numScalableDims`.
 ///
-/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
-/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// vector-dim-list := (static-dim-list `x`)?
+/// static-dim-list ::= static-dim (`x` static-dim)*
+/// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
 ///
 ParseResult
 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                 unsigned &numScalableDims) {
+                                 unsigned &numScalableDims,
+                                 SmallVectorImpl<bool> &scalableDims) {
   numScalableDims = 0;
   // If there is a set of fixed-length dimensions, consume it
-  while (getToken().is(Token::integer)) {
+  while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
     int64_t value;
+    bool scalable = consumeIf(Token::l_square);
     if (parseIntegerInDimensionList(value))
       return failure();
     dimensions.push_back(value);
+    if (scalable) {
+      if (!consumeIf(Token::r_square))
+        return emitWrongTokenError("missing ']' closing scalable dimension");
+      numScalableDims++;
+    }
+    scalableDims.push_back(scalable);
     // Make sure we have an 'x' or something like 'xbf32'.
     if (parseXInDimensionList())
       return failure();
   }
-  // If there is a set of scalable dimensions, consume it
-  if (consumeIf(Token::l_square)) {
-    while (getToken().is(Token::integer)) {
-      int64_t value;
-      if (parseIntegerInDimensionList(value))
-        return failure();
-      dimensions.push_back(value);
-      numScalableDims++;
-      // Check if we have reached the end of the scalable dimension list
-      if (consumeIf(Token::r_square)) {
-        // Make sure we have something like 'xbf32'.
-        return parseXInDimensionList();
-      }
-      // Make sure we have an 'x'
-      if (parseXInDimensionList())
-        return failure();
-    }
-    // If we make it here, we've finished parsing the dimension list
-    // without finding ']' closing the set of scalable dimensions
-    return emitWrongTokenError(
-        "missing ']' closing set of scalable dimensions");
-  }
 
   return success();
 }

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 5b313af8cee33..8269546d98365 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -994,6 +994,10 @@ class DialectReader : public DialectBytecodeReader {
     return success();
   }
 
+  LogicalResult readBool(bool &result) override {
+    return reader.parseByte(result);
+  }
+
 private:
   AttrTypeReader &attrTypeReader;
   StringSectionReader &stringReader;

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 936117aa2b8fc..f02389927019c 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -396,6 +396,8 @@ class DialectWriter : public DialectBytecodeWriter {
         reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
   }
 
+  void writeOwnedBool(bool value) override { emitter.emitByte(value); }
+
   int64_t getBytecodeVersion() const override { return bytecodeVersion; }
 
 private:

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 36f7a268a6a13..2547d815c12b1 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -45,6 +45,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
     // file locations.
   }
   void writeOwnedBlob(ArrayRef<char> blob) override {}
+  void writeOwnedBool(bool value) override {}
 
   int64_t getBytecodeVersion() const override {
     llvm_unreachable("unexpected querying of version in IRNumbering");

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 21ef3076a5a47..0449ba99c0817 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -463,8 +463,9 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
     return {};
   if (type.getShape().empty())
     return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
-                                    type.getNumScalableDims());
+  Type vectorType =
+      VectorType::get(type.getShape().back(), elementType,
+                      type.getNumScalableDims(), type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
   auto shape = type.getShape();

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e0dd2d6fbc03b..633f296b0702d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -123,7 +123,8 @@ static Type getI1SameShape(Type type) {
     return UnrankedTensorType::get(i1Type);
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return VectorType::get(vectorType.getShape(), i1Type,
-                           vectorType.getNumScalableDims());
+                           vectorType.getNumScalableDims(),
+                           vectorType.getScalableDims());
   return i1Type;
 }
 

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 765242b5416af..cdbf45bdf7f30 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -30,7 +30,8 @@ static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
   if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
     return VectorType::get(sVectorType.getShape(), i1Type,
-                           sVectorType.getNumScalableDims());
+                           sVectorType.getNumScalableDims(),
+                           sVectorType.getScalableDims());
   return nullptr;
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 95d76a14d2bd3..1039bd23b3c15 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -992,7 +992,13 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
       return LLVMScalableVectorType::get(elementType, numElements);
     return LLVMFixedVectorType::get(elementType, numElements);
   }
-  return VectorType::get(numElements, elementType, (unsigned)isScalable);
+
+  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
+  // scalable/non-scalable.
+  SmallVector<bool> scalableDims(1, isScalable);
+
+  return VectorType::get(numElements, elementType,
+                         static_cast<unsigned>(isScalable), scalableDims);
 }
 
 Type mlir::LLVM::getVectorType(Type elementType,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 45b35d3e01a6a..d0fcaada603d2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -225,7 +225,8 @@ struct VectorizationState {
 
     // TODO: Extend scalable vector type to support a bit map.
     bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
-    return VectorType::get(vectorShape, elementType, numScalableDims);
+    return VectorType::get(vectorShape, elementType, numScalableDims,
+                           scalableDims);
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
@@ -1227,7 +1228,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     if (firstMaxRankedType) {
       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
                                      getElementTypeOrSelf(vecOperand.getType()),
-                                     firstMaxRankedType.getNumScalableDims());
+                                     firstMaxRankedType.getNumScalableDims(),
+                                     firstMaxRankedType.getScalableDims());
       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
     } else {
       vecOperands.push_back(vecOperand);
@@ -1239,7 +1241,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     resultTypes.push_back(
         firstMaxRankedType
             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getNumScalableDims())
+                              firstMaxRankedType.getNumScalableDims(),
+                              firstMaxRankedType.getScalableDims())
             : resultType);
   }
   //   d. Build and return the new op.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index caef60eb1ab7b..77bd330ee2eef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -57,7 +57,8 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
 /// Constructs vector type for element type.
 static VectorType vectorType(VL vl, Type etp) {
   unsigned numScalableDims = vl.enableVLAVectorization;
-  return VectorType::get(vl.vectorLength, etp, numScalableDims);
+  return VectorType::get(vl.vectorLength, etp, numScalableDims,
+                         vl.enableVLAVectorization);
 }
 
 /// Constructs vector type from a memref value.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a3220ef85b6f9..7dd05f519bdea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -345,9 +345,9 @@ LogicalResult MultiDimReductionOp::verify() {
 /// Returns the mask type expected by this operation.
 Type MultiDimReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getNumScalableDims());
+  return VectorType::get(
+      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
+      vecType.getNumScalableDims(), vecType.getScalableDims());
 }
 
 namespace {
@@ -484,9 +484,9 @@ void ReductionOp::print(OpAsmPrinter &p) {
 /// Returns the mask type expected by this operation.
 Type ReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getNumScalableDims());
+  return VectorType::get(
+      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
+      vecType.getNumScalableDims(), vecType.getScalableDims());
 }
 
 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -2788,16 +2788,22 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(parser.getNameLoc(),
                             "expected vector type for operand #1");
 
-  unsigned numScalableDims = vLHS.getNumScalableDims();
   VectorType resType;
   if (vRHS) {
-    numScalableDims += vRHS.getNumScalableDims();
+    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
+                                      vRHS.getScalableDims()[0]};
+    auto numScalableDims =
+        count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
     resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                              vLHS.getElementType(), numScalableDims);
+                              vLHS.getElementType(), numScalableDims,
+                              scalableDimsRes);
   } else {
     // Scalar RHS operand
+    SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
+    auto numScalableDims =
+        count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
     resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
-                              numScalableDims);
+                              numScalableDims, scalableDimsRes);
   }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
@@ -2861,9 +2867,9 @@ LogicalResult OuterProductOp::verify() {
 /// verification purposes. It requires the operation to be vectorized."
 Type OuterProductOp::getExpectedMaskType() {
   auto vecType = this->getResultVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getNumScalableDims());
+  return VectorType::get(
+      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
+      vecType.getNumScalableDims(), vecType.getScalableDims());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3516,12 +3522,14 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
                                           AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
-  // TODO: Extend the scalable vector type representation with a bit map.
-  assert((permMap.isMinorIdentity() || vecType.getNumScalableDims() == 0) &&
-         "Scalable vectors are not supported yet");
   assert(invPermMap && "Inversed permutation map couldn't be computed");
   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
-  return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims());
+
+  SmallVector<bool> scalableDims =
+      applyPermutationMap(invPermMap, vecType.getScalableDims());
+
+  return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(),
+                         scalableDims);
 }
 
 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -4479,9 +4487,9 @@ LogicalResult GatherOp::verify() {
 /// verification purposes. It requires the operation to be vectorized."
 Type GatherOp::getExpectedMaskType() {
   auto vecType = this->getIndexVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1),
-                         vecType.getNumScalableDims());
+  return VectorType::get(
+      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
+      vecType.getNumScalableDims(), vecType.getScalableDims());
 }
 
 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ef0bf75a9cd67..c0975a6be4be9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2458,19 +2458,18 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         }
       })
       .Case<VectorType>([&](VectorType vectorTy) {
+        auto scalableDims = vectorTy.getScalableDims();
         os << "vector<";
         auto vShape = vectorTy.getShape();
         unsigned lastDim = vShape.size();
-        unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
         unsigned dimIdx = 0;
-        for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
-          os << vShape[dimIdx] << 'x';
-        if (vectorTy.isScalable()) {
-          os << '[';
-          unsigned secondToLastDim = lastDim - 1;
-          for (; dimIdx < secondToLastDim; dimIdx++)
-            os << vShape[dimIdx] << 'x';
-          os << vShape[dimIdx] << "]x";
+        for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
+          if (!scalableDims.empty() && scalableDims[dimIdx])
+            os << '[';
+          os << vShape[dimIdx];
+          if (!scalableDims.empty() && scalableDims[dimIdx])
+            os << ']';
+          os << 'x';
         }
         printType(vectorTy.getElementType());
         os << '>';

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index eea07edfdab3c..62ef2c63444b9 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 unsigned numScalableDims) {
+                                 unsigned numScalableDims,
+                                 ArrayRef<bool> scalableDims) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -238,6 +239,21 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
            << "vector types must have positive constant sizes but got "
            << shape;
 
+  if (numScalableDims > shape.size())
+    return emitError()
+           << "number of scalable dims cannot exceed the number of dims"
+           << " (" << numScalableDims << " vs " << shape.size() << ")";
+
+  if (scalableDims.size() != shape.size())
+    return emitError() << "number of dims must match, got "
+                       << scalableDims.size() << " and " << shape.size();
+
+  auto numScale =
+      count_if(scalableDims, [](bool isScalable) { return isScalable; });
+  if (numScale != numScalableDims)
+    return emitError() << "number of scalable dims must match, explicit: "
+                       << numScalableDims << ", and bools:" << numScale;
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Builtin/invalid.mlir b/mlir/test/Dialect/Builtin/invalid.mlir
index 79c8b8337af9d..74fffff86a584 100644
--- a/mlir/test/Dialect/Builtin/invalid.mlir
+++ b/mlir/test/Dialect/Builtin/invalid.mlir
@@ -13,7 +13,10 @@
 // VectorType
 //===----------------------------------------------------------------------===//
 
-// expected-error at +1 {{missing ']' closing set of scalable dimensions}}
+// expected-error at +1 {{missing ']' closing scalable dimension}}
 func.func @scalable_vector_arg(%arg0: vector<[4xf32>) { }
 
 // -----
+
+// expected-error at +1 {{missing ']' closing scalable dimension}}
+func.func @scalable_vector_arg(%arg0: vector<[4x4]xf32>) { }

diff  --git a/mlir/test/Dialect/Builtin/ops.mlir b/mlir/test/Dialect/Builtin/ops.mlir
index 5e0ea413ab62d..4e8b1f7efde23 100644
--- a/mlir/test/Dialect/Builtin/ops.mlir
+++ b/mlir/test/Dialect/Builtin/ops.mlir
@@ -27,10 +27,10 @@
 %scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32>
 
 // A 2D scalable vector
-%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64>
+%scalable_vector_2d = "foo.op"() : () -> vector<[2]x[2]xf64>
 
 // A 2D scalable vector with fixed-length dimensions
 %scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16>
 
 // A multi-dimensional vector with mixed scalable and fixed-length dimensions
-%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8>
+%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4]x[4]xi8>

diff  --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 8e010d2183fd7..7d9923e036660 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -7,7 +7,7 @@ func.func @scalable_outerproduct(%src : memref<?xf32>) {
   %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
 
   %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32>
-  vector.store %op, %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+  vector.store %op, %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
 
   %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32
   vector.store %op2, %src[%idx] : memref<?xf32>, vector<[4]xf32>
@@ -28,9 +28,9 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
 
 func.func @invalid_outerproduct1(%src : memref<?xf32>) {
   %idx = arith.constant 0 : index
-  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
   %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
 
-  // expected-error @+1 {{expected 1-d vector for operand #1}}
-  %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32>
+  // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
+  %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
 }


        


More information about the Mlir-commits mailing list