[Mlir-commits] [mlir] f22af20 - [mlir][VectorType] Remove `numScalableDims` from the vector type
Andrzej Warzynski
llvmlistbot at llvm.org
Wed Jun 28 05:53:51 PDT 2023
Author: Andrzej Warzynski
Date: 2023-06-28T13:53:45+01:00
New Revision: f22af204edfd1a8f16511b2635ed41c00fc5502f
URL: https://github.com/llvm/llvm-project/commit/f22af204edfd1a8f16511b2635ed41c00fc5502f
DIFF: https://github.com/llvm/llvm-project/commit/f22af204edfd1a8f16511b2635ed41c00fc5502f.diff
LOG: [mlir][VectorType] Remove `numScalableDims` from the vector type
This is a follow-up of https://reviews.llvm.org/D153372 in which
`numScalableDims` (single integer) was effectively replaced with
`isScalableDim` bitmask.
This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
* https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
Differential Revision: https://reviews.llvm.org/D153412
Added:
Modified:
mlir/include/mlir/IR/BuiltinDialectBytecode.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/AsmParser/Parser.h
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/IR/BuiltinTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 40e6f04451c651..fcbb5f45acbc51 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -275,18 +275,17 @@ def VectorType : DialectType<(type
Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
- let printerPredicate = "!$_val.getNumScalableDims()";
+ let printerPredicate = "!$_val.isScalable()";
}
def VectorTypeWithScalableDims : DialectType<(type
Array<BoolList>:$scalableDims,
- VarInt:$numScalableDims,
Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
- let printerPredicate = "$_val.getNumScalableDims()";
+ let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 1fd869be76e9b1..f22421aa7a428d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -306,23 +306,20 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()),
- numScalableDims(other.getNumScalableDims()),
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
- : shape(shape), elementType(elementType),
- numScalableDims(numScalableDims) {
+ : shape(shape), elementType(elementType) {
if (scalableDims.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
this->scalableDims = scalableDims;
}
- Builder &setShape(ArrayRef<int64_t> newShape, unsigned newNumScalableDims = 0,
+ Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
- numScalableDims = newNumScalableDims;
if (newIsScalableDim.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
@@ -340,8 +337,6 @@ class VectorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
- if (pos >= shape.size() - numScalableDims)
- numScalableDims--;
if (storage.empty())
storage.append(shape.begin(), shape.end());
if (storageScalableDims.empty())
@@ -360,7 +355,7 @@ class VectorType::Builder {
operator Type() {
if (shape.empty())
return elementType;
- return VectorType::get(shape, elementType, numScalableDims, scalableDims);
+ return VectorType::get(shape, elementType, scalableDims);
}
private:
@@ -368,7 +363,6 @@ class VectorType::Builder {
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
- unsigned numScalableDims;
ArrayRef<bool> scalableDims;
// Owning scalableDims data for copy-on-write operations.
SmallVector<bool> storageScalableDims;
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dead6297f379e6..900531b1953c4b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1066,13 +1066,11 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- "unsigned":$numScalableDims,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"unsigned", "0">:$numScalableDims,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
@@ -1082,8 +1080,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType,
- numScalableDims, scalableDims);
+ return $_get(elementType.getContext(), shape, elementType, scalableDims);
}]>
];
let extraClassDeclaration = [{
@@ -1100,7 +1097,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
- return getNumScalableDims() > 0;
+ return llvm::is_contained(getScalableDims(), true);
+ }
+ bool allDimsScalable() const {
+ // Treat 0-d vectors as fixed size.
+ if (getRank() == 0)
+ return false;
+ return !llvm::is_contained(getScalableDims(), false);
}
/// Get or create a new VectorType with the same shape as `this` and an
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 655412da2b742f..9704cea0e2b55c 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -211,7 +211,6 @@ class Parser {
/// Parse a vector type.
VectorType parseVectorType();
ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
- unsigned &numScalableDims,
SmallVectorImpl<bool> &scalableDims);
ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true,
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 6eeea41d97c42c..6a65dda505a1c1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -441,8 +441,7 @@ VectorType Parser::parseVectorType() {
SmallVector<int64_t, 4> dimensions;
SmallVector<bool, 4> scalableDims;
- unsigned numScalableDims;
- if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims))
+ if (parseVectorDimensionList(dimensions, scalableDims))
return nullptr;
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
return emitError(getToken().getLoc(),
@@ -459,16 +458,13 @@ VectorType Parser::parseVectorType() {
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, numScalableDims,
- scalableDims);
+ return VectorType::get(dimensions, elementType, scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
/// For i-th dimension, `scalableDims[i]` contains either:
/// * `false` for a non-scalable dimension (e.g. `4`),
/// * `true` for a scalable dimension (e.g. `[4]`).
-/// This method also returns the number of scalable dimensions in
-/// `numScalableDims`.
///
/// vector-dim-list := (static-dim-list `x`)?
/// static-dim-list ::= static-dim (`x` static-dim)*
@@ -476,9 +472,7 @@ VectorType Parser::parseVectorType() {
///
ParseResult
Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
- unsigned &numScalableDims,
SmallVectorImpl<bool> &scalableDims) {
- numScalableDims = 0;
// If there is a set of fixed-length dimensions, consume it
while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
int64_t value;
@@ -489,7 +483,6 @@ Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
if (scalable) {
if (!consumeIf(Token::r_square))
return emitWrongTokenError("missing ']' closing scalable dimension");
- numScalableDims++;
}
scalableDims.push_back(scalable);
// Make sure we have an 'x' or something like 'xbf32'.
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 0449ba99c08178..4ca5c7510d9eda 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -463,11 +463,12 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
return {};
if (type.getShape().empty())
return VectorType::get({1}, elementType);
- Type vectorType =
- VectorType::get(type.getShape().back(), elementType,
- type.getNumScalableDims(), type.getScalableDims().back());
+ Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
+ assert((type.isScalable() == type.allDimsScalable()) &&
+ "expected scalable vector with all dims scalable");
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4175f8fcafe72e..990138549abf4c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -31,21 +31,15 @@ using namespace mlir::vector;
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- unsigned numScalableDims = tp.getNumScalableDims();
- if (tp.getShape().size() == numScalableDims)
- --numScalableDims;
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
- numScalableDims);
+ tp.getScalableDims().drop_front());
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- unsigned numScalableDims = tp.getNumScalableDims();
- if (numScalableDims > 0)
- --numScalableDims;
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
- numScalableDims);
+ tp.getScalableDims().take_back());
}
// Helper that picks the proper sequence for inserting.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 633f296b0702d4..5e95d16b87a635 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -123,7 +123,6 @@ static Type getI1SameShape(Type type) {
return UnrankedTensorType::get(i1Type);
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(vectorType.getShape(), i1Type,
- vectorType.getNumScalableDims(),
vectorType.getScalableDims());
return i1Type;
}
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index cdbf45bdf7f304..4af836a93c2a16 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -30,7 +30,6 @@ static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(sVectorType.getShape(), i1Type,
- sVectorType.getNumScalableDims(),
sVectorType.getScalableDims());
return nullptr;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 1039bd23b3c159..bc8300a8b7329e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -995,10 +995,7 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
- SmallVector<bool> scalableDims(1, isScalable);
-
- return VectorType::get(numElements, elementType,
- static_cast<unsigned>(isScalable), scalableDims);
+ return VectorType::get(numElements, elementType, {isScalable});
}
Type mlir::LLVM::getVectorType(Type elementType,
@@ -1030,7 +1027,10 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
"type");
if (useLLVM)
return LLVMScalableVectorType::get(elementType, numElements);
- return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
+
+ // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
+ // scalable/non-scalable.
+ return VectorType::get(numElements, elementType, /*scalableDims=*/true);
}
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d0fcaada603d25..a0bfd7f7b88d5a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -223,10 +223,7 @@ struct VectorizationState {
assert(areValidScalableVecDims(scalableDims) &&
"Permuted scalable vector dimensions are not supported");
- // TODO: Extend scalable vector type to support a bit map.
- bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
- return VectorType::get(vectorShape, elementType, numScalableDims,
- scalableDims);
+ return VectorType::get(vectorShape, elementType, scalableDims);
}
/// Masks an operation with the canonical vector mask if the operation needs
@@ -1228,7 +1225,6 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
if (firstMaxRankedType) {
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
getElementTypeOrSelf(vecOperand.getType()),
- firstMaxRankedType.getNumScalableDims(),
firstMaxRankedType.getScalableDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
@@ -1241,7 +1237,6 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
resultTypes.push_back(
firstMaxRankedType
? VectorType::get(firstMaxRankedType.getShape(), resultType,
- firstMaxRankedType.getNumScalableDims(),
firstMaxRankedType.getScalableDims())
: resultType);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 77bd330ee2eef8..93ee0647b7b5a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -56,9 +56,7 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
- unsigned numScalableDims = vl.enableVLAVectorization;
- return VectorType::get(vl.vectorLength, etp, numScalableDims,
- vl.enableVLAVectorization);
+ return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
}
/// Constructs vector type from a memref value.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 7a39aa48d8706f..fc87c8413c36ff 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1176,7 +1176,7 @@ Type Merger::inferType(ExprId e, Value src) const {
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
if (auto vtp = dyn_cast<VectorType>(src.getType()))
- return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
+ return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
return dtp;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7dd05f519bdeae..c2562af6e582ab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -345,9 +345,9 @@ LogicalResult MultiDimReductionOp::verify() {
/// Returns the mask type expected by this operation.
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
namespace {
@@ -484,9 +484,9 @@ void ReductionOp::print(OpAsmPrinter &p) {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -929,8 +929,7 @@ Type ContractionOp::getExpectedMaskType() {
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
// TODO: Extend the scalable vector type representation with a bit map.
- assert(lhsType.getNumScalableDims() == 0 &&
- rhsType.getNumScalableDims() == 0 &&
+ assert(!lhsType.isScalable() && !rhsType.isScalable() &&
"Scalable vectors are not supported yet");
return VectorType::get(maskShape,
@@ -2792,18 +2791,13 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
if (vRHS) {
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
vRHS.getScalableDims()[0]};
- auto numScalableDims =
- count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType(), numScalableDims,
- scalableDimsRes);
+ vLHS.getElementType(), scalableDimsRes);
} else {
// Scalar RHS operand
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
- auto numScalableDims =
- count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
- numScalableDims, scalableDimsRes);
+ scalableDimsRes);
}
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
@@ -2867,9 +2861,9 @@ LogicalResult OuterProductOp::verify() {
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
//===----------------------------------------------------------------------===//
@@ -3528,8 +3522,7 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
SmallVector<bool> scalableDims =
applyPermutationMap(invPermMap, vecType.getScalableDims());
- return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(),
- scalableDims);
+ return VectorType::get(maskShape, i1Type, scalableDims);
}
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -4487,9 +4480,9 @@ LogicalResult GatherOp::verify() {
/// verification purposes. It requires the operation to be vectorized."
Type GatherOp::getExpectedMaskType() {
auto vecType = this->getIndexVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea42d57d2fb0a6..abe6d8846a2357 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1024,7 +1024,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
- vtp.getNumScalableDims()),
+ vtp.getScalableDims()),
b);
if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 62ef2c63444b90..e29555f93e9407 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,6 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- unsigned numScalableDims,
ArrayRef<bool> scalableDims) {
if (!isValidElementType(elementType))
return emitError()
@@ -239,21 +238,10 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
<< "vector types must have positive constant sizes but got "
<< shape;
- if (numScalableDims > shape.size())
- return emitError()
- << "number of scalable dims cannot exceed the number of dims"
- << " (" << numScalableDims << " vs " << shape.size() << ")";
-
if (scalableDims.size() != shape.size())
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
- auto numScale =
- count_if(scalableDims, [](bool isScalable) { return isScalable; });
- if (numScale != numScalableDims)
- return emitError() << "number of scalable dims must match, explicit: "
- << numScalableDims << ", and bools:" << numScale;
-
return success();
}
@@ -262,17 +250,17 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+ return VectorType::get(getShape(), scaledEt, getScalableDims());
if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+ return VectorType::get(getShape(), scaledEt, getScalableDims());
return VectorType();
}
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
- getNumScalableDims());
+ getScalableDims());
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list