[Mlir-commits] [mlir] dcec2ca - Remove typeConverter from ConvertToLLVMPattern and use the existing one in ConversionPattern.

Christian Sigg llvmlistbot at llvm.org
Fri Dec 4 05:27:30 PST 2020


Author: Christian Sigg
Date: 2020-12-04T14:27:16+01:00
New Revision: dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947

URL: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947
DIFF: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947.diff

LOG: Remove typeConverter from ConvertToLLVMPattern and use the existing one in ConversionPattern.

ftynse

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D92564

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 7b8bcdff4deb..bf41f29749de 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -71,7 +71,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
+  LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic,
                                           SignatureConversion &result);
 
   /// Convert a non-empty list of types to be returned from a function into a
@@ -485,6 +485,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
   /// Returns the LLVM dialect.
   LLVM::LLVMDialect &getDialect() const;
 
+  LLVMTypeConverter *getTypeConverter() const;
+
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
   /// defined by the used type converter.
   LLVM::LLVMType getIndexType() const;
@@ -556,10 +558,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
                          Value allocatedPtr, Value alignedPtr,
                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
                          ConversionPatternRewriter &rewriter) const;
-
-protected:
-  /// Reference to the type converter, with potential extensions.
-  LLVMTypeConverter &typeConverter;
 };
 
 /// Utility class for operation conversions targeting the LLVM dialect that
@@ -644,7 +642,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
-                                         operands, this->typeConverter,
+                                         operands, *this->getTypeConverter(),
                                          rewriter);
   }
 };
@@ -666,9 +664,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
                                   SourceOp>::value,
                   "expected same operands and result type");
-    return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(),
-                                               operands, this->typeConverter,
-                                               rewriter);
+    return LLVM::detail::vectorOneToOneRewrite(
+        op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
+        rewriter);
   }
 };
 

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 3950562539f6..fe06e12c8f21 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -86,7 +86,7 @@ struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
       return failure();
     return matchAndRewriteOneToOne<MaskRndScaleOp,
                                    LLVM::x86_avx512_mask_rndscale_ps_512>(
-        *this, this->typeConverter, op, operands, rewriter);
+        *this, *getTypeConverter(), op, operands, rewriter);
   }
 };
 
@@ -103,7 +103,7 @@ struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
       return failure();
     return matchAndRewriteOneToOne<MaskRndScaleOp,
                                    LLVM::x86_avx512_mask_rndscale_pd_512>(
-        *this, this->typeConverter, op, operands, rewriter);
+        *this, *getTypeConverter(), op, operands, rewriter);
   }
 };
 
@@ -120,7 +120,7 @@ struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
       return failure();
     return matchAndRewriteOneToOne<MaskScaleFOp,
                                    LLVM::x86_avx512_mask_scalef_ps_512>(
-        *this, this->typeConverter, op, operands, rewriter);
+        *this, *getTypeConverter(), op, operands, rewriter);
   }
 };
 
@@ -137,7 +137,7 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
       return failure();
     return matchAndRewriteOneToOne<MaskScaleFOp,
                                    LLVM::x86_avx512_mask_scalef_pd_512>(
-        *this, this->typeConverter, op, operands, rewriter);
+        *this, *getTypeConverter(), op, operands, rewriter);
   }
 };
 } // namespace

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index ad84216d1e3b..810511194f68 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -72,7 +72,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
 
 protected:
-  MLIRContext *context = &this->typeConverter.getContext();
+  MLIRContext *context = &this->getTypeConverter()->getContext();
 
   LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
   LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
@@ -81,7 +81,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
   LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
   LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy(
-      context, this->typeConverter.getPointerBitwidth(0));
+      context, this->getTypeConverter()->getPointerBitwidth(0));
 
   FunctionCallBuilder moduleLoadCallBuilder = {
       "mgpuModuleLoad",
@@ -333,8 +333,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
-  auto arguments =
-      typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
+  auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
+                                                       operands, rewriter);
   arguments.push_back(elementSize);
   hostRegisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -486,7 +486,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
     OpBuilder &builder) const {
   auto loc = launchOp.getLoc();
   auto numKernelOperands = launchOp.getNumKernelOperands();
-  auto arguments = typeConverter.promoteOperands(
+  auto arguments = getTypeConverter()->promoteOperands(
       loc, launchOp.getOperands().take_back(numKernelOperands),
       operands.take_back(numKernelOperands), builder);
   auto numArguments = arguments.size();

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index a3fad7e71c84..69ea393e5df1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -41,7 +41,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
 
       uint64_t numElements = type.getNumElements();
 
-      auto elementType = typeConverter.convertType(type.getElementType())
+      auto elementType = typeConverter->convertType(type.getElementType())
                              .template cast<LLVM::LLVMType>();
       auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
       std::string name = std::string(
@@ -54,14 +54,14 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
     }
 
     // Rewrite the original GPU function to an LLVM function.
-    auto funcType = typeConverter.convertType(gpuFuncOp.getType())
+    auto funcType = typeConverter->convertType(gpuFuncOp.getType())
                         .template cast<LLVM::LLVMType>()
                         .getPointerElementTy();
 
     // Remap proper input types.
     TypeConverter::SignatureConversion signatureConversion(
         gpuFuncOp.front().getNumArguments());
-    typeConverter.convertFunctionSignature(
+    getTypeConverter()->convertFunctionSignature(
         gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
 
     // Create the new function operation. Only copy those attributes that are
@@ -110,7 +110,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
         Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
         auto type = attribution.getType().cast<MemRefType>();
         auto descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, typeConverter, type, memory);
+            rewriter, loc, *getTypeConverter(), type, memory);
         signatureConversion.remapInput(numProperArguments + en.index(), descr);
       }
 
@@ -127,7 +127,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
         // Explicitly drop memory space when lowering private memory
         // attributions since NVVM models it as `alloca`s in the default
         // memory space and does not support `alloca`s with addrspace(5).
-        auto ptrType = typeConverter.convertType(type.getElementType())
+        auto ptrType = typeConverter->convertType(type.getElementType())
                            .template cast<LLVM::LLVMType>()
                            .getPointerTo(AllocaAddrSpace);
         Value numElements = rewriter.create<LLVM::ConstantOp>(
@@ -136,7 +136,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
         Value allocated = rewriter.create<LLVM::AllocaOp>(
             gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
         auto descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, typeConverter, type, allocated);
+            rewriter, loc, *getTypeConverter(), type, allocated);
         signatureConversion.remapInput(
             numProperArguments + numWorkgroupAttributions + en.index(), descr);
       }
@@ -145,8 +145,8 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
     // Move the region to the new function, update the entry block signature.
     rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
                                 llvmFuncOp.end());
-    if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter,
-                                           &signatureConversion)))
+    if (failed(rewriter.convertRegionTypes(
+            &llvmFuncOp.getBody(), *typeConverter, &signatureConversion)))
       return failure();
 
     rewriter.eraseOp(gpuFuncOp);

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index f32c664c17c4..b907703995d8 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -135,8 +135,8 @@ class RangeOpConversion : public ConvertToLLVMPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto rangeOp = cast<RangeOp>(op);
-    auto rangeDescriptorTy =
-        convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
+    auto rangeDescriptorTy = convertRangeType(
+        rangeOp.getType().cast<RangeType>(), *getTypeConverter());
 
     edsc::ScopedContext context(rewriter, op->getLoc());
 
@@ -181,7 +181,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
     edsc::ScopedContext context(rewriter, op->getLoc());
     ReshapeOpAdaptor adaptor(operands);
     BaseViewConversionHelper baseDesc(adaptor.src());
-    BaseViewConversionHelper desc(typeConverter.convertType(dstType));
+    BaseViewConversionHelper desc(typeConverter->convertType(dstType));
     desc.setAllocatedPtr(baseDesc.allocatedPtr());
     desc.setAlignedPtr(baseDesc.alignedPtr());
     desc.setOffset(baseDesc.offset());
@@ -214,11 +214,11 @@ class SliceOpConversion : public ConvertToLLVMPattern {
 
     auto sliceOp = cast<SliceOp>(op);
     auto memRefType = sliceOp.getBaseViewType();
-    auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
+    auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
                        .cast<LLVM::LLVMType>();
 
     BaseViewConversionHelper desc(
-        typeConverter.convertType(sliceOp.getShapedType()));
+        typeConverter->convertType(sliceOp.getShapedType()));
 
     // TODO: extract sizes and emit asserts.
     SmallVector<Value, 4> strides(memRefType.getRank());

diff  --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 91e97ca1ec50..c589ef69f2c4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -35,7 +35,7 @@ struct RegionOpConversion : public ConvertToLLVMPattern {
                                          curOp.getAttrs());
     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
                                 newOp.region().end());
-    if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
+    if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
       return failure();
 
     rewriter.eraseOp(op);

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 525a5be24485..f83f72d1d10e 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -224,7 +224,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
       spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
       auto pointeeType =
           spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
-      auto dstGlobalType = typeConverter.convertType(pointeeType);
+      auto dstGlobalType = typeConverter->convertType(pointeeType);
       if (!dstGlobalType)
         return failure();
       std::string name =

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index f54ffc1c9d6c..17a065463297 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -446,8 +446,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,
                                            PatternBenefit benefit)
-    : ConversionPattern(rootOpName, benefit, typeConverter, context),
-      typeConverter(typeConverter) {}
+    : ConversionPattern(rootOpName, benefit, typeConverter, context) {}
 
 //===----------------------------------------------------------------------===//
 // StructBuilder implementation
@@ -1013,27 +1012,32 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
 }
 
+LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
+  return static_cast<LLVMTypeConverter *>(
+      ConversionPattern::getTypeConverter());
+}
+
 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
-  return *typeConverter.getDialect();
+  return *getTypeConverter()->getDialect();
 }
 
 LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
-  return typeConverter.getIndexType();
+  return getTypeConverter()->getIndexType();
 }
 
 LLVM::LLVMType
 ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
   return LLVM::LLVMType::getIntNTy(
-      &typeConverter.getContext(),
-      typeConverter.getPointerBitwidth(addressSpace));
+      &getTypeConverter()->getContext(),
+      getTypeConverter()->getPointerBitwidth(addressSpace));
 }
 
 LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
-  return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
+  return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext());
 }
 
 LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
-  return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
+  return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext());
 }
 
 Value ConvertToLLVMPattern::createIndexConstant(
@@ -1086,7 +1090,7 @@ Value ConvertToLLVMPattern::getDataPtr(
 // Check if the MemRefType `type` is supported by the lowering. We currently
 // only support memrefs with identity maps.
 bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
-  if (!typeConverter.convertType(type.getElementType()))
+  if (!typeConverter->convertType(type.getElementType()))
     return false;
   return type.getAffineMaps().empty() ||
          llvm::all_of(type.getAffineMaps(),
@@ -1095,7 +1099,7 @@ bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
 
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
-  auto structElementType = unwrap(typeConverter.convertType(elementType));
+  auto structElementType = unwrap(typeConverter->convertType(elementType));
   return structElementType.getPointerTo(type.getMemorySpace());
 }
 
@@ -1155,7 +1159,7 @@ Value ConvertToLLVMPattern::getSizeInBytes(
   //   %1 = ptrtoint %elementType* %0 to %indexType
   // which is a common pattern of getting the size of a type in bytes.
   auto convertedPtrType =
-      typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
+      typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
       loc, convertedPtrType,
@@ -1179,7 +1183,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
     ArrayRef<Value> sizes, ArrayRef<Value> strides,
     ConversionPatternRewriter &rewriter) const {
-  auto structType = typeConverter.convertType(memRefType);
+  auto structType = typeConverter->convertType(memRefType);
   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
 
   // Field 1: Allocated pointer, used for malloc/free.
@@ -1347,7 +1351,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
     // LLVMTypeConverter provided to this legalization pattern.
     auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
-    auto llvmType = typeConverter.convertFunctionSignature(
+    auto llvmType = getTypeConverter()->convertFunctionSignature(
         funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
     if (!llvmType)
       return nullptr;
@@ -1379,7 +1383,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
         attributes);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
-    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
                                            &result)))
       return nullptr;
 
@@ -1402,14 +1406,14 @@ struct FuncOpConversion : public FuncOpConversionBase {
     if (!newFuncOp)
       return failure();
 
-    if (typeConverter.getOptions().emitCWrappers ||
+    if (getTypeConverter()->getOptions().emitCWrappers ||
         funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
       if (newFuncOp.isExternal())
-        wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp,
-                             newFuncOp);
+        wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
+                             funcOp, newFuncOp);
       else
-        wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp,
-                               newFuncOp);
+        wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
+                               funcOp, newFuncOp);
     }
 
     rewriter.eraseOp(funcOp);
@@ -1472,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
 
       Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, typeConverter, memrefTy, arg);
+          rewriter, loc, *getTypeConverter(), memrefTy, arg);
       rewriter.replaceOp(placeholder, {desc});
     }
 
@@ -1757,7 +1761,7 @@ struct CreateComplexOpLowering
 
     // Pack real and imaginary part in a complex number struct.
     auto loc = op.getLoc();
-    auto structType = typeConverter.convertType(complexOp.getType());
+    auto structType = typeConverter->convertType(complexOp.getType());
     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
     complexStruct.setReal(rewriter, loc, transformed.real());
     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
@@ -1836,7 +1840,7 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
         unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
 
     // Initialize complex number struct for result.
-    auto structType = this->typeConverter.convertType(op.getType());
+    auto structType = typeConverter->convertType(op.getType());
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
@@ -1863,7 +1867,7 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
         unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
 
     // Initialize complex number struct for result.
-    auto structType = this->typeConverter.convertType(op.getType());
+    auto structType = typeConverter->convertType(op.getType());
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to substract complex numbers.
@@ -1887,7 +1891,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // If constant refers to a function, convert it to "addressof".
     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
-      auto type = typeConverter.convertType(op.getResult().getType())
+      auto type = typeConverter->convertType(op.getResult().getType())
                       .dyn_cast_or_null<LLVM::LLVMType>();
       if (!type)
         return rewriter.notifyMatchFailure(op, "failed to convert result type");
@@ -1905,9 +1909,9 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
       return rewriter.notifyMatchFailure(
           op, "referring to a symbol outside of the current module");
 
-    return LLVM::detail::oneToOneRewrite(op,
-                                         LLVM::ConstantOp::getOperationName(),
-                                         operands, typeConverter, rewriter);
+    return LLVM::detail::oneToOneRewrite(
+        op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
+        rewriter);
   }
 };
 
@@ -1916,7 +1920,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
   using ConvertToLLVMPattern::createIndexConstant;
   using ConvertToLLVMPattern::getIndexType;
   using ConvertToLLVMPattern::getVoidPtrType;
-  using ConvertToLLVMPattern::typeConverter;
 
   explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
       : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
@@ -2288,11 +2291,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
 
     if (numResults != 0) {
       if (!(packedResult =
-                this->typeConverter.packFunctionResults(resultTypes)))
+                this->getTypeConverter()->packFunctionResults(resultTypes)))
         return failure();
     }
 
-    auto promoted = this->typeConverter.promoteOperands(
+    auto promoted = this->getTypeConverter()->promoteOperands(
         callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
         rewriter);
     auto newOp = rewriter.create<LLVM::CallOp>(
@@ -2309,23 +2312,23 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
       results.reserve(numResults);
       for (unsigned i = 0; i < numResults; ++i) {
         auto type =
-            this->typeConverter.convertType(callOp.getResult(i).getType());
+            this->typeConverter->convertType(callOp.getResult(i).getType());
         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
             callOp.getLoc(), type, newOp->getResult(0),
             rewriter.getI64ArrayAttr(i)));
       }
     }
 
-    if (this->typeConverter.getOptions().useBarePtrCallConv) {
+    if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
       // For the bare-ptr calling convention, promote memref results to
       // descriptors.
       assert(results.size() == resultTypes.size() &&
              "The number of arguments and types doesn't match");
-      this->typeConverter.promoteBarePtrsToDescriptors(
+      this->getTypeConverter()->promoteBarePtrsToDescriptors(
           rewriter, callOp.getLoc(), resultTypes, results);
     } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
-                                              this->typeConverter, resultTypes,
-                                              results,
+                                              *this->getTypeConverter(),
+                                              resultTypes, results,
                                               /*toDynamic=*/false))) {
       return failure();
     }
@@ -2410,7 +2413,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
     if (!isSupportedMemRefType(type))
       return failure();
 
-    LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+    LLVM::LLVMType arrayTy =
+        convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
 
     LLVM::Linkage linkage =
         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
@@ -2449,14 +2453,15 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
     unsigned memSpace = type.getMemorySpace();
 
-    LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+    LLVM::LLVMType arrayTy =
+        convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
         loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
 
     // Get the address of the first element in the array by creating a GEP with
     // the address of the GV as the base, and (rank + 1) number of 0 indices.
     LLVM::LLVMType elementType =
-        unwrap(typeConverter.convertType(type.getElementType()));
+        unwrap(typeConverter->convertType(type.getElementType()));
     LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
 
     SmallVector<Value, 4> operands = {addressOf};
@@ -2517,7 +2522,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
       return failure();
 
     return handleMultidimensionalVectors(
-        op.getOperation(), operands, typeConverter,
+        op.getOperation(), operands, *getTypeConverter(),
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
               mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
@@ -2546,8 +2551,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
     // a sanity check that the underlying structs are the same. Once op
     // semantics are relaxed we can revisit.
     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
-      return success(typeConverter.convertType(srcType) ==
-                     typeConverter.convertType(dstType));
+      return success(typeConverter->convertType(srcType) ==
+                     typeConverter->convertType(dstType));
 
     // At least one of the operands is unranked type
     assert(srcType.isa<UnrankedMemRefType>() ||
@@ -2566,7 +2571,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
 
     auto srcType = memRefCastOp.getOperand().getType();
     auto dstType = memRefCastOp.getType();
-    auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
+    auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
     auto loc = memRefCastOp.getLoc();
 
     // For ranked/ranked case, just keep the original descriptor.
@@ -2581,7 +2586,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
       auto srcMemRefType = srcType.cast<MemRefType>();
       int64_t rank = srcMemRefType.getRank();
       // ptr = AllocaOp sizeof(MemRefDescriptor)
-      auto ptr = typeConverter.promoteOneMemRefDescriptor(
+      auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
           loc, transformed.source(), rewriter);
       // voidptr = BitCastOp srcType* to void*
       auto voidPtr =
@@ -2589,7 +2594,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
               .getResult();
       // rank = ConstantOp srcRank
       auto rankVal = rewriter.create<LLVM::ConstantOp>(
-          loc, typeConverter.convertType(rewriter.getIntegerType(64)),
+          loc, typeConverter->convertType(rewriter.getIntegerType(64)),
           rewriter.getI64IntegerAttr(rank));
       // undef = UndefOp
       UnrankedMemRefDescriptor memRefDesc =
@@ -2693,7 +2698,7 @@ struct MemRefReinterpretCastOpLowering
                                   Value *descriptor) const {
     MemRefType targetMemRefType =
         castOp.getResult().getType().cast<MemRefType>();
-    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+    auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
                                       .dyn_cast_or_null<LLVM::LLVMType>();
     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
       return failure();
@@ -2704,8 +2709,9 @@ struct MemRefReinterpretCastOpLowering
 
     // Set allocated and aligned pointers.
     Value allocatedPtr, alignedPtr;
-    extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(),
-                             adaptor.source(), &allocatedPtr, &alignedPtr);
+    extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+                             castOp.source(), adaptor.source(), &allocatedPtr,
+                             &alignedPtr);
     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
     desc.setAlignedPtr(rewriter, loc, alignedPtr);
 
@@ -2779,10 +2785,10 @@ struct MemRefReshapeOpLowering
     // Create the unranked memref descriptor that holds the ranked one. The
     // inner descriptor is allocated on stack.
     auto targetDesc = UnrankedMemRefDescriptor::undef(
-        rewriter, loc, unwrap(typeConverter.convertType(targetType)));
+        rewriter, loc, unwrap(typeConverter->convertType(targetType)));
     targetDesc.setRank(rewriter, loc, resultRank);
     SmallVector<Value, 4> sizes;
-    UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
+    UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
                                            targetDesc, sizes);
     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
         loc, getVoidPtrType(), sizes.front(), llvm::None);
@@ -2790,37 +2796,38 @@ struct MemRefReshapeOpLowering
 
     // Extract pointers and offset from the source memref.
     Value allocatedPtr, alignedPtr, offset;
-    extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(),
-                             adaptor.source(), &allocatedPtr, &alignedPtr,
-                             &offset);
+    extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+                             reshapeOp.source(), adaptor.source(),
+                             &allocatedPtr, &alignedPtr, &offset);
 
     // Set pointers and offset.
     LLVM::LLVMType llvmElementType =
-        unwrap(typeConverter.convertType(elementType));
+        unwrap(typeConverter->convertType(elementType));
     LLVM::LLVMType elementPtrPtrType =
         llvmElementType.getPointerTo(addressSpace).getPointerTo();
     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
                                               elementPtrPtrType, allocatedPtr);
-    UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter,
+    UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
                                             underlyingDescPtr,
                                             elementPtrPtrType, alignedPtr);
-    UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter,
+    UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
                                         underlyingDescPtr, elementPtrPtrType,
                                         offset);
 
     // Use the offset pointer as base for further addressing. Copy over the new
     // shape and compute strides. For this, we create a loop from rank-1 to 0.
     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
-        rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+        rewriter, loc, *getTypeConverter(), underlyingDescPtr,
+        elementPtrPtrType);
     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
-        rewriter, loc, typeConverter, targetSizesBase, resultRank);
+        rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
     Value oneIndex = createIndexConstant(rewriter, loc, 1);
     Value resultRankMinusOne =
         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
 
     Block *initBlock = rewriter.getInsertionBlock();
-    LLVM::LLVMType indexType = typeConverter.getIndexType();
+    LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
 
     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
@@ -2854,11 +2861,11 @@ struct MemRefReshapeOpLowering
     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
-    UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter,
+    UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
                                       targetSizesBase, indexArg, size);
 
     // Write stride value and compute next one.
-    UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter,
+    UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
                                         targetStridesBase, indexArg, strideArg);
     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
 
@@ -2892,7 +2899,7 @@ struct DialectCastOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     LLVM::DialectCastOp::Adaptor transformed(operands);
     if (transformed.in().getType() !=
-        typeConverter.convertType(castOp.getType())) {
+        typeConverter->convertType(castOp.getType())) {
       return failure();
     }
     rewriter.replaceOp(castOp, transformed.in());
@@ -2942,15 +2949,16 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
-        typeConverter.convertType(scalarMemRefType)
+        typeConverter->convertType(scalarMemRefType)
             .cast<LLVM::LLVMType>()
             .getPointerTo(addressSpace),
         underlyingRankedDesc);
 
     // Get pointer to offset field of memref<element_type> descriptor.
-    Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace);
+    Type indexPtrTy =
+        getTypeConverter()->getIndexType().getPointerTo(addressSpace);
     Value two = rewriter.create<LLVM::ConstantOp>(
-        loc, typeConverter.convertType(rewriter.getI32Type()),
+        loc, typeConverter->convertType(rewriter.getI32Type()),
         rewriter.getI32IntegerAttr(2));
     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
         loc, indexPtrTy, scalarMemRefDescPtr,
@@ -3082,7 +3090,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
                                          transformed.indices(), rewriter);
 
     // Replace with llvm.prefetch.
-    auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
+    auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
     auto isWrite = rewriter.create<LLVM::ConstantOp>(
         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
     auto localityHint = rewriter.create<LLVM::ConstantOp>(
@@ -3110,7 +3118,7 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
     IndexCastOpAdaptor transformed(operands);
 
     auto targetType =
-        this->typeConverter.convertType(indexCastOp.getResult().getType())
+        typeConverter->convertType(indexCastOp.getResult().getType())
             .cast<LLVM::LLVMType>();
     auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
     unsigned targetBits = targetType.getIntegerBitWidth();
@@ -3144,7 +3152,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
     CmpIOpAdaptor transformed(operands);
 
     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
-        cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()),
+        cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
         rewriter.getI64IntegerAttr(static_cast<int64_t>(
             convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
         transformed.lhs(), transformed.rhs());
@@ -3162,7 +3170,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
     CmpFOpAdaptor transformed(operands);
 
     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
-        cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()),
+        cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
         rewriter.getI64IntegerAttr(static_cast<int64_t>(
             convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
         transformed.lhs(), transformed.rhs());
@@ -3248,7 +3256,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     unsigned numArguments = op.getNumOperands();
     SmallVector<Value, 4> updatedOperands;
 
-    if (typeConverter.getOptions().useBarePtrCallConv) {
+    if (getTypeConverter()->getOptions().useBarePtrCallConv) {
       // For the bare-ptr calling convention, extract the aligned pointer to
       // be returned from the memref descriptor.
       for (auto it : llvm::zip(op->getOperands(), operands)) {
@@ -3266,7 +3274,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
       }
     } else {
       updatedOperands = llvm::to_vector<4>(operands);
-      copyUnrankedDescriptors(rewriter, loc, typeConverter,
+      copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(),
                               op.getOperands().getTypes(), updatedOperands,
                               /*toDynamic=*/true);
     }
@@ -3285,7 +3293,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
-    auto packedType = typeConverter.packFunctionResults(
+    auto packedType = getTypeConverter()->packFunctionResults(
         llvm::to_vector<4>(op.getOperandTypes()));
 
     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
@@ -3323,11 +3331,11 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
       return failure();
 
     // First insert it into an undef vector so we can shuffle it.
-    auto vectorType = typeConverter.convertType(splatOp.getType());
+    auto vectorType = typeConverter->convertType(splatOp.getType());
     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
     auto zero = rewriter.create<LLVM::ConstantOp>(
         splatOp.getLoc(),
-        typeConverter.convertType(rewriter.getIntegerType(32)),
+        typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
     auto v = rewriter.create<LLVM::InsertElementOp>(
@@ -3360,7 +3368,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
 
     // First insert it into an undef vector so we can shuffle it.
     auto loc = splatOp.getLoc();
-    auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
+    auto vectorTypeInfo =
+        extractNDVectorTypeInfo(resultType, *getTypeConverter());
     auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
     auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
     if (!llvmArrayTy || !llvmVectorTy)
@@ -3373,7 +3382,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     // places within the returned descriptor.
     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
     auto zero = rewriter.create<LLVM::ConstantOp>(
-        loc, typeConverter.convertType(rewriter.getIntegerType(32)),
+        loc, typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
                                                      adaptor.input(), zero);
@@ -3418,7 +3427,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
 
     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
     auto sourceElementTy =
-        typeConverter.convertType(sourceMemRefType.getElementType())
+        typeConverter->convertType(sourceMemRefType.getElementType())
             .dyn_cast_or_null<LLVM::LLVMType>();
 
     auto viewMemRefType = subViewOp.getType();
@@ -3429,9 +3438,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
                             extractFromI64ArrayAttr(subViewOp.static_strides()))
                             .cast<MemRefType>();
     auto targetElementTy =
-        typeConverter.convertType(viewMemRefType.getElementType())
+        typeConverter->convertType(viewMemRefType.getElementType())
             .dyn_cast<LLVM::LLVMType>();
-    auto targetDescTy = typeConverter.convertType(viewMemRefType)
+    auto targetDescTy = typeConverter->convertType(viewMemRefType)
                             .dyn_cast_or_null<LLVM::LLVMType>();
     if (!sourceElementTy || !targetDescTy)
       return failure();
@@ -3477,7 +3486,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
 
     // Offset.
-    auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType());
+    auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
       targetMemRef.setConstantOffset(rewriter, loc, offset);
     } else {
@@ -3553,7 +3562,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
     auto targetMemRef = MemRefDescriptor::undef(
-        rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
+        rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
@@ -3629,10 +3638,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
 
     auto viewMemRefType = viewOp.getType();
     auto targetElementTy =
-        typeConverter.convertType(viewMemRefType.getElementType())
+        typeConverter->convertType(viewMemRefType.getElementType())
             .dyn_cast<LLVM::LLVMType>();
     auto targetDescTy =
-        typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+        typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
     if (!targetDescTy)
       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
              failure();
@@ -3825,7 +3834,7 @@ struct GenericAtomicRMWOpLowering
     auto loc = atomicOp.getLoc();
     GenericAtomicRMWOp::Adaptor adaptor(operands);
     LLVM::LLVMType valueType =
-        typeConverter.convertType(atomicOp.getResult().getType())
+        typeConverter->convertType(atomicOp.getResult().getType())
             .cast<LLVM::LLVMType>();
 
     // Split the block into initial, loop, and ending parts.

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b3fa315b75a3..85d3e2bddd66 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -309,7 +309,7 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern {
     auto matmulOp = cast<vector::MatmulOp>(op);
     auto adaptor = vector::MatmulOpAdaptor(operands);
     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
-        op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
+        op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
         matmulOp.rhs_columns());
     return success();
@@ -331,7 +331,7 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
     auto transOp = cast<vector::FlatTransposeOp>(op);
     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
-        transOp, typeConverter.convertType(transOp.res().getType()),
+        transOp, typeConverter->convertType(transOp.res().getType()),
         adaptor.matrix(), transOp.rows(), transOp.columns());
     return success();
   }
@@ -354,10 +354,10 @@ class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(typeConverter, load, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
       return failure();
 
-    auto vtype = typeConverter.convertType(load.getResultVectorType());
+    auto vtype = typeConverter->convertType(load.getResultVectorType());
     Value ptr;
     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
                           vtype, ptr)))
@@ -387,10 +387,10 @@ class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(typeConverter, store, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
       return failure();
 
-    auto vtype = typeConverter.convertType(store.getValueVectorType());
+    auto vtype = typeConverter->convertType(store.getValueVectorType());
     Value ptr;
     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
                           vtype, ptr)))
@@ -420,7 +420,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(typeConverter, gather, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
       return failure();
 
     // Get index ptrs.
@@ -433,7 +433,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
 
     // Replace with the gather intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-        gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
+        gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
     return success();
   }
@@ -456,7 +456,7 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(typeConverter, scatter, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
       return failure();
 
     // Get index ptrs.
@@ -497,7 +497,7 @@ class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
 
     auto vType = expand.getResultVectorType();
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
-        op, typeConverter.convertType(vType), ptr, adaptor.mask(),
+        op, typeConverter->convertType(vType), ptr, adaptor.mask(),
         adaptor.pass_thru());
     return success();
   }
@@ -545,7 +545,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
     auto reductionOp = cast<vector::ReductionOp>(op);
     auto kind = reductionOp.kind();
     Type eltType = reductionOp.dest().getType();
-    Type llvmType = typeConverter.convertType(eltType);
+    Type llvmType = typeConverter->convertType(eltType);
     if (eltType.isIntOrIndex()) {
       // Integer reductions: add/mul/min/max/and/or/xor.
       if (kind == "add")
@@ -580,39 +580,40 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
       else
         return failure();
       return success();
-
-    } else if (eltType.isa<FloatType>()) {
-      // Floating-point reductions: add/mul/min/max
-      if (kind == "add") {
-        // Optional accumulator (or zero).
-        Value acc = operands.size() > 1 ? operands[1]
-                                        : rewriter.create<LLVM::ConstantOp>(
-                                              op->getLoc(), llvmType,
-                                              rewriter.getZeroAttr(eltType));
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
-            op, llvmType, acc, operands[0],
-            rewriter.getBoolAttr(reassociateFPReductions));
-      } else if (kind == "mul") {
-        // Optional accumulator (or one).
-        Value acc = operands.size() > 1
-                        ? operands[1]
-                        : rewriter.create<LLVM::ConstantOp>(
-                              op->getLoc(), llvmType,
-                              rewriter.getFloatAttr(eltType, 1.0));
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
-            op, llvmType, acc, operands[0],
-            rewriter.getBoolAttr(reassociateFPReductions));
-      } else if (kind == "min")
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
-            op, llvmType, operands[0]);
-      else if (kind == "max")
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
-            op, llvmType, operands[0]);
-      else
-        return failure();
-      return success();
     }
-    return failure();
+
+    if (!eltType.isa<FloatType>())
+      return failure();
+
+    // Floating-point reductions: add/mul/min/max
+    if (kind == "add") {
+      // Optional accumulator (or zero).
+      Value acc = operands.size() > 1 ? operands[1]
+                                      : rewriter.create<LLVM::ConstantOp>(
+                                            op->getLoc(), llvmType,
+                                            rewriter.getZeroAttr(eltType));
+      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
+          op, llvmType, acc, operands[0],
+          rewriter.getBoolAttr(reassociateFPReductions));
+    } else if (kind == "mul") {
+      // Optional accumulator (or one).
+      Value acc = operands.size() > 1
+                      ? operands[1]
+                      : rewriter.create<LLVM::ConstantOp>(
+                            op->getLoc(), llvmType,
+                            rewriter.getFloatAttr(eltType, 1.0));
+      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
+          op, llvmType, acc, operands[0],
+          rewriter.getBoolAttr(reassociateFPReductions));
+    } else if (kind == "min")
+      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
+                                                            operands[0]);
+    else if (kind == "max")
+      rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
+                                                            operands[0]);
+    else
+      return failure();
+    return success();
   }
 
 private:
@@ -663,7 +664,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
     auto v1Type = shuffleOp.getV1VectorType();
     auto v2Type = shuffleOp.getV2VectorType();
     auto vectorType = shuffleOp.getVectorType();
-    Type llvmType = typeConverter.convertType(vectorType);
+    Type llvmType = typeConverter->convertType(vectorType);
     auto maskArrayAttr = shuffleOp.mask();
 
     // Bail if result type cannot be lowered.
@@ -695,9 +696,9 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
         extPos -= v1Dim;
         value = adaptor.v2();
       }
-      Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
-                                 rank, extPos);
-      insert = insertOne(rewriter, typeConverter, loc, insert, extract,
+      Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
+                                 llvmType, rank, extPos);
+      insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
                          llvmType, rank, insPos++);
     }
     rewriter.replaceOp(op, insert);
@@ -718,7 +719,7 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
     auto adaptor = vector::ExtractElementOpAdaptor(operands);
     auto extractEltOp = cast<vector::ExtractElementOp>(op);
     auto vectorType = extractEltOp.getVectorType();
-    auto llvmType = typeConverter.convertType(vectorType.getElementType());
+    auto llvmType = typeConverter->convertType(vectorType.getElementType());
 
     // Bail if result type cannot be lowered.
     if (!llvmType)
@@ -745,7 +746,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
     auto extractOp = cast<vector::ExtractOp>(op);
     auto vectorType = extractOp.getVectorType();
     auto resultType = extractOp.getResult().getType();
-    auto llvmResultType = typeConverter.convertType(resultType);
+    auto llvmResultType = typeConverter->convertType(resultType);
     auto positionArrayAttr = extractOp.position();
 
     // Bail if result type cannot be lowered.
@@ -769,7 +770,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
       auto nMinusOnePositionAttrs =
           ArrayAttr::get(positionAttrs.drop_back(), context);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, typeConverter.convertType(oneDVectorType), extracted,
+          loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
     }
 
@@ -833,7 +834,7 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
     auto adaptor = vector::InsertElementOpAdaptor(operands);
     auto insertEltOp = cast<vector::InsertElementOp>(op);
     auto vectorType = insertEltOp.getDestVectorType();
-    auto llvmType = typeConverter.convertType(vectorType);
+    auto llvmType = typeConverter->convertType(vectorType);
 
     // Bail if result type cannot be lowered.
     if (!llvmType)
@@ -860,7 +861,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
     auto insertOp = cast<vector::InsertOp>(op);
     auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
-    auto llvmResultType = typeConverter.convertType(destVectorType);
+    auto llvmResultType = typeConverter->convertType(destVectorType);
     auto positionArrayAttr = insertOp.position();
 
     // Bail if result type cannot be lowered.
@@ -887,7 +888,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
       auto nMinusOnePositionAttrs =
           ArrayAttr::get(positionAttrs.drop_back(), context);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, typeConverter.convertType(oneDVectorType), extracted,
+          loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
     }
 
@@ -895,7 +896,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
     Value inserted = rewriter.create<LLVM::InsertElementOp>(
-        loc, typeConverter.convertType(oneDVectorType), extracted,
+        loc, typeConverter->convertType(oneDVectorType), extracted,
         adaptor.source(), constant);
 
     // Potential insertion of resulting 1-D vector into array.
@@ -1000,7 +1001,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
     Value extracted =
         rewriter.create<ExtractOp>(loc, op.dest(),
                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
-                                                  /*dropFront=*/rankRest));
+                                                  /*dropBack=*/rankRest));
     // A 
diff erent pattern will kick in for InsertStridedSlice with matching
     // ranks.
     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
@@ -1010,7 +1011,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
     rewriter.replaceOpWithNewOp<InsertOp>(
         op, stridedSliceInnerOp.getResult(), op.dest(),
         getI64SubArray(op.offsets(), /*dropFront=*/0,
-                       /*dropFront=*/rankRest));
+                       /*dropBack=*/rankRest));
     return success();
   }
 };
@@ -1144,7 +1145,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
       return failure();
     MemRefDescriptor sourceMemRef(operands[0]);
 
-    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+    auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
                                       .dyn_cast_or_null<LLVM::LLVMType>();
     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
       return failure();
@@ -1234,7 +1235,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     if (!strides)
       return failure();
 
-    auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+    auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
 
     Location loc = op->getLoc();
     MemRefType memRefType = xferOp.getMemRefType();
@@ -1279,8 +1280,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
           loc, vecTy.getPointerTo(), dataPtr);
 
     if (!xferOp.isMaskedDim(0))
-      return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
-                                              xferOp, operands, vectorDataPtr);
+      return replaceTransferOpWithLoadOrStore(
+          rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
 
     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@@ -1297,8 +1298,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
                                        vecWidth, dim, &off);
 
     // 5. Rewrite as a masked read / write.
-    return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
-                                       operands, vectorDataPtr, mask);
+    return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
+                                       xferOp, operands, vectorDataPtr, mask);
   }
 
 private:
@@ -1331,7 +1332,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
     auto adaptor = vector::PrintOpAdaptor(operands);
     Type printType = printOp.getPrintType();
 
-    if (typeConverter.convertType(printType) == nullptr)
+    if (typeConverter->convertType(printType) == nullptr)
       return failure();
 
     // Make sure element type has runtime support.
@@ -1421,10 +1422,10 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
     for (int64_t d = 0; d < dim; ++d) {
       auto reducedType =
           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
-      auto llvmType = typeConverter.convertType(
+      auto llvmType = typeConverter->convertType(
           rank > 1 ? reducedType : vectorType.getElementType());
-      Value nestedVal =
-          extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
+      Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
+                                   llvmType, rank, d);
       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
                 conversion);
       if (d != dim - 1)

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 26b8bec1f3fc..61f094746a0a 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -79,7 +79,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     if (!xferOp.isMaskedDim(0))
       return failure();
 
-    auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+    auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
     LLVM::LLVMType vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
     unsigned vecWidth = vecTy.getVectorNumElements();
@@ -142,9 +142,9 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     Value int32Zero = rewriter.create<LLVM::ConstantOp>(
         loc, toLLVMTy(i32Ty),
         rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
-    return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
-                                      xferOp, vecTy, dwordConfig, int32Zero,
-                                      int32Zero, int1False, int1False);
+    return replaceTransferOpWithMubuf(
+        rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
+        dwordConfig, int32Zero, int32Zero, int1False, int1False);
   }
 };
 } // end anonymous namespace


        


More information about the Mlir-commits mailing list