[llvm-branch-commits] [mlir] dcec2ca - Remove typeConverter from ConvertToLLVMPattern and use the existing one in ConversionPattern.
Christian Sigg via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Dec 4 05:33:00 PST 2020
Author: Christian Sigg
Date: 2020-12-04T14:27:16+01:00
New Revision: dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947
URL: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947
DIFF: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947.diff
LOG: Remove typeConverter from ConvertToLLVMPattern and use the existing one in ConversionPattern.
ftynse
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D92564
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 7b8bcdff4deb..bf41f29749de 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -71,7 +71,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
+ LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
@@ -485,6 +485,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
+ LLVMTypeConverter *getTypeConverter() const;
+
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
LLVM::LLVMType getIndexType() const;
@@ -556,10 +558,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
Value allocatedPtr, Value alignedPtr,
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const;
-
-protected:
- /// Reference to the type converter, with potential extensions.
- LLVMTypeConverter &typeConverter;
};
/// Utility class for operation conversions targeting the LLVM dialect that
@@ -644,7 +642,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
- operands, this->typeConverter,
+ operands, *this->getTypeConverter(),
rewriter);
}
};
@@ -666,9 +664,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected same operands and result type");
- return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(),
- operands, this->typeConverter,
- rewriter);
+ return LLVM::detail::vectorOneToOneRewrite(
+ op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
+ rewriter);
}
};
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 3950562539f6..fe06e12c8f21 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -86,7 +86,7 @@ struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_ps_512>(
- *this, this->typeConverter, op, operands, rewriter);
+ *this, *getTypeConverter(), op, operands, rewriter);
}
};
@@ -103,7 +103,7 @@ struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_pd_512>(
- *this, this->typeConverter, op, operands, rewriter);
+ *this, *getTypeConverter(), op, operands, rewriter);
}
};
@@ -120,7 +120,7 @@ struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_ps_512>(
- *this, this->typeConverter, op, operands, rewriter);
+ *this, *getTypeConverter(), op, operands, rewriter);
}
};
@@ -137,7 +137,7 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_pd_512>(
- *this, this->typeConverter, op, operands, rewriter);
+ *this, *getTypeConverter(), op, operands, rewriter);
}
};
} // namespace
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index ad84216d1e3b..810511194f68 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -72,7 +72,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
protected:
- MLIRContext *context = &this->typeConverter.getContext();
+ MLIRContext *context = &this->getTypeConverter()->getContext();
LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
@@ -81,7 +81,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy(
- context, this->typeConverter.getPointerBitwidth(0));
+ context, this->getTypeConverter()->getPointerBitwidth(0));
FunctionCallBuilder moduleLoadCallBuilder = {
"mgpuModuleLoad",
@@ -333,8 +333,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
- auto arguments =
- typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
+ auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
+ operands, rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
@@ -486,7 +486,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
OpBuilder &builder) const {
auto loc = launchOp.getLoc();
auto numKernelOperands = launchOp.getNumKernelOperands();
- auto arguments = typeConverter.promoteOperands(
+ auto arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
operands.take_back(numKernelOperands), builder);
auto numArguments = arguments.size();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index a3fad7e71c84..69ea393e5df1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -41,7 +41,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
uint64_t numElements = type.getNumElements();
- auto elementType = typeConverter.convertType(type.getElementType())
+ auto elementType = typeConverter->convertType(type.getElementType())
.template cast<LLVM::LLVMType>();
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
std::string name = std::string(
@@ -54,14 +54,14 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
}
// Rewrite the original GPU function to an LLVM function.
- auto funcType = typeConverter.convertType(gpuFuncOp.getType())
+ auto funcType = typeConverter->convertType(gpuFuncOp.getType())
.template cast<LLVM::LLVMType>()
.getPointerElementTy();
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
- typeConverter.convertFunctionSignature(
+ getTypeConverter()->convertFunctionSignature(
gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
// Create the new function operation. Only copy those attributes that are
@@ -110,7 +110,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution.getType().cast<MemRefType>();
auto descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, typeConverter, type, memory);
+ rewriter, loc, *getTypeConverter(), type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
}
@@ -127,7 +127,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
- auto ptrType = typeConverter.convertType(type.getElementType())
+ auto ptrType = typeConverter->convertType(type.getElementType())
.template cast<LLVM::LLVMType>()
.getPointerTo(AllocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
@@ -136,7 +136,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, typeConverter, type, allocated);
+ rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + en.index(), descr);
}
@@ -145,8 +145,8 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Move the region to the new function, update the entry block signature.
rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
llvmFuncOp.end());
- if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter,
- &signatureConversion)))
+ if (failed(rewriter.convertRegionTypes(
+ &llvmFuncOp.getBody(), *typeConverter, &signatureConversion)))
return failure();
rewriter.eraseOp(gpuFuncOp);
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index f32c664c17c4..b907703995d8 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -135,8 +135,8 @@ class RangeOpConversion : public ConvertToLLVMPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
- auto rangeDescriptorTy =
- convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
+ auto rangeDescriptorTy = convertRangeType(
+ rangeOp.getType().cast<RangeType>(), *getTypeConverter());
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -181,7 +181,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
- BaseViewConversionHelper desc(typeConverter.convertType(dstType));
+ BaseViewConversionHelper desc(typeConverter->convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
desc.setOffset(baseDesc.offset());
@@ -214,11 +214,11 @@ class SliceOpConversion : public ConvertToLLVMPattern {
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
- auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
+ auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
BaseViewConversionHelper desc(
- typeConverter.convertType(sliceOp.getShapedType()));
+ typeConverter->convertType(sliceOp.getShapedType()));
// TODO: extract sizes and emit asserts.
SmallVector<Value, 4> strides(memRefType.getRank());
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 91e97ca1ec50..c589ef69f2c4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -35,7 +35,7 @@ struct RegionOpConversion : public ConvertToLLVMPattern {
curOp.getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
- if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
+ if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
return failure();
rewriter.eraseOp(op);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 525a5be24485..f83f72d1d10e 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -224,7 +224,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType =
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
- auto dstGlobalType = typeConverter.convertType(pointeeType);
+ auto dstGlobalType = typeConverter->convertType(pointeeType);
if (!dstGlobalType)
return failure();
std::string name =
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index f54ffc1c9d6c..17a065463297 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -446,8 +446,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter,
PatternBenefit benefit)
- : ConversionPattern(rootOpName, benefit, typeConverter, context),
- typeConverter(typeConverter) {}
+ : ConversionPattern(rootOpName, benefit, typeConverter, context) {}
//===----------------------------------------------------------------------===//
// StructBuilder implementation
@@ -1013,27 +1012,32 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
}
+LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
+ return static_cast<LLVMTypeConverter *>(
+ ConversionPattern::getTypeConverter());
+}
+
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
- return *typeConverter.getDialect();
+ return *getTypeConverter()->getDialect();
}
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
- return typeConverter.getIndexType();
+ return getTypeConverter()->getIndexType();
}
LLVM::LLVMType
ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
return LLVM::LLVMType::getIntNTy(
- &typeConverter.getContext(),
- typeConverter.getPointerBitwidth(addressSpace));
+ &getTypeConverter()->getContext(),
+ getTypeConverter()->getPointerBitwidth(addressSpace));
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
- return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
+ return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext());
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
- return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
+ return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext());
}
Value ConvertToLLVMPattern::createIndexConstant(
@@ -1086,7 +1090,7 @@ Value ConvertToLLVMPattern::getDataPtr(
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
- if (!typeConverter.convertType(type.getElementType()))
+ if (!typeConverter->convertType(type.getElementType()))
return false;
return type.getAffineMaps().empty() ||
llvm::all_of(type.getAffineMaps(),
@@ -1095,7 +1099,7 @@ bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
- auto structElementType = unwrap(typeConverter.convertType(elementType));
+ auto structElementType = unwrap(typeConverter->convertType(elementType));
return structElementType.getPointerTo(type.getMemorySpace());
}
@@ -1155,7 +1159,7 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto convertedPtrType =
- typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
+ typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
@@ -1179,7 +1183,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const {
- auto structType = typeConverter.convertType(memRefType);
+ auto structType = typeConverter->convertType(memRefType);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
// Field 1: Allocated pointer, used for malloc/free.
@@ -1347,7 +1351,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
- auto llvmType = typeConverter.convertFunctionSignature(
+ auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
if (!llvmType)
return nullptr;
@@ -1379,7 +1383,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
- if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
@@ -1402,14 +1406,14 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (!newFuncOp)
return failure();
- if (typeConverter.getOptions().emitCWrappers ||
+ if (getTypeConverter()->getOptions().emitCWrappers ||
funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
if (newFuncOp.isExternal())
- wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp,
- newFuncOp);
+ wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
+ funcOp, newFuncOp);
else
- wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp,
- newFuncOp);
+ wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
+ funcOp, newFuncOp);
}
rewriter.eraseOp(funcOp);
@@ -1472,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
Value desc = MemRefDescriptor::fromStaticShape(
- rewriter, loc, typeConverter, memrefTy, arg);
+ rewriter, loc, *getTypeConverter(), memrefTy, arg);
rewriter.replaceOp(placeholder, {desc});
}
@@ -1757,7 +1761,7 @@ struct CreateComplexOpLowering
// Pack real and imaginary part in a complex number struct.
auto loc = op.getLoc();
- auto structType = typeConverter.convertType(complexOp.getType());
+ auto structType = typeConverter->convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
complexStruct.setReal(rewriter, loc, transformed.real());
complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
@@ -1836,7 +1840,7 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
// Initialize complex number struct for result.
- auto structType = this->typeConverter.convertType(op.getType());
+ auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
@@ -1863,7 +1867,7 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
// Initialize complex number struct for result.
- auto structType = this->typeConverter.convertType(op.getType());
+ auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
@@ -1887,7 +1891,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
- auto type = typeConverter.convertType(op.getResult().getType())
+ auto type = typeConverter->convertType(op.getResult().getType())
.dyn_cast_or_null<LLVM::LLVMType>();
if (!type)
return rewriter.notifyMatchFailure(op, "failed to convert result type");
@@ -1905,9 +1909,9 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
return rewriter.notifyMatchFailure(
op, "referring to a symbol outside of the current module");
- return LLVM::detail::oneToOneRewrite(op,
- LLVM::ConstantOp::getOperationName(),
- operands, typeConverter, rewriter);
+ return LLVM::detail::oneToOneRewrite(
+ op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
+ rewriter);
}
};
@@ -1916,7 +1920,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::createIndexConstant;
using ConvertToLLVMPattern::getIndexType;
using ConvertToLLVMPattern::getVoidPtrType;
- using ConvertToLLVMPattern::typeConverter;
explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
@@ -2288,11 +2291,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
if (numResults != 0) {
if (!(packedResult =
- this->typeConverter.packFunctionResults(resultTypes)))
+ this->getTypeConverter()->packFunctionResults(resultTypes)))
return failure();
}
- auto promoted = this->typeConverter.promoteOperands(
+ auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
rewriter);
auto newOp = rewriter.create<LLVM::CallOp>(
@@ -2309,23 +2312,23 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type =
- this->typeConverter.convertType(callOp.getResult(i).getType());
+ this->typeConverter->convertType(callOp.getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
callOp.getLoc(), type, newOp->getResult(0),
rewriter.getI64ArrayAttr(i)));
}
}
- if (this->typeConverter.getOptions().useBarePtrCallConv) {
+ if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
// For the bare-ptr calling convention, promote memref results to
// descriptors.
assert(results.size() == resultTypes.size() &&
"The number of arguments and types doesn't match");
- this->typeConverter.promoteBarePtrsToDescriptors(
+ this->getTypeConverter()->promoteBarePtrsToDescriptors(
rewriter, callOp.getLoc(), resultTypes, results);
} else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- this->typeConverter, resultTypes,
- results,
+ *this->getTypeConverter(),
+ resultTypes, results,
/*toDynamic=*/false))) {
return failure();
}
@@ -2410,7 +2413,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
if (!isSupportedMemRefType(type))
return failure();
- LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+ LLVM::LLVMType arrayTy =
+ convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
LLVM::Linkage linkage =
global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
@@ -2449,14 +2453,15 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
unsigned memSpace = type.getMemorySpace();
- LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter);
+ LLVM::LLVMType arrayTy =
+ convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
LLVM::LLVMType elementType =
- unwrap(typeConverter.convertType(type.getElementType()));
+ unwrap(typeConverter->convertType(type.getElementType()));
LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
SmallVector<Value, 4> operands = {addressOf};
@@ -2517,7 +2522,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
return failure();
return handleMultidimensionalVectors(
- op.getOperation(), operands, typeConverter,
+ op.getOperation(), operands, *getTypeConverter(),
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
@@ -2546,8 +2551,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
// a sanity check that the underlying structs are the same. Once op
// semantics are relaxed we can revisit.
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
- return success(typeConverter.convertType(srcType) ==
- typeConverter.convertType(dstType));
+ return success(typeConverter->convertType(srcType) ==
+ typeConverter->convertType(dstType));
// At least one of the operands is unranked type
assert(srcType.isa<UnrankedMemRefType>() ||
@@ -2566,7 +2571,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
- auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
+ auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
auto loc = memRefCastOp.getLoc();
// For ranked/ranked case, just keep the original descriptor.
@@ -2581,7 +2586,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
auto srcMemRefType = srcType.cast<MemRefType>();
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
- auto ptr = typeConverter.promoteOneMemRefDescriptor(
+ auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
loc, transformed.source(), rewriter);
// voidptr = BitCastOp srcType* to void*
auto voidPtr =
@@ -2589,7 +2594,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
.getResult();
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(rewriter.getIntegerType(64)),
+ loc, typeConverter->convertType(rewriter.getIntegerType(64)),
rewriter.getI64IntegerAttr(rank));
// undef = UndefOp
UnrankedMemRefDescriptor memRefDesc =
@@ -2693,7 +2698,7 @@ struct MemRefReinterpretCastOpLowering
Value *descriptor) const {
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
- auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+ auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
@@ -2704,8 +2709,9 @@ struct MemRefReinterpretCastOpLowering
// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
- extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(),
- adaptor.source(), &allocatedPtr, &alignedPtr);
+ extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+ castOp.source(), adaptor.source(), &allocatedPtr,
+ &alignedPtr);
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
desc.setAlignedPtr(rewriter, loc, alignedPtr);
@@ -2779,10 +2785,10 @@ struct MemRefReshapeOpLowering
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
auto targetDesc = UnrankedMemRefDescriptor::undef(
- rewriter, loc, unwrap(typeConverter.convertType(targetType)));
+ rewriter, loc, unwrap(typeConverter->convertType(targetType)));
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
+ UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
targetDesc, sizes);
Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
loc, getVoidPtrType(), sizes.front(), llvm::None);
@@ -2790,37 +2796,38 @@ struct MemRefReshapeOpLowering
// Extract pointers and offset from the source memref.
Value allocatedPtr, alignedPtr, offset;
- extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(),
- adaptor.source(), &allocatedPtr, &alignedPtr,
- &offset);
+ extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
+ reshapeOp.source(), adaptor.source(),
+ &allocatedPtr, &alignedPtr, &offset);
// Set pointers and offset.
LLVM::LLVMType llvmElementType =
- unwrap(typeConverter.convertType(elementType));
+ unwrap(typeConverter->convertType(elementType));
LLVM::LLVMType elementPtrPtrType =
llvmElementType.getPointerTo(addressSpace).getPointerTo();
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
elementPtrPtrType, allocatedPtr);
- UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter,
+ UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
underlyingDescPtr,
elementPtrPtrType, alignedPtr);
- UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter,
+ UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
underlyingDescPtr, elementPtrPtrType,
offset);
// Use the offset pointer as base for further addressing. Copy over the new
// shape and compute strides. For this, we create a loop from rank-1 to 0.
Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
- rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
+ rewriter, loc, *getTypeConverter(), underlyingDescPtr,
+ elementPtrPtrType);
Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
- rewriter, loc, typeConverter, targetSizesBase, resultRank);
+ rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
Value oneIndex = createIndexConstant(rewriter, loc, 1);
Value resultRankMinusOne =
rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
Block *initBlock = rewriter.getInsertionBlock();
- LLVM::LLVMType indexType = typeConverter.getIndexType();
+ LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
@@ -2854,11 +2861,11 @@ struct MemRefReshapeOpLowering
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
- UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter,
+ UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
targetSizesBase, indexArg, size);
// Write stride value and compute next one.
- UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter,
+ UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
targetStridesBase, indexArg, strideArg);
Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
@@ -2892,7 +2899,7 @@ struct DialectCastOpLowering
ConversionPatternRewriter &rewriter) const override {
LLVM::DialectCastOp::Adaptor transformed(operands);
if (transformed.in().getType() !=
- typeConverter.convertType(castOp.getType())) {
+ typeConverter->convertType(castOp.getType())) {
return failure();
}
rewriter.replaceOp(castOp, transformed.in());
@@ -2942,15 +2949,16 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc,
- typeConverter.convertType(scalarMemRefType)
+ typeConverter->convertType(scalarMemRefType)
.cast<LLVM::LLVMType>()
.getPointerTo(addressSpace),
underlyingRankedDesc);
// Get pointer to offset field of memref<element_type> descriptor.
- Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace);
+ Type indexPtrTy =
+ getTypeConverter()->getIndexType().getPointerTo(addressSpace);
Value two = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(rewriter.getI32Type()),
+ loc, typeConverter->convertType(rewriter.getI32Type()),
rewriter.getI32IntegerAttr(2));
Value offsetPtr = rewriter.create<LLVM::GEPOp>(
loc, indexPtrTy, scalarMemRefDescPtr,
@@ -3082,7 +3090,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
transformed.indices(), rewriter);
// Replace with llvm.prefetch.
- auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
+ auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
auto isWrite = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
auto localityHint = rewriter.create<LLVM::ConstantOp>(
@@ -3110,7 +3118,7 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
IndexCastOpAdaptor transformed(operands);
auto targetType =
- this->typeConverter.convertType(indexCastOp.getResult().getType())
+ typeConverter->convertType(indexCastOp.getResult().getType())
.cast<LLVM::LLVMType>();
auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
unsigned targetBits = targetType.getIntegerBitWidth();
@@ -3144,7 +3152,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
CmpIOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
- cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()),
+ cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -3162,7 +3170,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
CmpFOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
- cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()),
+ cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -3248,7 +3256,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
- if (typeConverter.getOptions().useBarePtrCallConv) {
+ if (getTypeConverter()->getOptions().useBarePtrCallConv) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
for (auto it : llvm::zip(op->getOperands(), operands)) {
@@ -3266,7 +3274,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
}
} else {
updatedOperands = llvm::to_vector<4>(operands);
- copyUnrankedDescriptors(rewriter, loc, typeConverter,
+ copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(),
op.getOperands().getTypes(), updatedOperands,
/*toDynamic=*/true);
}
@@ -3285,7 +3293,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto packedType = typeConverter.packFunctionResults(
+ auto packedType = getTypeConverter()->packFunctionResults(
llvm::to_vector<4>(op.getOperandTypes()));
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
@@ -3323,11 +3331,11 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
return failure();
// First insert it into an undef vector so we can shuffle it.
- auto vectorType = typeConverter.convertType(splatOp.getType());
+ auto vectorType = typeConverter->convertType(splatOp.getType());
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
splatOp.getLoc(),
- typeConverter.convertType(rewriter.getIntegerType(32)),
+ typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
auto v = rewriter.create<LLVM::InsertElementOp>(
@@ -3360,7 +3368,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// First insert it into an undef vector so we can shuffle it.
auto loc = splatOp.getLoc();
- auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
+ auto vectorTypeInfo =
+ extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmArrayTy || !llvmVectorTy)
@@ -3373,7 +3382,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(rewriter.getIntegerType(32)),
+ loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
adaptor.input(), zero);
@@ -3418,7 +3427,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
auto sourceElementTy =
- typeConverter.convertType(sourceMemRefType.getElementType())
+ typeConverter->convertType(sourceMemRefType.getElementType())
.dyn_cast_or_null<LLVM::LLVMType>();
auto viewMemRefType = subViewOp.getType();
@@ -3429,9 +3438,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
extractFromI64ArrayAttr(subViewOp.static_strides()))
.cast<MemRefType>();
auto targetElementTy =
- typeConverter.convertType(viewMemRefType.getElementType())
+ typeConverter->convertType(viewMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
- auto targetDescTy = typeConverter.convertType(viewMemRefType)
+ auto targetDescTy = typeConverter->convertType(viewMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!sourceElementTy || !targetDescTy)
return failure();
@@ -3477,7 +3486,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
// Offset.
- auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType());
+ auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
if (!ShapedType::isDynamicStrideOrOffset(offset)) {
targetMemRef.setConstantOffset(rewriter, loc, offset);
} else {
@@ -3553,7 +3562,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
auto targetMemRef = MemRefDescriptor::undef(
- rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
+ rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
@@ -3629,10 +3638,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
- typeConverter.convertType(viewMemRefType.getElementType())
+ typeConverter->convertType(viewMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
auto targetDescTy =
- typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+ typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
failure();
@@ -3825,7 +3834,7 @@ struct GenericAtomicRMWOpLowering
auto loc = atomicOp.getLoc();
GenericAtomicRMWOp::Adaptor adaptor(operands);
LLVM::LLVMType valueType =
- typeConverter.convertType(atomicOp.getResult().getType())
+ typeConverter->convertType(atomicOp.getResult().getType())
.cast<LLVM::LLVMType>();
// Split the block into initial, loop, and ending parts.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b3fa315b75a3..85d3e2bddd66 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -309,7 +309,7 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern {
auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
+ op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
return success();
@@ -331,7 +331,7 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
auto transOp = cast<vector::FlatTransposeOp>(op);
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
- transOp, typeConverter.convertType(transOp.res().getType()),
+ transOp, typeConverter->convertType(transOp.res().getType()),
adaptor.matrix(), transOp.rows(), transOp.columns());
return success();
}
@@ -354,10 +354,10 @@ class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, load, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
return failure();
- auto vtype = typeConverter.convertType(load.getResultVectorType());
+ auto vtype = typeConverter->convertType(load.getResultVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
vtype, ptr)))
@@ -387,10 +387,10 @@ class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, store, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
return failure();
- auto vtype = typeConverter.convertType(store.getValueVectorType());
+ auto vtype = typeConverter->convertType(store.getValueVectorType());
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
vtype, ptr)))
@@ -420,7 +420,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, gather, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
return failure();
// Get index ptrs.
@@ -433,7 +433,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
+ gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
return success();
}
@@ -456,7 +456,7 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, scatter, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
return failure();
// Get index ptrs.
@@ -497,7 +497,7 @@ class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- op, typeConverter.convertType(vType), ptr, adaptor.mask(),
+ op, typeConverter->convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
@@ -545,7 +545,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
- Type llvmType = typeConverter.convertType(eltType);
+ Type llvmType = typeConverter->convertType(eltType);
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
@@ -580,39 +580,40 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
else
return failure();
return success();
-
- } else if (eltType.isa<FloatType>()) {
- // Floating-point reductions: add/mul/min/max
- if (kind == "add") {
- // Optional accumulator (or zero).
- Value acc = operands.size() > 1 ? operands[1]
- : rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
- rewriter.getZeroAttr(eltType));
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
- op, llvmType, acc, operands[0],
- rewriter.getBoolAttr(reassociateFPReductions));
- } else if (kind == "mul") {
- // Optional accumulator (or one).
- Value acc = operands.size() > 1
- ? operands[1]
- : rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
- rewriter.getFloatAttr(eltType, 1.0));
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
- op, llvmType, acc, operands[0],
- rewriter.getBoolAttr(reassociateFPReductions));
- } else if (kind == "min")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
- op, llvmType, operands[0]);
- else if (kind == "max")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
- op, llvmType, operands[0]);
- else
- return failure();
- return success();
}
- return failure();
+
+ if (!eltType.isa<FloatType>())
+ return failure();
+
+ // Floating-point reductions: add/mul/min/max
+ if (kind == "add") {
+ // Optional accumulator (or zero).
+ Value acc = operands.size() > 1 ? operands[1]
+ : rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType,
+ rewriter.getZeroAttr(eltType));
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
+ op, llvmType, acc, operands[0],
+ rewriter.getBoolAttr(reassociateFPReductions));
+ } else if (kind == "mul") {
+ // Optional accumulator (or one).
+ Value acc = operands.size() > 1
+ ? operands[1]
+ : rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType,
+ rewriter.getFloatAttr(eltType, 1.0));
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
+ op, llvmType, acc, operands[0],
+ rewriter.getBoolAttr(reassociateFPReductions));
+ } else if (kind == "min")
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
+ operands[0]);
+ else if (kind == "max")
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
+ operands[0]);
+ else
+ return failure();
+ return success();
}
private:
@@ -663,7 +664,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
- Type llvmType = typeConverter.convertType(vectorType);
+ Type llvmType = typeConverter->convertType(vectorType);
auto maskArrayAttr = shuffleOp.mask();
// Bail if result type cannot be lowered.
@@ -695,9 +696,9 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
extPos -= v1Dim;
value = adaptor.v2();
}
- Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
- rank, extPos);
- insert = insertOne(rewriter, typeConverter, loc, insert, extract,
+ Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
+ llvmType, rank, extPos);
+ insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
@@ -718,7 +719,7 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
auto adaptor = vector::ExtractElementOpAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
- auto llvmType = typeConverter.convertType(vectorType.getElementType());
+ auto llvmType = typeConverter->convertType(vectorType.getElementType());
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -745,7 +746,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
- auto llvmResultType = typeConverter.convertType(resultType);
+ auto llvmResultType = typeConverter->convertType(resultType);
auto positionArrayAttr = extractOp.position();
// Bail if result type cannot be lowered.
@@ -769,7 +770,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, typeConverter.convertType(oneDVectorType), extracted,
+ loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
@@ -833,7 +834,7 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
auto adaptor = vector::InsertElementOpAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter.convertType(vectorType);
+ auto llvmType = typeConverter->convertType(vectorType);
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -860,7 +861,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
- auto llvmResultType = typeConverter.convertType(destVectorType);
+ auto llvmResultType = typeConverter->convertType(destVectorType);
auto positionArrayAttr = insertOp.position();
// Bail if result type cannot be lowered.
@@ -887,7 +888,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, typeConverter.convertType(oneDVectorType), extracted,
+ loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
@@ -895,7 +896,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
- loc, typeConverter.convertType(oneDVectorType), extracted,
+ loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.source(), constant);
// Potential insertion of resulting 1-D vector into array.
@@ -1000,7 +1001,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
Value extracted =
rewriter.create<ExtractOp>(loc, op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropFront=*/rankRest));
+ /*dropBack=*/rankRest));
// A
diff erent pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
@@ -1010,7 +1011,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropFront=*/rankRest));
+ /*dropBack=*/rankRest));
return success();
}
};
@@ -1144,7 +1145,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
- auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
+ auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
@@ -1234,7 +1235,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
if (!strides)
return failure();
- auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+ auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
Location loc = op->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
@@ -1279,8 +1280,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
loc, vecTy.getPointerTo(), dataPtr);
if (!xferOp.isMaskedDim(0))
- return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
- xferOp, operands, vectorDataPtr);
+ return replaceTransferOpWithLoadOrStore(
+ rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@@ -1297,8 +1298,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
- return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
- operands, vectorDataPtr, mask);
+ return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
+ xferOp, operands, vectorDataPtr, mask);
}
private:
@@ -1331,7 +1332,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
- if (typeConverter.convertType(printType) == nullptr)
+ if (typeConverter->convertType(printType) == nullptr)
return failure();
// Make sure element type has runtime support.
@@ -1421,10 +1422,10 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
- auto llvmType = typeConverter.convertType(
+ auto llvmType = typeConverter->convertType(
rank > 1 ? reducedType : vectorType.getElementType());
- Value nestedVal =
- extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
+ Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
+ llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
conversion);
if (d != dim - 1)
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 26b8bec1f3fc..61f094746a0a 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -79,7 +79,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
if (!xferOp.isMaskedDim(0))
return failure();
- auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+ auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
LLVM::LLVMType vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
unsigned vecWidth = vecTy.getVectorNumElements();
@@ -142,9 +142,9 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
Value int32Zero = rewriter.create<LLVM::ConstantOp>(
loc, toLLVMTy(i32Ty),
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
- return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
- xferOp, vecTy, dwordConfig, int32Zero,
- int32Zero, int1False, int1False);
+ return replaceTransferOpWithMubuf(
+ rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
+ dwordConfig, int32Zero, int32Zero, int1False, int1False);
}
};
} // end anonymous namespace
More information about the llvm-branch-commits
mailing list