[Mlir-commits] [mlir] [mlir][LLVM] Delete `getVectorElementType` (PR #134981)

Matthias Springer llvmlistbot at llvm.org
Wed Apr 9 09:41:59 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/134981

>From 9a4a7e9874699ce9ad7ea01653fe22c21677b4b1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 9 Apr 2025 11:05:31 +0200
Subject: [PATCH] [mlir][LLVM] Delete `getVectorElementType`

---
 mlir/docs/Dialects/LLVM.md                    |  2 --
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   | 14 +++++++------
 .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 14 ++++++++-----
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  8 +++++---
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h  |  4 ----
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  7 +++----
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 20 ++++++++++---------
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |  6 ------
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  6 +++---
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  6 +++---
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir   |  8 ++++----
 11 files changed, 46 insertions(+), 49 deletions(-)

diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index d0509e036682f..468f69c419071 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -334,8 +334,6 @@ compatible with the LLVM dialect:
 
 -   `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
     vector type compatible with the LLVM dialect;
--   `Type LLVM::getVectorElementType(Type)` - returns the element type of any
-    vector type compatible with the LLVM dialect;
 -   `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
     of elements in any vector type compatible with the LLVM dialect;
 -   `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 2debd09f78b34..ab928c9e2d0e7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
     const llvm::DataLayout &dl =
       builder.GetInsertBlock()->getModule()->getDataLayout();
     llvm::Type *ElemTy = moduleTranslation.convertType(
-        getVectorElementType(op.getType()));
+        op.getType().getElementType());
     llvm::Align align = dl.getABITypeAlign(ElemTy);
     $res = mb.CreateColumnMajorLoad(
       ElemTy, $data, align, $stride, $isVolatile, $rows,
@@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
     llvm::MatrixBuilder mb(builder);
     const llvm::DataLayout &dl =
       builder.GetInsertBlock()->getModule()->getDataLayout();
-    Type elementType = getVectorElementType(op.getMatrix().getType());
+    Type elementType = op.getMatrix().getType().getElementType();
     llvm::Align align = dl.getABITypeAlign(
       moduleTranslation.convertType(elementType));
     mb.CreateColumnMajorStore(
@@ -1164,7 +1164,8 @@ def LLVM_vector_insert
   let extraClassDeclaration = [{
     uint64_t getVectorBitWidth(Type vector) {
       return getVectorNumElements(vector).getKnownMinValue() *
-             getVectorElementType(vector).getIntOrFloatBitWidth();
+             ::llvm::cast<VectorType>(vector).getElementType()
+                .getIntOrFloatBitWidth();
     }
     uint64_t getSrcVectorBitWidth() {
       return getVectorBitWidth(getSrcvec().getType());
@@ -1196,7 +1197,8 @@ def LLVM_vector_extract
   let extraClassDeclaration = [{
     uint64_t getVectorBitWidth(Type vector) {
       return getVectorNumElements(vector).getKnownMinValue() *
-             getVectorElementType(vector).getIntOrFloatBitWidth();
+             ::llvm::cast<VectorType>(vector).getElementType()
+                .getIntOrFloatBitWidth();
     }
     uint64_t getSrcVectorBitWidth() {
       return getVectorBitWidth(getSrcvec().getType());
@@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
             "result has twice as many elements as 'vec1'",
             And<[CPred<"getVectorNumElements($res.getType()) == "
                        "getVectorNumElements($vec1.getType()) * 2">,
-                 CPred<"getVectorElementType($vec1.getType()) == "
-                       "getVectorElementType($res.getType())">]>>,
+                 CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
+                       "::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
         ]>,
         Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 1fa1d3be557db..b97b5ac932c97 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -113,17 +113,20 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
 
 // Type constraint accepting any LLVM vector type.
 def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
-                         "LLVM dialect-compatible vector type">;
+                         "LLVM dialect-compatible vector type",
+                         "::mlir::VectorType">;
 
 // Type constraint accepting any LLVM fixed-length vector type.
 def LLVM_AnyFixedVector : Type<CPred<
                                 "!::mlir::LLVM::isScalableVectorType($_self)">,
-                                "LLVM dialect-compatible fixed-length vector type">;
+                                "LLVM dialect-compatible fixed-length vector type",
+                                "::mlir::VectorType">;
 
 // Type constraint accepting any LLVM scalable vector type.
 def LLVM_AnyScalableVector : Type<CPred<
                                 "::mlir::LLVM::isScalableVectorType($_self)">,
-                                "LLVM dialect-compatible scalable vector type">;
+                                "LLVM dialect-compatible scalable vector type",
+                                "::mlir::VectorType">;
 
 // Type constraint accepting an LLVM vector type with an additional constraint
 // on the vector element type.
@@ -131,9 +134,10 @@ class LLVM_VectorOf<Type element> : Type<
   And<[LLVM_AnyVector.predicate,
        SubstLeaves<
          "$_self",
-         "::mlir::LLVM::getVectorElementType($_self)",
+         "::llvm::cast<::mlir::VectorType>($_self).getElementType()",
          element.predicate>]>,
-  "LLVM dialect-compatible vector of " # element.summary>;
+  "LLVM dialect-compatible vector of " # element.summary,
+  "::mlir::VectorType">;
 
 // Type constraint accepting a constrained type, or a vector of such types.
 class LLVM_ScalarOrVectorOf<Type element> :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b107b64e55b46..6602318b07b85 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
 //===----------------------------------------------------------------------===//
 
 def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
-    TypesMatchWith<"result type matches vector element type", "vector", "res",
-                   "LLVM::getVectorElementType($_self)">]> {
+    TypesMatchWith<
+        "result type matches vector element type", "vector", "res",
+        "::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
   let summary = "Extract an element from an LLVM vector.";
 
   let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
@@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {
 
 def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
     TypesMatchWith<"argument type matches vector element type", "vector",
-                   "value", "LLVM::getVectorElementType($_self)">,
+                   "value",
+                   "::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
     AllTypesMatch<["res", "vector"]>]> {
   let summary = "Insert an element into an LLVM vector.";
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 03c246e589643..a2a76c49a2bda 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
 /// dialect pointers and LLVM dialect scalable vector types.
 bool isCompatibleVectorType(Type type);
 
-/// Returns the element type of any vector type compatible with the LLVM
-/// dialect.
-Type getVectorElementType(Type type);
-
 /// Returns the element count of any LLVM-compatible vector type.
 llvm::ElementCount getVectorNumElements(Type type);
 
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 6e0adfc1e0ff3..ef791ead8985c 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {
 
 /// Returns the bit width of LLVMType integer or vector.
 static unsigned getLLVMTypeBitWidth(Type type) {
-  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
-                                ? LLVM::getVectorElementType(type)
-                                : type))
-      .getWidth();
+  assert((isa<IntegerType>(type) || isa<VectorType>(type)) &&
+         "expected integer or vector");
+  return cast<IntegerType>(getElementTypeOrSelf(type)).getWidth();
 }
 
 /// Creates `IntegerAttribute` with all bits set for given type
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 78eb4c9b3481f..33a1686541996 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
                             Value v2, DenseI32ArrayAttr mask,
                             ArrayRef<NamedAttribute> attrs) {
   auto containerType = v1.getType();
-  auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
-                                   mask.size(),
-                                   LLVM::isScalableVectorType(containerType));
+  auto vType = LLVM::getVectorType(
+      cast<VectorType>(containerType).getElementType(), mask.size(),
+      LLVM::isScalableVectorType(containerType));
   build(builder, state, vType, v1, v2, mask);
   state.addAttributes(attrs);
 }
@@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
   if (!LLVM::isCompatibleVectorType(v1Type))
     return parser.emitError(parser.getCurrentLocation(),
                             "expected an LLVM compatible vector type");
-  resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
-                                LLVM::isScalableVectorType(v1Type));
+  resType =
+      LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
+                          mask.size(), LLVM::isScalableVectorType(v1Type));
   return success();
 }
 
@@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
     if (isCompatibleVectorType(valType)) {
       if (isScalableVectorType(valType))
         return emitOpError("expected LLVM IR fixed vector type");
-      Type elemType = getVectorElementType(valType);
+      Type elemType = llvm::cast<VectorType>(valType).getElementType();
       if (!isCompatibleFloatingPointType(elemType))
         return emitOpError(
             "expected LLVM IR floating point type for vector element");
@@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
       return op.emitError("input and output vectors are of incompatible shape");
     // Because this is a CastOp, the element of vectors is guaranteed to be an
     // integer.
-    inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
-    outputType =
-        cast<IntegerType>(getVectorElementType(op.getResult().getType()));
+    inputType = cast<IntegerType>(
+        cast<VectorType>(op.getArg().getType()).getElementType());
+    outputType = cast<IntegerType>(
+        cast<VectorType>(op.getResult().getType()).getElementType());
   } else {
     // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
     // an integer.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 663adc3c34256..b3c2a29309528 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
   return false;
 }
 
-Type mlir::LLVM::getVectorElementType(Type type) {
-  auto vecTy = dyn_cast<VectorType>(type);
-  assert(vecTy && "incompatible with LLVM vector type");
-  return vecTy.getElementType();
-}
-
 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
   auto vecTy = dyn_cast<VectorType>(type);
   assert(vecTy && "incompatible with LLVM vector type");
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2859abdb41772..0d08f15d29b7d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
   if (iface.isConvertibleInstruction(inst->getOpcode()))
     return iface.convertInstruction(odsBuilder, inst, llvmOperands,
                                     moduleImport);
-  // TODO: Implement the `convertInstruction` hooks in the
-  // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+    // TODO: Implement the `convertInstruction` hooks in the
+    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
@@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
   }
 
   // An LLVM dialect vector can only contain scalars.
-  Type elementType = LLVM::getVectorElementType(type);
+  Type elementType = cast<VectorType>(type).getElementType();
   if (!elementType.isIntOrFloat())
     return {};
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index db55088d812e6..0cd6b1f20a1bf 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
 // -----
 
 func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
-  // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
+  // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.extractelement %arg2[%arg1 : i32] : f32
 }
 
 // -----
 
 func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
-  // expected-error at +1 {{'vector' must be LLVM dialect-compatible vector}}
+  // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
 }
 
 // -----
 
 func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
-  // expected-error at +2 {{expected an LLVM compatible vector type}}
+  // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 7bb64542accdf..90c0f5ac55cb1 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
 // -----
 
 llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
-  // expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+  // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
     { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
   llvm.return %0 : f32
@@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
 // -----
 
 llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
-  // expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
+  // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.intr.matrix.multiply %arg0, %arg1
     { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
   llvm.return %0 : vector<12xf32>
@@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
 // -----
 
 llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
-  // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+  // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   %0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
   llvm.return %0 : vector<48xf32>
 }
@@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
 // -----
 
 llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
-  // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
+  // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
   llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
   llvm.return
 }



More information about the Mlir-commits mailing list