[Mlir-commits] [mlir] 0f04384 - [mlir] NFC: Rename LLVMOpLowering::lowering to LLVMOpLowering::typeConverter
Alex Zinenko
llvmlistbot at llvm.org
Tue Feb 18 06:57:17 PST 2020
Author: Alex Zinenko
Date: 2020-02-18T15:57:10+01:00
New Revision: 0f04384daf78e26652bae3c5ea9cc201c9099b9d
URL: https://github.com/llvm/llvm-project/commit/0f04384daf78e26652bae3c5ea9cc201c9099b9d
DIFF: https://github.com/llvm/llvm-project/commit/0f04384daf78e26652bae3c5ea9cc201c9099b9d.diff
LOG: [mlir] NFC: Rename LLVMOpLowering::lowering to LLVMOpLowering::typeConverter
The existing name is an artifact dating back to the times when we did not have
a dedicated TypeConverter infrastructure. It is also confusing with with the
name of classes using it.
Differential revision: https://reviews.llvm.org/D74707
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 8ab7b17e5458..e791430066ec 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -363,17 +363,15 @@ class UnrankedMemRefDescriptor : public StructBuilder {
static unsigned getNumUnpackedValues() { return 2; }
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
-/// conversion patterns with an access to the containing LLVMLowering for the
-/// purpose of type conversions.
+/// conversion patterns with access to an LLVMTypeConverter.
class LLVMOpLowering : public ConversionPattern {
public:
LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
- LLVMTypeConverter &lowering, PatternBenefit benefit = 1);
+ LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1);
protected:
- // Back-reference to the lowering class, used to call type and function
- // conversions accounting for potential extensions.
- LLVMTypeConverter &lowering;
+ /// Reference to the type converter, with potential extensions.
+ LLVMTypeConverter &typeConverter;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 7c8ce43ed16f..fff8d206e8b8 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -51,7 +51,7 @@ struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
- auto dialect = lowering.getDialect();
+ auto dialect = typeConverter.getDialect();
Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
case X:
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 059bdfb1661f..4fb131ea6abc 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -44,7 +44,7 @@ struct OpToFuncCallLowering : public LLVMOpLowering {
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- LLVMType resultType = lowering.convertType(op->getResult(0).getType())
+ LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
.template cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, operands);
StringRef funcName = getFunctionName(resultType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d9f37dbf0d8c..5cc89580ddfd 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -337,7 +337,7 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
// Clamp lane: `activeWidth - 1`
Value maskAndClamp =
rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one);
- auto dialect = lowering.getDialect();
+ auto dialect = typeConverter.getDialect();
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
@@ -490,7 +490,7 @@ struct GPUShuffleOpLowering : public LLVMOpLowering {
Location loc = op->getLoc();
gpu::ShuffleOpOperandAdaptor adaptor(operands);
- auto dialect = lowering.getDialect();
+ auto dialect = typeConverter.getDialect();
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
@@ -544,8 +544,8 @@ struct GPUFuncOpLowering : LLVMOpLowering {
uint64_t numElements = type.getNumElements();
- auto elementType =
- lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+ auto elementType = typeConverter.convertType(type.getElementType())
+ .cast<LLVM::LLVMType>();
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
std::string name = std::string(
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
@@ -557,15 +557,15 @@ struct GPUFuncOpLowering : LLVMOpLowering {
}
// Rewrite the original GPU function to an LLVM function.
- auto funcType = lowering.convertType(gpuFuncOp.getType())
+ auto funcType = typeConverter.convertType(gpuFuncOp.getType())
.cast<LLVM::LLVMType>()
.getPointerElementTy();
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
- lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false,
- signatureConversion);
+ typeConverter.convertFunctionSignature(
+ gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
@@ -592,7 +592,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
Value zero = nullptr;
if (!workgroupBuffers.empty())
@@ -612,15 +612,15 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// and canonicalize that away later.
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution.getType().cast<MemRefType>();
- auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
- type, memory);
+ auto descr = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, typeConverter, type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
}
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
auto type = attribution.getType().cast<MemRefType>();
@@ -630,7 +630,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
- auto ptrType = lowering.convertType(type.getElementType())
+ auto ptrType = typeConverter.convertType(type.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
Value numElements = rewriter.create<LLVM::ConstantOp>(
@@ -638,8 +638,8 @@ struct GPUFuncOpLowering : LLVMOpLowering {
rewriter.getI64IntegerAttr(type.getNumElements()));
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
- auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
- type, allocated);
+ auto descr = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, typeConverter, type, allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + en.index(), descr);
}
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 09f7e5f71d5e..bf05726a22c7 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -146,7 +146,7 @@ class RangeOpConversion : public LLVMOpLowering {
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
- convertLinalgType(rangeOp.getResult().getType(), lowering);
+ convertLinalgType(rangeOp.getResult().getType(), typeConverter);
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -190,7 +190,7 @@ class ReshapeOpConversion : public LLVMOpLowering {
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpOperandAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
- BaseViewConversionHelper desc(lowering.convertType(dstType));
+ BaseViewConversionHelper desc(typeConverter.convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
desc.setOffset(baseDesc.offset());
@@ -225,11 +225,11 @@ class SliceOpConversion : public LLVMOpLowering {
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
- auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
+ auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
BaseViewConversionHelper desc(
- lowering.convertType(sliceOp.getShapedType()));
+ typeConverter.convertType(sliceOp.getShapedType()));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value, 4> strides(memRefType.getRank());
@@ -322,7 +322,7 @@ class TransposeOpConversion : public LLVMOpLowering {
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
BaseViewConversionHelper desc(
- lowering.convertType(transposeOp.getShapedType()));
+ typeConverter.convertType(transposeOp.getShapedType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index f613a39a9612..87bb1fbcf306 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -376,9 +376,10 @@ Type LLVMTypeConverter::convertStandardType(Type t) {
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
- LLVMTypeConverter &lowering_,
+ LLVMTypeConverter &typeConverter_,
PatternBenefit benefit)
- : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
+ : ConversionPattern(rootOpName, benefit, context),
+ typeConverter(typeConverter_) {}
/*============================================================================*/
/* StructBuilder implementation */
@@ -706,9 +707,9 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
public:
// Construct a conversion pattern.
explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
- LLVMTypeConverter &lowering_)
+ LLVMTypeConverter &typeConverter_)
: LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
- lowering_),
+ typeConverter_),
dialect(dialect_) {}
// Get the LLVM IR dialect.
@@ -904,7 +905,7 @@ struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
- auto llvmType = lowering.convertFunctionSignature(
+ auto llvmType = typeConverter.convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
// Propagate argument attributes to all converted arguments obtained after
@@ -957,10 +958,10 @@ struct FuncOpConversion : public FuncOpConversionBase {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (emitWrappers) {
if (newFuncOp.isExternal())
- wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp,
+ wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
newFuncOp);
else
- wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp,
+ wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp,
newFuncOp);
}
@@ -1014,7 +1015,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType());
rewriter.replaceUsesOfBlockArgument(arg, placeHolder);
auto desc = MemRefDescriptor::fromStaticShape(
- rewriter, funcLoc, lowering, memrefType, arg);
+ rewriter, funcLoc, typeConverter, memrefType, arg);
rewriter.replaceOp(placeHolder.getDefiningOp(), {desc});
}
}
@@ -1119,7 +1120,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
Type packedType;
if (numResults != 0) {
- packedType = this->lowering.packFunctionResults(op->getResultTypes());
+ packedType =
+ this->typeConverter.packFunctionResults(op->getResultTypes());
if (!packedType)
return this->matchFailure();
}
@@ -1139,7 +1141,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto type = this->lowering.convertType(op->getResult(i).getType());
+ auto type = this->typeConverter.convertType(op->getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
@@ -1206,7 +1208,8 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
- auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
+ auto vectorTypeInfo =
+ extractNDVectorTypeInfo(vectorType, this->typeConverter);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return this->matchFailure();
@@ -1416,8 +1419,9 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto elementType = type.getElementType();
- auto convertedPtrType =
- lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
+ auto convertedPtrType = typeConverter.convertType(elementType)
+ .cast<LLVM::LLVMType>()
+ .getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto one = createIndexConstant(rewriter, loc, 1);
auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
@@ -1464,7 +1468,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
.getResult(0);
}
- auto structElementType = lowering.convertType(elementType);
+ auto structElementType = typeConverter.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
Value bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
@@ -1484,7 +1488,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
"unexpected number of strides");
// Create the MemRef descriptor.
- auto structType = lowering.convertType(type);
+ auto structType = typeConverter.convertType(type);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
// Field 1: Allocated pointer, used for malloc/free.
memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
@@ -1578,11 +1582,12 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
}
if (numResults != 0) {
- if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
+ if (!(packedResult =
+ this->typeConverter.packFunctionResults(resultTypes)))
return this->matchFailure();
}
- auto promoted = this->lowering.promoteMemRefDescriptors(
+ auto promoted = this->typeConverter.promoteMemRefDescriptors(
op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult,
promoted, op->getAttrs());
@@ -1601,7 +1606,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto type = this->lowering.convertType(op->getResult(i).getType());
+ auto type = this->typeConverter.convertType(op->getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
@@ -1749,7 +1754,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
- auto targetStructType = lowering.convertType(memRefCastOp.getType());
+ auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
auto loc = op->getLoc();
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
@@ -1766,15 +1771,15 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
auto srcMemRefType = srcType.cast<MemRefType>();
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
- auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
- rewriter);
+ auto ptr = typeConverter.promoteOneMemRefDescriptor(
+ loc, transformed.source(), rewriter);
// voidptr = BitCastOp srcType* to void*
auto voidPtr =
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
.getResult();
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(rewriter.getIntegerType(64)),
+ loc, typeConverter.convertType(rewriter.getIntegerType(64)),
rewriter.getI64IntegerAttr(rank));
// undef = UndefOp
UnrankedMemRefDescriptor memRefDesc =
@@ -1967,7 +1972,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
transformed.indices(), rewriter, getModule());
// Replace with llvm.prefetch.
- auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32));
+ auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
auto isWrite = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), llvmI32Type,
rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
@@ -1998,7 +2003,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
auto indexCastOp = cast<IndexCastOp>(op);
auto targetType =
- this->lowering.convertType(indexCastOp.getResult().getType())
+ this->typeConverter.convertType(indexCastOp.getResult().getType())
.cast<LLVM::LLVMType>();
auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
@@ -2033,7 +2038,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
CmpIOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
- op, lowering.convertType(cmpiOp.getResult().getType()),
+ op, typeConverter.convertType(cmpiOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -2052,7 +2057,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
CmpFOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
- op, lowering.convertType(cmpfOp.getResult().getType()),
+ op, typeConverter.convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -2138,8 +2143,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto packedType =
- lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
+ auto packedType = typeConverter.packFunctionResults(
+ llvm::to_vector<4>(op->getOperandTypes()));
Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
@@ -2177,10 +2182,10 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
return matchFailure();
// First insert it into an undef vector so we can shuffle it.
- auto vectorType = lowering.convertType(splatOp.getType());
+ auto vectorType = typeConverter.convertType(splatOp.getType());
Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
+ op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
auto v = rewriter.create<LLVM::InsertElementOp>(
@@ -2213,7 +2218,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
// First insert it into an undef vector so we can shuffle it.
auto loc = op->getLoc();
- auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering);
+ auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmArrayTy || !llvmVectorTy)
@@ -2226,7 +2231,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
// places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(rewriter.getIntegerType(32)),
+ loc, typeConverter.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
adaptor.input(), zero);
@@ -2278,14 +2283,15 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
auto sourceMemRefType = viewOp.source().getType().cast<MemRefType>();
auto sourceElementTy =
- lowering.convertType(sourceMemRefType.getElementType())
+ typeConverter.convertType(sourceMemRefType.getElementType())
.dyn_cast_or_null<LLVM::LLVMType>();
auto viewMemRefType = viewOp.getType();
- auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
- .dyn_cast<LLVM::LLVMType>();
- auto targetDescTy =
- lowering.convertType(viewMemRefType).dyn_cast_or_null<LLVM::LLVMType>();
+ auto targetElementTy =
+ typeConverter.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ auto targetDescTy = typeConverter.convertType(viewMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
if (!sourceElementTy || !targetDescTy)
return matchFailure();
@@ -2333,7 +2339,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
// Fill in missing dynamic sizes.
- auto llvmIndexType = lowering.convertType(rewriter.getIndexType());
+ auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType());
if (dynamicSizes.empty()) {
dynamicSizes.reserve(viewMemRefType.getRank());
auto shape = viewMemRefType.getShape();
@@ -2424,10 +2430,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
ViewOpOperandAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
- auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
- .dyn_cast<LLVM::LLVMType>();
+ auto targetElementTy =
+ typeConverter.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
auto targetDescTy =
- lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+ typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
return op->emitWarning("Target descriptor type not converted to LLVM"),
matchFailure();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c43fc0e847ad..99b3f6894933 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -36,8 +36,8 @@ using namespace mlir::vector;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
- LLVMTypeConverter &lowering) {
- return lowering.convertType(containerType.getElementType())
+ LLVMTypeConverter &typeConverter) {
+ return typeConverter.convertType(containerType.getElementType())
.template cast<LLVM::LLVMType>()
.getPointerTo();
}
@@ -56,12 +56,13 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value val1,
- Value val2, Type llvmType, int64_t rank, int64_t pos) {
+ LLVMTypeConverter &typeConverter, Location loc,
+ Value val1, Value val2, Type llvmType, int64_t rank,
+ int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(idxType),
+ loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
constant);
@@ -83,12 +84,12 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value val,
- Type llvmType, int64_t rank, int64_t pos) {
+ LLVMTypeConverter &typeConverter, Location loc,
+ Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(idxType),
+ loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
constant);
@@ -137,7 +138,7 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
- if (lowering.convertType(dstVectorType) == nullptr)
+ if (typeConverter.convertType(dstVectorType) == nullptr)
return matchFailure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
@@ -203,12 +204,12 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
- Type llvmType = lowering.convertType(dstVectorType);
+ Type llvmType = typeConverter.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- Value expand =
- insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
+ Value expand = insertOne(rewriter, typeConverter, loc, undef, value,
+ llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
@@ -217,8 +218,8 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
reducedVectorTypeFront(dstVectorType), rewriter);
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
- result =
- insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
+ result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
+ rank, d);
}
return result;
}
@@ -243,31 +244,32 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
- Type llvmType = lowering.convertType(dstVectorType);
+ Type llvmType = typeConverter.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
- Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
+ Type redLlvmType =
+ typeConverter.convertType(dstVectorType.getElementType());
Value one =
- extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
- Value expand =
- insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
+ extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0);
+ Value expand = insertOne(rewriter, typeConverter, loc, result, one,
+ llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
}
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
- Type redLlvmType = lowering.convertType(redSrcType);
+ Type redLlvmType = typeConverter.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
- Value one =
- extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
+ Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType,
+ rank, pos);
Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
- result =
- insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
+ result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
+ rank, d);
}
return result;
}
@@ -286,7 +288,7 @@ class VectorReductionOpConversion : public LLVMOpLowering {
auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
- Type llvmType = lowering.convertType(eltType);
+ Type llvmType = typeConverter.convertType(eltType);
if (eltType.isInteger(32) || eltType.isInteger(64)) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
@@ -353,7 +355,7 @@ class VectorReductionV2OpConversion : public LLVMOpLowering {
auto reductionOp = cast<vector::ReductionV2Op>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
- Type llvmType = lowering.convertType(eltType);
+ Type llvmType = typeConverter.convertType(eltType);
if (kind == "add") {
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
op, llvmType, operands[1], operands[0]);
@@ -383,7 +385,7 @@ class VectorShuffleOpConversion : public LLVMOpLowering {
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
- Type llvmType = lowering.convertType(vectorType);
+ Type llvmType = typeConverter.convertType(vectorType);
auto maskArrayAttr = shuffleOp.mask();
// Bail if result type cannot be lowered.
@@ -415,10 +417,10 @@ class VectorShuffleOpConversion : public LLVMOpLowering {
extPos -= v1Dim;
value = adaptor.v2();
}
- Value extract =
- extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
- insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
- rank, insPos++);
+ Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
+ rank, extPos);
+ insert = insertOne(rewriter, typeConverter, loc, insert, extract,
+ llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
return matchSuccess();
@@ -438,7 +440,7 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
- auto llvmType = lowering.convertType(vectorType.getElementType());
+ auto llvmType = typeConverter.convertType(vectorType.getElementType());
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -465,7 +467,7 @@ class VectorExtractOpConversion : public LLVMOpLowering {
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
- auto llvmResultType = lowering.convertType(resultType);
+ auto llvmResultType = typeConverter.convertType(resultType);
auto positionArrayAttr = extractOp.position();
// Bail if result type cannot be lowered.
@@ -489,13 +491,13 @@ class VectorExtractOpConversion : public LLVMOpLowering {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
+ loc, typeConverter.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@@ -553,7 +555,7 @@ class VectorInsertElementOpConversion : public LLVMOpLowering {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = lowering.convertType(vectorType);
+ auto llvmType = typeConverter.convertType(vectorType);
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -580,7 +582,7 @@ class VectorInsertOpConversion : public LLVMOpLowering {
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
- auto llvmResultType = lowering.convertType(destVectorType);
+ auto llvmResultType = typeConverter.convertType(destVectorType);
auto positionArrayAttr = insertOp.position();
// Bail if result type cannot be lowered.
@@ -607,16 +609,16 @@ class VectorInsertOpConversion : public LLVMOpLowering {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
+ loc, typeConverter.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Insertion of an element into a 1-D LLVM vector.
- auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
- loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
- constant);
+ loc, typeConverter.convertType(oneDVectorType), extracted,
+ adaptor.source(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
@@ -830,7 +832,7 @@ class VectorOuterProductOpConversion : public LLVMOpLowering {
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
- auto llvmArrayOfVectType = lowering.convertType(
+ auto llvmArrayOfVectType = typeConverter.convertType(
cast<vector::OuterProductOp>(op).getResult().getType());
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
Value a = adaptor.lhs(), b = adaptor.rhs();
@@ -893,7 +895,7 @@ class VectorTypeCastOpConversion : public LLVMOpLowering {
return matchFailure();
MemRefDescriptor sourceMemRef(operands[0]);
- auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
+ auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
@@ -916,7 +918,7 @@ class VectorTypeCastOpConversion : public LLVMOpLowering {
if (failed(successStrides) || !isContiguous)
return matchFailure();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@@ -979,7 +981,7 @@ class VectorPrintOpConversion : public LLVMOpLowering {
auto adaptor = vector::PrintOpOperandAdaptor(operands);
Type printType = printOp.getPrintType();
- if (lowering.convertType(printType) == nullptr)
+ if (typeConverter.convertType(printType) == nullptr)
return matchFailure();
// Make sure element type has runtime support (currently just Float/Double).
@@ -1021,10 +1023,10 @@ class VectorPrintOpConversion : public LLVMOpLowering {
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
- auto llvmType = lowering.convertType(
+ auto llvmType = typeConverter.convertType(
rank > 1 ? reducedType : vectorType.getElementType());
Value nestedVal =
- extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
+ extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
@@ -1055,36 +1057,36 @@ class VectorPrintOpConversion : public LLVMOpLowering {
// Helpers for method names.
Operation *getPrintI32(Operation *op) const {
- LLVM::LLVMDialect *dialect = lowering.getDialect();
+ LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i32",
LLVM::LLVMType::getInt32Ty(dialect));
}
Operation *getPrintI64(Operation *op) const {
- LLVM::LLVMDialect *dialect = lowering.getDialect();
+ LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i64",
LLVM::LLVMType::getInt64Ty(dialect));
}
Operation *getPrintFloat(Operation *op) const {
- LLVM::LLVMDialect *dialect = lowering.getDialect();
+ LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f32",
LLVM::LLVMType::getFloatTy(dialect));
}
Operation *getPrintDouble(Operation *op) const {
- LLVM::LLVMDialect *dialect = lowering.getDialect();
+ LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f64",
LLVM::LLVMType::getDoubleTy(dialect));
}
Operation *getPrintOpen(Operation *op) const {
- return getPrint(op, lowering.getDialect(), "print_open", {});
+ return getPrint(op, typeConverter.getDialect(), "print_open", {});
}
Operation *getPrintClose(Operation *op) const {
- return getPrint(op, lowering.getDialect(), "print_close", {});
+ return getPrint(op, typeConverter.getDialect(), "print_close", {});
}
Operation *getPrintComma(Operation *op) const {
- return getPrint(op, lowering.getDialect(), "print_comma", {});
+ return getPrint(op, typeConverter.getDialect(), "print_comma", {});
}
Operation *getPrintNewline(Operation *op) const {
- return getPrint(op, lowering.getDialect(), "print_newline", {});
+ return getPrint(op, typeConverter.getDialect(), "print_newline", {});
}
};
More information about the Mlir-commits
mailing list