[Mlir-commits] [mlir] [mlir][IR] Add `ScalarTypeInterface` and use as `VectorType` element type (PR #132400)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 21 06:52:02 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds a new builtin type interface: `ScalarTypeInterface`

Instead of maintaining a list of valid element types for `VectorType`, restrict valid element types to `ScalarTypeInterface`.

---

Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff


54 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1) 
- (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8) 
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4) 
- (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1) 
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4) 
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2) 
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2) 
- (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1) 
- (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1) 
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2) 
- (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2) 
- (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2) 
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1) 
- (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4) 
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1) 
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4) 
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1) 
- (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27) 
- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3) 
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6) 
- (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2) 
- (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Traits.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2) 
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1) 
- (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1) 
- (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1) 
- (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1) 
- (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/132400


More information about the Mlir-commits mailing list