[Mlir-commits] [mlir] [mlir][SPIRV] Fix build (2) (PR #111265)

Matthias Springer llvmlistbot at llvm.org
Sat Oct 5 12:56:54 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/111265

None

>From c0bbf0c40a32123d275f02d6ec8fbd24b99b3ba5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 5 Oct 2024 21:56:08 +0200
Subject: [PATCH] [mlir][SPIRV] Fix build (2)

---
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 31 ++++++++++---------
 1 file changed, 16 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index e36e3951a31ecc..74c169c9a7e76a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -149,7 +149,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
 
 /// Broadcasts the value to vector with `numElements` number of elements.
 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
-                       const LLVMTypeConverter &typeConverter,
+                       const TypeConverter &typeConverter,
                        ConversionPatternRewriter &rewriter) {
   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
   auto llvmVectorType = typeConverter.convertType(vectorType);
@@ -166,7 +166,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
 
 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
-                                 const LLVMTypeConverter &typeConverter,
+                                 const TypeConverter &typeConverter,
                                  ConversionPatternRewriter &rewriter) {
   if (auto vectorType = dyn_cast<VectorType>(srcType)) {
     unsigned numElements = vectorType.getNumElements();
@@ -186,8 +186,7 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType,
 /// Then cast `Offset` and `Count` if their bit width is different
 /// from `Base` bit width.
 static Value processCountOrOffset(Location loc, Value value, Type srcType,
-                                  Type dstType,
-                                  const LLVMTypeConverter &converter,
+                                  Type dstType, const TypeConverter &converter,
                                   ConversionPatternRewriter &rewriter) {
   Value broadcasted =
       optionallyBroadcast(loc, value, srcType, converter, rewriter);
@@ -197,7 +196,7 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
 /// offset to LLVM struct. Otherwise, the conversion is not supported.
 static Type convertStructTypeWithOffset(spirv::StructType type,
-                                        const LLVMTypeConverter &converter) {
+                                        const TypeConverter &converter) {
   if (type != VulkanLayoutUtils::decorateType(type))
     return nullptr;
 
@@ -210,7 +209,7 @@ static Type convertStructTypeWithOffset(spirv::StructType type,
 
 /// Converts SPIR-V struct with no offset to packed LLVM struct.
 static Type convertStructTypePacked(spirv::StructType type,
-                                    const LLVMTypeConverter &converter) {
+                                    const TypeConverter &converter) {
   SmallVector<Type> elementsVector;
   if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
     return nullptr;
@@ -227,10 +226,11 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
 }
 
 /// Utility for `spirv.Load` and `spirv.Store` conversion.
-static LogicalResult replaceWithLoadOrStore(
-    Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter,
-    const LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile,
-    bool isNonTemporal) {
+static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
+                                            ConversionPatternRewriter &rewriter,
+                                            const TypeConverter &typeConverter,
+                                            unsigned alignment, bool isVolatile,
+                                            bool isNonTemporal) {
   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
     auto dstType = typeConverter.convertType(loadOp.getType());
     if (!dstType)
@@ -271,7 +271,7 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
 /// modelled at the moment.
 static Type convertPointerType(spirv::PointerType type,
-                               const LLVMTypeConverter &converter,
+                               const TypeConverter &converter,
                                spirv::ClientAPI clientAPI) {
   unsigned addressSpace =
       storageClassToAddressSpace(clientAPI, type.getStorageClass());
@@ -292,7 +292,7 @@ static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
 /// member decorations. Also, only natural offset is supported.
 static Type convertStructType(spirv::StructType type,
-                              const LLVMTypeConverter &converter) {
+                              const TypeConverter &converter) {
   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
   type.getMemberDecorations(memberDecorations);
   if (!memberDecorations.empty())
@@ -1378,9 +1378,10 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     auto funcType = funcOp.getFunctionType();
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
-    auto llvmType = getTypeConverter()->convertFunctionSignature(
-        funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
-        signatureConverter);
+    auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
+                        ->convertFunctionSignature(
+                            funcType, /*isVariadic=*/false,
+                            /*useBarePtrCallConv=*/false, signatureConverter);
     if (!llvmType)
       return failure();
 



More information about the Mlir-commits mailing list