[Mlir-commits] [mlir] [mlir][LLVM] Delete `getVectorElementType` (PR #134981)
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 9 09:47:17 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/134981
>From 2ae68ba0ceafef57383ea1f395f451aa052335ec Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 9 Apr 2025 11:05:31 +0200
Subject: [PATCH] [mlir][LLVM] Delete `getVectorElementType`
---
mlir/docs/Dialects/LLVM.md | 2 --
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 14 +++++++------
.../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 14 ++++++++-----
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 8 +++++---
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 4 ----
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 7 +++----
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 20 ++++++++++---------
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 6 ------
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 6 +++---
mlir/test/Dialect/LLVMIR/invalid.mlir | 6 +++---
mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 8 ++++----
11 files changed, 46 insertions(+), 49 deletions(-)
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index d0509e036682f..468f69c419071 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -334,8 +334,6 @@ compatible with the LLVM dialect:
- `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
vector type compatible with the LLVM dialect;
-- `Type LLVM::getVectorElementType(Type)` - returns the element type of any
- vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 2debd09f78b34..ab928c9e2d0e7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
llvm::Type *ElemTy = moduleTranslation.convertType(
- getVectorElementType(op.getType()));
+ op.getType().getElementType());
llvm::Align align = dl.getABITypeAlign(ElemTy);
$res = mb.CreateColumnMajorLoad(
ElemTy, $data, align, $stride, $isVolatile, $rows,
@@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
llvm::MatrixBuilder mb(builder);
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
- Type elementType = getVectorElementType(op.getMatrix().getType());
+ Type elementType = op.getMatrix().getType().getElementType();
llvm::Align align = dl.getABITypeAlign(
moduleTranslation.convertType(elementType));
mb.CreateColumnMajorStore(
@@ -1164,7 +1164,8 @@ def LLVM_vector_insert
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
- getVectorElementType(vector).getIntOrFloatBitWidth();
+ ::llvm::cast<VectorType>(vector).getElementType()
+ .getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
@@ -1196,7 +1197,8 @@ def LLVM_vector_extract
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
- getVectorElementType(vector).getIntOrFloatBitWidth();
+ ::llvm::cast<VectorType>(vector).getElementType()
+ .getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
@@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
"result has twice as many elements as 'vec1'",
And<[CPred<"getVectorNumElements($res.getType()) == "
"getVectorNumElements($vec1.getType()) * 2">,
- CPred<"getVectorElementType($vec1.getType()) == "
- "getVectorElementType($res.getType())">]>>,
+ CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
+ "::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
]>,
Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 1fa1d3be557db..b97b5ac932c97 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -113,17 +113,20 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
// Type constraint accepting any LLVM vector type.
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
- "LLVM dialect-compatible vector type">;
+ "LLVM dialect-compatible vector type",
+ "::mlir::VectorType">;
// Type constraint accepting any LLVM fixed-length vector type.
def LLVM_AnyFixedVector : Type<CPred<
"!::mlir::LLVM::isScalableVectorType($_self)">,
- "LLVM dialect-compatible fixed-length vector type">;
+ "LLVM dialect-compatible fixed-length vector type",
+ "::mlir::VectorType">;
// Type constraint accepting any LLVM scalable vector type.
def LLVM_AnyScalableVector : Type<CPred<
"::mlir::LLVM::isScalableVectorType($_self)">,
- "LLVM dialect-compatible scalable vector type">;
+ "LLVM dialect-compatible scalable vector type",
+ "::mlir::VectorType">;
// Type constraint accepting an LLVM vector type with an additional constraint
// on the vector element type.
@@ -131,9 +134,10 @@ class LLVM_VectorOf<Type element> : Type<
And<[LLVM_AnyVector.predicate,
SubstLeaves<
"$_self",
- "::mlir::LLVM::getVectorElementType($_self)",
+ "::llvm::cast<::mlir::VectorType>($_self).getElementType()",
element.predicate>]>,
- "LLVM dialect-compatible vector of " # element.summary>;
+ "LLVM dialect-compatible vector of " # element.summary,
+ "::mlir::VectorType">;
// Type constraint accepting a constrained type, or a vector of such types.
class LLVM_ScalarOrVectorOf<Type element> :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b107b64e55b46..6602318b07b85 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
//===----------------------------------------------------------------------===//
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
- TypesMatchWith<"result type matches vector element type", "vector", "res",
- "LLVM::getVectorElementType($_self)">]> {
+ TypesMatchWith<
+ "result type matches vector element type", "vector", "res",
+ "::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
let summary = "Extract an element from an LLVM vector.";
let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
@@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
TypesMatchWith<"argument type matches vector element type", "vector",
- "value", "LLVM::getVectorElementType($_self)">,
+ "value",
+ "::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
AllTypesMatch<["res", "vector"]>]> {
let summary = "Insert an element into an LLVM vector.";
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 03c246e589643..a2a76c49a2bda 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
/// dialect pointers and LLVM dialect scalable vector types.
bool isCompatibleVectorType(Type type);
-/// Returns the element type of any vector type compatible with the LLVM
-/// dialect.
-Type getVectorElementType(Type type);
-
/// Returns the element count of any LLVM-compatible vector type.
llvm::ElementCount getVectorNumElements(Type type);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 6e0adfc1e0ff3..93979e0f73324 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(Type type) {
- return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
- ? LLVM::getVectorElementType(type)
- : type))
- .getWidth();
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ type = vecTy.getElementType();
+ return cast<IntegerType>(type).getWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 78eb4c9b3481f..33a1686541996 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
Value v2, DenseI32ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType();
- auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
- mask.size(),
- LLVM::isScalableVectorType(containerType));
+ auto vType = LLVM::getVectorType(
+ cast<VectorType>(containerType).getElementType(), mask.size(),
+ LLVM::isScalableVectorType(containerType));
build(builder, state, vType, v1, v2, mask);
state.addAttributes(attrs);
}
@@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
if (!LLVM::isCompatibleVectorType(v1Type))
return parser.emitError(parser.getCurrentLocation(),
"expected an LLVM compatible vector type");
- resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
- LLVM::isScalableVectorType(v1Type));
+ resType =
+ LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
+ mask.size(), LLVM::isScalableVectorType(v1Type));
return success();
}
@@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
if (isCompatibleVectorType(valType)) {
if (isScalableVectorType(valType))
return emitOpError("expected LLVM IR fixed vector type");
- Type elemType = getVectorElementType(valType);
+ Type elemType = llvm::cast<VectorType>(valType).getElementType();
if (!isCompatibleFloatingPointType(elemType))
return emitOpError(
"expected LLVM IR floating point type for vector element");
@@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
return op.emitError("input and output vectors are of incompatible shape");
// Because this is a CastOp, the element of vectors is guaranteed to be an
// integer.
- inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
- outputType =
- cast<IntegerType>(getVectorElementType(op.getResult().getType()));
+ inputType = cast<IntegerType>(
+ cast<VectorType>(op.getArg().getType()).getElementType());
+ outputType = cast<IntegerType>(
+ cast<VectorType>(op.getResult().getType()).getElementType());
} else {
// Because this is a CastOp and arg is not a vector, arg is guaranteed to be
// an integer.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 663adc3c34256..b3c2a29309528 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
return false;
}
-Type mlir::LLVM::getVectorElementType(Type type) {
- auto vecTy = dyn_cast<VectorType>(type);
- assert(vecTy && "incompatible with LLVM vector type");
- return vecTy.getElementType();
-}
-
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
auto vecTy = dyn_cast<VectorType>(type);
assert(vecTy && "incompatible with LLVM vector type");
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2859abdb41772..0d08f15d29b7d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
if (iface.isConvertibleInstruction(inst->getOpcode()))
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
moduleImport);
- // TODO: Implement the `convertInstruction` hooks in the
- // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+ // TODO: Implement the `convertInstruction` hooks in the
+ // `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure();
}
@@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
}
// An LLVM dialect vector can only contain scalars.
- Type elementType = LLVM::getVectorElementType(type);
+ Type elementType = cast<VectorType>(type).getElementType();
if (!elementType.isIntOrFloat())
return {};
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index db55088d812e6..0cd6b1f20a1bf 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
// -----
func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
- // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
+ // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.extractelement %arg2[%arg1 : i32] : f32
}
// -----
func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
- // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
+ // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
}
// -----
func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
- // expected-error at +2 {{expected an LLVM compatible vector type}}
+ // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 7bb64542accdf..90c0f5ac55cb1 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
// -----
llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
- // expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+ // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
{ isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
llvm.return %0 : f32
@@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
// -----
llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
- // expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
+ // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.multiply %arg0, %arg1
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
llvm.return %0 : vector<12xf32>
@@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
// -----
llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
- // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+ // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
llvm.return %0 : vector<48xf32>
}
@@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
// -----
llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
- // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+ // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
llvm.return
}
More information about the Mlir-commits
mailing list