[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `LLVMFixedVectorType` (PR #133286)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Mar 27 10:47:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Since #<!-- -->125690, the MLIR vector type supports `!llvm.ptr` as an element type. The only remaining element type for `LLVMFixedVectorType` is now `LLVMPPCFP128Type`.
This commit turns `LLVMPPCFP128Type` into a proper FP type (by implementing `FloatTypeInterface`), so that the MLIR vector type accepts it as an element type. This makes `LLVMFixedVectorType` obsolete. This commit deletes `LLVMFixedVectorType`.
Note: `LLVMScalableVectorType` remains for now.
Depends on #<!-- -->125690.
---
Patch is 23.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133286.diff
10 Files Affected:
- (modified) mlir/docs/Dialects/LLVM.md (+3-5)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (-1)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+14-32)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+22-32)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+1-1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp (+7-11)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+28-67)
- (modified) mlir/lib/Target/LLVMIR/TypeToLLVM.cpp (+2-8)
- (modified) mlir/test/Dialect/LLVMIR/types-invalid.mlir (-19)
- (modified) mlir/test/Dialect/LLVMIR/types.mlir (+2)
``````````diff
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index fadc81b567b4e..81c358244d96e 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -327,11 +327,9 @@ multiple of some fixed size in case of _scalable_ vectors, and the element type.
Vectors cannot be nested and only 1D vectors are supported. Scalable vectors are
still considered 1D.
-LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in
-types, and provides additional types for fixed-sized vectors of LLVM dialect
-types (`LLVMFixedVectorType`) and scalable vectors of any types
-(`LLVMScalableVectorType`). These two additional types share the following
-syntax:
+The LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in
+types, and provides additional types for scalable vectors of any types
+(`LLVMScalableVectorType`):
```
llvm-vec-type ::= `!llvm.vec<` (`?` `x`)? integer-literal `x` type `>`
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index bca0feb45aab2..9d238fc746b8f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -67,7 +67,6 @@ namespace LLVM {
}
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 3386003cb61fb..fe12ab99b9141 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -288,38 +289,6 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
];
}
-//===----------------------------------------------------------------------===//
-// LLVMFixedVectorType
-//===----------------------------------------------------------------------===//
-
-def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
- let summary = "LLVM fixed vector type";
- let description = [{
- LLVM dialect vector type that supports all element types that are supported
- in LLVM vectors but that are not supported by the builtin MLIR vector type.
- E.g., LLVMFixedVectorType supports LLVM pointers as element type.
- }];
-
- let typeName = "llvm.fixed_vec";
-
- let parameters = (ins "Type":$elementType, "unsigned":$numElements);
- let assemblyFormat = [{
- `<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
- }];
-
- let genVerifyDecl = 1;
-
- let builders = [
- TypeBuilderWithInferredContext<(ins "Type":$elementType,
- "unsigned":$numElements)>
- ];
-
- let extraClassDeclaration = [{
- /// Checks if the given type can be used in a vector type.
- static bool isValidElementType(Type type);
- }];
-}
-
//===----------------------------------------------------------------------===//
// LLVMScalableVectorType
//===----------------------------------------------------------------------===//
@@ -400,4 +369,17 @@ def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
}];
}
+//===----------------------------------------------------------------------===//
+// LLVMPPCFP128Type
+//===----------------------------------------------------------------------===//
+
+def LLVMPPCFP128Type : LLVMType<"LLVMPPCFP128", "ppc_fp128",
+ [DeclareTypeInterfaceMethods<FloatTypeInterface, ["getFloatSemantics"]>]> {
+ let summary = "128 bit FP type with IBM double-double semantics";
+ let description = [{
+ A 128 bit floating-point type with IBM double-double semantics.
+ See S_PPCDoubleDouble in APFloat.h for details.
+ }];
+}
+
#endif // LLVMTYPES_TD
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 18a70cc64628f..29701ffc89b19 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -686,8 +686,6 @@ static Type extractVectorElementType(Type type) {
return vectorType.getElementType();
if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
return scalableVectorType.getElementType();
- if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
- return fixedVectorType.getElementType();
return type;
}
@@ -724,20 +722,19 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
if (rawConstantIndices.size() == 1 || !currType)
continue;
- currType =
- TypeSwitch<Type, Type>(currType)
- .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
- LLVMArrayType>([](auto containerType) {
- return containerType.getElementType();
- })
- .Case([&](LLVMStructType structType) -> Type {
- int64_t memberIndex = rawConstantIndices.back();
- if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
- structType.getBody().size())
- return structType.getBody()[memberIndex];
- return nullptr;
- })
- .Default(Type(nullptr));
+ currType = TypeSwitch<Type, Type>(currType)
+ .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>(
+ [](auto containerType) {
+ return containerType.getElementType();
+ })
+ .Case([&](LLVMStructType structType) -> Type {
+ int64_t memberIndex = rawConstantIndices.back();
+ if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
+ structType.getBody().size())
+ return structType.getBody()[memberIndex];
+ return nullptr;
+ })
+ .Default(Type(nullptr));
}
}
@@ -838,11 +835,11 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos,
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
indices, emitOpError);
})
- .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
- LLVMArrayType>([&](auto containerType) -> LogicalResult {
- return verifyStructIndices(containerType.getElementType(), indexPos + 1,
- indices, emitOpError);
- })
+ .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>(
+ [&](auto containerType) -> LogicalResult {
+ return verifyStructIndices(containerType.getElementType(),
+ indexPos + 1, indices, emitOpError);
+ })
.Default([&](auto otherType) -> LogicalResult {
return emitOpError()
<< "type " << otherType << " cannot be indexed (index #"
@@ -3108,16 +3105,14 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//
/// Compute the total number of elements in the given type, also taking into
-/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
-/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
+/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
+/// Everything else is treated as a scalar.
static int64_t getNumElements(Type t) {
if (auto vecType = dyn_cast<VectorType>(t))
return vecType.getNumElements() * getNumElements(vecType.getElementType());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return arrayType.getNumElements() *
getNumElements(arrayType.getElementType());
- if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
- return vecType.getNumElements() * getNumElements(vecType.getElementType());
assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
"number of elements of a scalable vector type is unknown");
return 1;
@@ -3135,8 +3130,6 @@ static bool hasScalableVectorType(Type t) {
}
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return hasScalableVectorType(arrayType.getElementType());
- if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
- return hasScalableVectorType(vecType.getElementType());
return false;
}
@@ -3216,8 +3209,7 @@ LogicalResult LLVM::ConstantOp::verify() {
<< "scalable vector type requires a splat attribute";
return success();
}
- if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
- getType()))
+ if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
return emitOpError() << "expected vector or array type";
// The number of elements of the attribute and the type must match.
int64_t attrNumElements;
@@ -3466,8 +3458,7 @@ LogicalResult LLVM::BitcastOp::verify() {
if (!resultType)
return success();
- auto isVector =
- llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
+ auto isVector = llvm::IsaPred<VectorType, LLVMScalableVectorType>;
// Due to bitcast requiring both operands to be of the same size, it is not
// possible for only one of the two to be a pointer of vectors.
@@ -3883,7 +3874,6 @@ void LLVMDialect::initialize() {
// clang-format off
addTypes<LLVMVoidType,
- LLVMPPCFP128Type,
LLVMTokenType,
LLVMLabelType,
LLVMMetadataType>();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 51dcb071f9c18..c5a1502c8cbe8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -137,7 +137,7 @@ static bool isSupportedTypeForConversion(Type type) {
// LLVM vector types are only used for either pointers or target specific
// types. These types cannot be casted in the general case, thus the memory
// optimizations do not support them.
- if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+ if (isa<LLVM::LLVMScalableVectorType>(type))
return false;
if (auto vectorType = dyn_cast<VectorType>(type)) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index d700dc52d42d2..edfc5adeb424e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -40,8 +40,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMMetadataType>([&](Type) { return "metadata"; })
.Case<LLVMFunctionType>([&](Type) { return "func"; })
.Case<LLVMPointerType>([&](Type) { return "ptr"; })
- .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
- [&](Type) { return "vec"; })
+ .Case<LLVMScalableVectorType>([&](Type) { return "vec"; })
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
@@ -104,9 +103,9 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
printer << getTypeKeyword(type);
llvm::TypeSwitch<Type>(type)
- .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
- LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType,
- LLVMStructType>([&](auto type) { type.print(printer); });
+ .Case<LLVMPointerType, LLVMArrayType, LLVMScalableVectorType,
+ LLVMFunctionType, LLVMTargetExtType, LLVMStructType>(
+ [&](auto type) { type.print(printer); });
}
//===----------------------------------------------------------------------===//
@@ -143,14 +142,11 @@ static Type parseVectorType(AsmParser &parser) {
}
bool isScalable = dims.size() == 2;
- if (isScalable)
- return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
- if (elementType.isSignlessIntOrFloat()) {
- parser.emitError(typePos)
- << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
+ if (!isScalable) {
+ parser.emitError(dimPos) << "expected scalable vector";
return Type();
}
- return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
+ return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
}
/// Attempts to set the body of an identified structure type. Reports a parsing
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 403756765268e..b008659c7e958 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -658,7 +658,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
}
//===----------------------------------------------------------------------===//
-// Vector types.
+// LLVMScalableVectorType.
//===----------------------------------------------------------------------===//
/// Verifies that the type about to be constructed is well-formed.
@@ -675,35 +675,6 @@ verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
- unsigned numElements) {
- assert(elementType && "expected non-null subtype");
- return Base::get(elementType.getContext(), elementType, numElements);
-}
-
-LLVMFixedVectorType
-LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
- Type elementType, unsigned numElements) {
- assert(elementType && "expected non-null subtype");
- return Base::getChecked(emitError, elementType.getContext(), elementType,
- numElements);
-}
-
-bool LLVMFixedVectorType::isValidElementType(Type type) {
- return llvm::isa<LLVMPPCFP128Type>(type);
-}
-
-LogicalResult
-LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type elementType, unsigned numElements) {
- return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
- emitError, elementType, numElements);
-}
-
-//===----------------------------------------------------------------------===//
-// LLVMScalableVectorType.
-//===----------------------------------------------------------------------===//
-
LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
@@ -762,6 +733,14 @@ bool LLVM::LLVMTargetExtType::supportsMemOps() const {
return false;
}
+//===----------------------------------------------------------------------===//
+// LLVMPPCFP128Type
+//===----------------------------------------------------------------------===//
+
+const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
+ return APFloat::PPCDoubleDouble();
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
@@ -783,7 +762,6 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMPointerType,
LLVMStructType,
LLVMTokenType,
- LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMTargetExtType,
LLVMVoidType,
@@ -832,7 +810,6 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
})
// clang-format off
.Case<
- LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMArrayType
>([&](auto containerType) {
@@ -880,7 +857,7 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
}
bool mlir::LLVM::isCompatibleVectorType(Type type) {
- if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
+ if (llvm::isa<LLVMScalableVectorType>(type))
return true;
if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
@@ -897,7 +874,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
Type mlir::LLVM::getVectorElementType(Type type) {
return llvm::TypeSwitch<Type, Type>(type)
- .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
+ .Case<LLVMScalableVectorType, VectorType>(
[](auto ty) { return ty.getElementType(); })
.Default([](Type) -> Type {
llvm_unreachable("incompatible with LLVM vector type");
@@ -911,9 +888,6 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
return llvm::ElementCount::getScalable(ty.getNumElements());
return llvm::ElementCount::getFixed(ty.getNumElements());
})
- .Case([](LLVMFixedVectorType ty) {
- return llvm::ElementCount::getFixed(ty.getNumElements());
- })
.Case([](LLVMScalableVectorType ty) {
return llvm::ElementCount::getScalable(ty.getMinNumElements());
})
@@ -923,30 +897,28 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
}
bool mlir::LLVM::isScalableVectorType(Type vectorType) {
- assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
- vectorType)) &&
+ assert((llvm::isa<LLVMScalableVectorType, VectorType>(vectorType)) &&
"expected LLVM-compatible vector type");
- return !llvm::isa<LLVMFixedVectorType>(vectorType) &&
- (llvm::isa<LLVMScalableVectorType>(vectorType) ||
- llvm::cast<VectorType>(vectorType).isScalable());
+ return llvm::isa<LLVMScalableVectorType>(vectorType) ||
+ llvm::cast<VectorType>(vectorType).isScalable();
}
Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
bool isScalable) {
- bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
- bool useBuiltIn = VectorType::isValidElementType(elementType);
- (void)useBuiltIn;
- assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
- "to be either builtin or LLVM dialect type");
- if (useLLVM) {
- if (isScalable)
- return LLVMScalableVectorType::get(elementType, numElements);
- return LLVMFixedVectorType::get(elementType, numElements);
+ if (!isScalable) {
+ // Non-scalable vectors always use the MLIR vector type.
+ assert(VectorType::isValidElementType(elementType) &&
+ "incompatible element type");
+ return VectorType::get(numElements, elementType, {false});
}
- // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
- // scalable/non-scalable.
- return VectorType::get(numElements, elementType, {isScalable});
+ // This is a scalable vector.
+ if (VectorType::isValidElementType(elementType))
+ return VectorType::get(numElements, elementType, {true});
+ assert(LLVMScalableVectorType::isValidElementType(elementType) &&
+ "neither the MLIR vector type nor LLVMScalableVectorType is "
+ "compatible with the specified element type");
+ return LLVMScalableVectorType::get(elementType, numElements);
}
Type mlir::LLVM::getVectorType(Type elementType,
@@ -959,13 +931,8 @@ Type mlir::LLVM::getVectorType(Type elementType,
}
Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
- bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
- bool useBuiltIn = VectorType::isValidElementType(elementType);
- (void)useBuiltIn;
- assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
- "to be either builtin or LLVM dialect type");
- if (useLLVM)
- return LLVMFixedVectorType::get(elementType, numElements);
+ assert(VectorType::isValidElementType(elementType) &&
+ "incompatible element type");
return VectorType::get(numElements, elementType);
}
@@ -1000,12 +967,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
})
.Case<LLVMPPCFP128Type>(
[](Type) { return llvm::TypeSize::getFixed(128); })
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/133286
More information about the llvm-branch-commits
mailing list