[flang-commits] [flang] ce25459 - [mlir][Conversion] Store const type converter in ConversionPattern
Matthias Springer via flang-commits
flang-commits at lists.llvm.org
Mon Aug 14 00:03:36 PDT 2023
Author: Matthias Springer
Date: 2023-08-14T09:03:11+02:00
New Revision: ce254598b73b119c9463f5b7f4131559e276e844
URL: https://github.com/llvm/llvm-project/commit/ce254598b73b119c9463f5b7f4131559e276e844
DIFF: https://github.com/llvm/llvm-project/commit/ce254598b73b119c9463f5b7f4131559e276e844.diff
LOG: [mlir][Conversion] Store const type converter in ConversionPattern
ConversionPatterns do not (and should not) modify the type converter that they are using.
* Make `ConversionPattern::typeConverter` const.
* Make member functions of the `LLVMTypeConverter` const.
* Conversion patterns take a const type converter.
* Various helper functions (that are called from patterns) now also take a const type converter.
Differential Revision: https://reviews.llvm.org/D157601
Added:
Modified:
flang/include/flang/Optimizer/CodeGen/TypeConverter.h
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/CodeGen/TypeConverter.cpp
mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 4131eb53f07625..f42c40eb68902b 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -49,20 +49,20 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// i32 is used here because LLVM wants i32 constants when indexing into struct
// types. Indexing into other aggregate types is more flexible.
- mlir::Type offsetType();
+ mlir::Type offsetType() const;
// i64 can be used to index into aggregates like arrays
- mlir::Type indexType();
+ mlir::Type indexType() const;
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult>
convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
- llvm::ArrayRef<mlir::Type> callStack);
+ llvm::ArrayRef<mlir::Type> callStack) const;
// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
- bool requiresExtendedDesc(mlir::Type boxElementType);
+ bool requiresExtendedDesc(mlir::Type boxElementType) const;
// Magic value to indicate we do not know the rank of an entity, either
// because it is assumed rank or because we have not determined it yet.
@@ -70,35 +70,33 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
// the addendum defined in descriptor.h.
- mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank());
+ mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()) const;
/// Convert fir.box type to the corresponding llvm struct type instead of a
/// pointer to this struct type.
- mlir::Type convertBoxTypeAsStruct(BaseBoxType box);
+ mlir::Type convertBoxTypeAsStruct(BaseBoxType box) const;
// fir.boxproc<any> --> llvm<"{ any*, i8* }">
- mlir::Type convertBoxProcType(BoxProcType boxproc);
+ mlir::Type convertBoxProcType(BoxProcType boxproc) const;
- unsigned characterBitsize(fir::CharacterType charTy);
+ unsigned characterBitsize(fir::CharacterType charTy) const;
// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
// fir.char<k,n> --> llvm.array<n x "ix">
- mlir::Type convertCharType(fir::CharacterType charTy);
+ mlir::Type convertCharType(fir::CharacterType charTy) const;
// Use the target specifics to figure out how to map complex to LLVM IR. The
// use of complex values in function signatures is handled before conversion
// to LLVM IR dialect here.
//
// fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
- template <typename C>
- mlir::Type convertComplexType(C cmplx) {
+ template <typename C> mlir::Type convertComplexType(C cmplx) const {
LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n');
auto eleTy = cmplx.getElementType();
return convertType(specifics->complexMemoryType(eleTy));
}
- template <typename A>
- mlir::Type convertPointerLike(A &ty) {
+ template <typename A> mlir::Type convertPointerLike(A &ty) const {
mlir::Type eleTy = ty.getEleTy();
// A sequence type is a special case. A sequence of runtime size on its
// interior dimensions lowers to a memory reference. In that case, we
@@ -126,27 +124,27 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
- mlir::Type convertRealType(fir::KindTy kind);
+ mlir::Type convertRealType(fir::KindTy kind) const;
// fir.array<c ... :any> --> llvm<"[...[c x any]]">
- mlir::Type convertSequenceType(SequenceType seq);
+ mlir::Type convertSequenceType(SequenceType seq) const;
// fir.tdesc<any> --> llvm<"i8*">
// TODO: For now use a void*, however pointer identity is not sufficient for
// the f18 object v. class distinction (F2003).
- mlir::Type convertTypeDescType(mlir::MLIRContext *ctx);
+ mlir::Type convertTypeDescType(mlir::MLIRContext *ctx) const;
- KindMapping &getKindMap() { return kindMapping; }
+ const KindMapping &getKindMap() const { return kindMapping; }
// Relay TBAA tag attachment to TBAABuilder.
void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType, mlir::Type accessFIRType,
- mlir::LLVM::GEPOp gep);
+ mlir::LLVM::GEPOp gep) const;
private:
KindMapping kindMapping;
std::unique_ptr<CodeGenSpecifics> specifics;
- TBAABuilder tbaaBuilder;
+ std::unique_ptr<TBAABuilder> tbaaBuilder;
};
} // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 596458d37d2dcd..0fbee616ac9a51 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -117,7 +117,7 @@ namespace {
template <typename FromOp>
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
public:
- explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
+ explicit FIROpConversion(const fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options)
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options) {}
@@ -359,8 +359,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
return al;
}
- fir::LLVMTypeConverter &lowerTy() const {
- return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
+ const fir::LLVMTypeConverter &lowerTy() const {
+ return *static_cast<const fir::LLVMTypeConverter *>(
+ this->getTypeConverter());
}
void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
@@ -3191,8 +3192,8 @@ struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
};
template <typename OP>
-static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
- typename OP::Adaptor adaptor,
+static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
+ OP select, typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
@@ -3461,7 +3462,7 @@ template <typename LLVMOP, typename OPTY>
static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
- fir::LLVMTypeConverter &lowering) {
+ const fir::LLVMTypeConverter &lowering) {
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
@@ -3610,7 +3611,7 @@ struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
/// These operations are normally dead after the pre-codegen pass.
template <typename FromOp>
struct MustBeDeadConversion : public FIROpConversion<FromOp> {
- explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
+ explicit MustBeDeadConversion(const fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options)
: FIROpConversion<FromOp>(lowering, options) {}
using OpAdaptor = typename FromOp::Adaptor;
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 8de2dbbca3f806..fd5f0c7135fea2 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -37,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
specifics(CodeGenSpecifics::get(module.getContext(),
getTargetTriple(module),
getKindMapping(module))),
- tbaaBuilder(module->getContext(), applyTBAA) {
+ tbaaBuilder(
+ std::make_unique<TBAABuilder>(module->getContext(), applyTBAA)) {
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
// Each conversion should return a value of type mlir::Type.
@@ -155,20 +156,19 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
// i32 is used here because LLVM wants i32 constants when indexing into struct
// types. Indexing into other aggregate types is more flexible.
-mlir::Type LLVMTypeConverter::offsetType() {
+mlir::Type LLVMTypeConverter::offsetType() const {
return mlir::IntegerType::get(&getContext(), 32);
}
// i64 can be used to index into aggregates like arrays
-mlir::Type LLVMTypeConverter::indexType() {
+mlir::Type LLVMTypeConverter::indexType() const {
return mlir::IntegerType::get(&getContext(), 64);
}
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
-std::optional<mlir::LogicalResult>
-LLVMTypeConverter::convertRecordType(fir::RecordType derived,
- llvm::SmallVectorImpl<mlir::Type> &results,
- llvm::ArrayRef<mlir::Type> callStack) {
+std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
+ fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
+ llvm::ArrayRef<mlir::Type> callStack) const {
auto name = derived.getName();
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
if (llvm::count(callStack, derived) > 1) {
@@ -192,14 +192,14 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived,
// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
-bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) {
+bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) const {
auto eleTy = fir::unwrapSequenceType(boxElementType);
return eleTy.isa<fir::RecordType>();
}
// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
// the addendum defined in descriptor.h.
-mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {
+mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const {
// (base_addr*, elem_len, version, rank, type, attribute, f18Addendum, [dim]
llvm::SmallVector<mlir::Type> dataDescFields;
mlir::Type ele = box.getEleTy();
@@ -269,14 +269,14 @@ mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {
/// Convert fir.box type to the corresponding llvm struct type instead of a
/// pointer to this struct type.
-mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) {
+mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) const {
return convertBoxType(box)
.cast<mlir::LLVM::LLVMPointerType>()
.getElementType();
}
// fir.boxproc<any> --> llvm<"{ any*, i8* }">
-mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
+mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const {
auto funcTy = convertType(boxproc.getEleTy());
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(&getContext(), 8));
@@ -285,13 +285,13 @@ mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
/*isPacked=*/false);
}
-unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) {
+unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) const {
return kindMapping.getCharacterBitsize(charTy.getFKind());
}
// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
// fir.char<k,n> --> llvm.array<n x "ix">
-mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {
+mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) const {
auto iTy = mlir::IntegerType::get(&getContext(), characterBitsize(charTy));
if (charTy.getLen() == fir::CharacterType::unknownLen())
return iTy;
@@ -300,13 +300,13 @@ mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {
// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
-mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) {
+mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) const {
return fir::fromRealTypeID(&getContext(), kindMapping.getRealTypeID(kind),
kind);
}
// fir.array<c ... :any> --> llvm<"[...[c x any]]">
-mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
+mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
auto baseTy = convertType(seq.getEleTy());
if (characterWithDynamicLen(seq.getEleTy()))
return mlir::LLVM::LLVMPointerType::get(baseTy);
@@ -328,7 +328,8 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
// fir.tdesc<any> --> llvm<"i8*">
// TODO: For now use a void*, however pointer identity is not sufficient for
// the f18 object v. class distinction (F2003).
-mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
+mlir::Type
+LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(&getContext(), 8));
}
@@ -337,8 +338,8 @@ mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
void LLVMTypeConverter::attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType,
mlir::Type accessFIRType,
- mlir::LLVM::GEPOp gep) {
- tbaaBuilder.attachTBAATag(op, baseFIRType, accessFIRType, gep);
+ mlir::LLVM::GEPOp gep) const {
+ tbaaBuilder->attachTBAATag(op, baseFIRType, accessFIRType, gep);
}
} // namespace fir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index 28d37a91edb80d..ef8215d332c463 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -40,13 +40,14 @@ class MemRefDescriptor : public StructBuilder {
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
- static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- MemRefType type, Value memory);
- static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- MemRefType type, Value memory,
- Value alignedMemory);
+ static MemRefDescriptor
+ fromStaticShape(OpBuilder &builder, Location loc,
+ const LLVMTypeConverter &typeConverter, MemRefType type,
+ Value memory);
+ static MemRefDescriptor
+ fromStaticShape(OpBuilder &builder, Location loc,
+ const LLVMTypeConverter &typeConverter, MemRefType type,
+ Value memory, Value alignedMemory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
@@ -95,7 +96,7 @@ class MemRefDescriptor : public StructBuilder {
/// \note there is no setter for this one since it is derived from alignedPtr
/// and offset.
Value bufferPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter, MemRefType type);
+ const LLVMTypeConverter &converter, MemRefType type);
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
@@ -106,7 +107,7 @@ class MemRefDescriptor : public StructBuilder {
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
static Value pack(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter, MemRefType type,
+ const LLVMTypeConverter &converter, MemRefType type,
ValueRange values);
/// Builds IR extracting individual elements of a MemRef descriptor structure
@@ -178,7 +179,7 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// - rank of the memref;
/// - pointer to the memref descriptor.
static Value pack(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter, UnrankedMemRefType type,
+ const LLVMTypeConverter &converter, UnrankedMemRefType type,
ValueRange values);
/// Builds IR extracting individual elements that compose an unranked memref
@@ -195,7 +196,7 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// which must have the same length as `values`, is needed to handle layouts
/// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
static void computeSizes(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
ArrayRef<unsigned> addressSpaces,
SmallVectorImpl<Value> &sizes);
@@ -217,11 +218,12 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Builds IR extracting the aligned pointer from the descriptor.
static Value alignedPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value memRefDescPtr,
+ const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType);
/// Builds IR inserting the aligned pointer into the descriptor.
static void setAlignedPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType,
Value alignedPtr);
@@ -230,44 +232,45 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Returns a pointer to a convertType(index), which points to the beggining
/// of a struct {index, index[rank], index[rank]}.
static Value offsetBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType);
/// Builds IR extracting the offset from the descriptor.
static Value offset(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value memRefDescPtr,
- LLVM::LLVMPointerType elemPtrType);
+ const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType);
/// Builds IR inserting the offset into the descriptor.
static void setOffset(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value memRefDescPtr,
- LLVM::LLVMPointerType elemPtrType, Value offset);
+ const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType,
+ Value offset);
/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType);
/// Builds IR extracting the size[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value sizeBasePtr,
+ const LLVMTypeConverter &typeConverter, Value sizeBasePtr,
Value index);
/// Builds IR inserting the size[index] into the descriptor.
static void setSize(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value sizeBasePtr,
+ const LLVMTypeConverter &typeConverter, Value sizeBasePtr,
Value index, Value size);
/// Builds IR extracting the pointer to the first element of the stride array.
static Value strideBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank);
/// Builds IR extracting the stride[index] from the descriptor.
static Value stride(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value strideBasePtr,
- Value index, Value stride);
+ const LLVMTypeConverter &typeConverter,
+ Value strideBasePtr, Value index, Value stride);
/// Builds IR inserting the stride[index] into the descriptor.
static void setStride(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter, Value strideBasePtr,
- Value index, Value stride);
+ const LLVMTypeConverter &typeConverter,
+ Value strideBasePtr, Value index, Value stride);
};
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 075d753ea6ed82..92f4025ffffffb 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -23,7 +23,7 @@ namespace detail {
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
@@ -37,14 +37,14 @@ LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
class ConvertToLLVMPattern : public ConversionPattern {
public:
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
protected:
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
- LLVMTypeConverter *getTypeConverter() const;
+ const LLVMTypeConverter *getTypeConverter() const;
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
@@ -140,7 +140,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
- explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
+ explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(SourceOp::getOperationName(),
&typeConverter.getContext(), typeConverter,
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 79a68e875f045e..2097aa78ebd70e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -33,7 +33,7 @@ class LLVMStructType;
class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
- structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
+ structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result);
public:
@@ -54,20 +54,20 @@ class LLVMTypeConverter : public TypeConverter {
/// is populated with argument mapping.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
bool useBarePtrCallConv,
- SignatureConversion &result);
+ SignatureConversion &result) const;
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
/// each of the types converted with `convertCallingConventionType`.
Type packFunctionResults(TypeRange types,
- bool useBarePointerCallConv = false);
+ bool useBarePointerCallConv = false) const;
/// Convert a non-empty list of types of values produced by an operation into
/// an LLVM-compatible type. In particular, if more than one value is
/// produced, create a literal structure with elements that correspond to each
/// of the LLVM-compatible types converted with `convertType`.
- Type packOperationResults(TypeRange types);
+ Type packOperationResults(TypeRange types) const;
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
@@ -75,20 +75,20 @@ class LLVMTypeConverter : public TypeConverter {
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
Type convertCallingConventionType(Type type,
- bool useBarePointerCallConv = false);
+ bool useBarePointerCallConv = false) const;
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values);
+ SmallVectorImpl<Value> &values) const;
/// Returns the MLIR context.
- MLIRContext &getContext();
+ MLIRContext &getContext() const;
/// Returns the LLVM dialect.
- LLVM::LLVMDialect *getDialect() { return llvmDialect; }
+ LLVM::LLVMDialect *getDialect() const { return llvmDialect; }
const LowerToLLVMOptions &getOptions() const { return options; }
@@ -105,23 +105,23 @@ class LLVMTypeConverter : public TypeConverter {
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv = false);
+ bool useBarePtrCallConv = false) const;
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
/// C/C++ ABI lowering related to struct argument passing.
Value promoteOneMemRefDescriptor(Location loc, Value operand,
- OpBuilder &builder);
+ OpBuilder &builder) const;
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments. Also converts the return
/// type to a pointer argument if it is a struct. Returns true if this
/// was the case.
std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
- convertFunctionTypeCWrapper(FunctionType type);
+ convertFunctionTypeCWrapper(FunctionType type) const;
/// Returns the data layout to use during and after conversion.
- const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
+ const llvm::DataLayout &getDataLayout() const { return options.dataLayout; }
/// Returns the data layout analysis to query during conversion.
const DataLayoutAnalysis *getDataLayoutAnalysis() const {
@@ -130,7 +130,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
- Type getIndexType();
+ Type getIndexType() const;
/// Returns true if using opaque pointers was enabled in the lowering options.
bool useOpaquePointers() const { return getOptions().useOpaquePointers; }
@@ -141,25 +141,26 @@ class LLVMTypeConverter : public TypeConverter {
/// pointers, as it will create an opaque pointer with the given address space
/// if opaque pointers are enabled in the lowering options.
LLVM::LLVMPointerType getPointerType(Type elementType,
- unsigned addressSpace = 0);
+ unsigned addressSpace = 0) const;
/// Gets the bitwidth of the index type when converted to LLVM.
- unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); }
+ unsigned getIndexTypeBitwidth() const { return options.getIndexBitwidth(); }
/// Gets the pointer bitwidth.
- unsigned getPointerBitwidth(unsigned addressSpace = 0);
+ unsigned getPointerBitwidth(unsigned addressSpace = 0) const;
/// Returns the size of the memref descriptor object in bytes.
- unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout);
+ unsigned getMemRefDescriptorSize(MemRefType type,
+ const DataLayout &layout) const;
/// Returns the size of the unranked memref descriptor object in bytes.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
- const DataLayout &layout);
+ const DataLayout &layout) const;
/// Return the LLVM address space corresponding to the memory space of the
/// memref type `type` or failure if the memory space cannot be converted to
/// an integer.
- FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type);
+ FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type) const;
/// Check if a memref type can be converted to a bare pointer.
static bool canConvertToBarePtr(BaseMemRefType type);
@@ -173,28 +174,28 @@ class LLVMTypeConverter : public TypeConverter {
/// one. Additionally, if the function returns more than one value, pack the
/// results into an LLVM IR structure type so that the converted function type
/// returns at most one result.
- Type convertFunctionType(FunctionType type);
+ Type convertFunctionType(FunctionType type) const;
/// Convert the index type. Uses llvmModule data layout to create an integer
/// of the pointer bitwidth.
- Type convertIndexType(IndexType type);
+ Type convertIndexType(IndexType type) const;
/// Convert an integer type `i*` to `!llvm<"i*">`.
- Type convertIntegerType(IntegerType type);
+ Type convertIntegerType(IntegerType type) const;
/// Convert a floating point type: `f16` to `f16`, `f32` to
/// `f32` and `f64` to `f64`. `bf16` is not supported
/// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
/// all LLVM backends that support them currently represent them.
- Type convertFloatType(FloatType type);
+ Type convertFloatType(FloatType type) const;
/// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
/// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
/// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
- Type convertComplexType(ComplexType type);
+ Type convertComplexType(ComplexType type) const;
/// Convert a memref type into an LLVM type that captures the relevant data.
- Type convertMemRefType(MemRefType type);
+ Type convertMemRefType(MemRefType type) const;
/// Convert a memref type into a list of LLVM IR types that will form the
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
@@ -218,7 +219,7 @@ class LLVMTypeConverter : public TypeConverter {
/// - `i64`, `i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
- bool unpackAggregates);
+ bool unpackAggregates) const;
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that will form the unranked memref descriptor. In particular, this list
@@ -229,17 +230,17 @@ class LLVMTypeConverter : public TypeConverter {
/// i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
- SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
+ SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
/// Convert an unranked memref type to an LLVM type that captures the
/// runtime rank and a pointer to the static ranked memref desc
- Type convertUnrankedMemRefType(UnrankedMemRefType type);
+ Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
/// Convert a memref type to a bare pointer to the memref element type.
- Type convertMemRefToBarePtr(BaseMemRefType type);
+ Type convertMemRefToBarePtr(BaseMemRefType type) const;
/// Convert a 1D vector type into an LLVM vector type.
- Type convertVectorType(VectorType type);
+ Type convertVectorType(VectorType type) const;
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
@@ -252,13 +253,13 @@ class LLVMTypeConverter : public TypeConverter {
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
+LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
+LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index d115c2d2f58fef..279175b6128fc7 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -32,7 +32,7 @@ struct NDVectorTypeInfo {
// Iterates on the llvm array type until we hit a non-array type (which is
// asserted to be an llvm vector type).
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
- LLVMTypeConverter &converter);
+ const LLVMTypeConverter &converter);
// Express `linearIndex` in terms of coordinates of `basis`.
// Returns the empty vector when linearIndex is out of the range [0, P] where
@@ -50,14 +50,14 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);
LogicalResult handleMultidimensionalVectors(
- Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
+ Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter);
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index 495c4d63986f80..8bf04219c759ae 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -20,7 +20,7 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::getVoidPtrType;
explicit AllocationOpLLVMLowering(StringRef opName,
- LLVMTypeConverter &converter,
+ const LLVMTypeConverter &converter,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(opName, &converter.getContext(), converter,
benefit) {}
@@ -107,7 +107,7 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
/// Lowering for AllocOp and AllocaOp.
struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
explicit AllocLikeOpLLVMLowering(StringRef opName,
- LLVMTypeConverter &converter,
+ const LLVMTypeConverter &converter,
PatternBenefit benefit = 1)
: AllocationOpLLVMLowering(opName, converter, benefit) {}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index ba3e8ae89e1606..89ded981d38f9f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -83,7 +83,7 @@ class SPIRVTypeConverter : public TypeConverter {
const SPIRVConversionOptions &getOptions() const { return options; }
/// Checks if the SPIR-V capability inquired is supported.
- bool allows(spirv::Capability capability);
+ bool allows(spirv::Capability capability) const;
private:
spirv::TargetEnv targetEnv;
@@ -169,17 +169,17 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
// that has static strides. Extend to handle dynamic strides.
-Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType,
- Value basePtr, ValueRange indices, Location loc,
- OpBuilder &builder);
+Value getElementPtr(const SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr, ValueRange indices,
+ Location loc, OpBuilder &builder);
// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
-Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder);
// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
-Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b4051093d4b0a9..6e11c3ed0a0179 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -476,13 +476,13 @@ class ConversionPattern : public RewritePattern {
/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
- TypeConverter *getTypeConverter() const { return typeConverter; }
+ const TypeConverter *getTypeConverter() const { return typeConverter; }
template <typename ConverterTy>
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
- ConverterTy *>
+ const ConverterTy *>
getTypeConverter() const {
- return static_cast<ConverterTy *>(typeConverter);
+ return static_cast<const ConverterTy *>(typeConverter);
}
protected:
@@ -492,13 +492,13 @@ class ConversionPattern : public RewritePattern {
/// Construct a conversion pattern with the given converter, and forward the
/// remaining arguments to RewritePattern.
template <typename... Args>
- ConversionPattern(TypeConverter &typeConverter, Args &&...args)
+ ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
protected:
/// An optional type converter for use by this pattern.
- TypeConverter *typeConverter = nullptr;
+ const TypeConverter *typeConverter = nullptr;
private:
using RewritePattern::rewrite;
@@ -514,7 +514,7 @@ class OpConversionPattern : public ConversionPattern {
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
- OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+ OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
context) {}
@@ -567,7 +567,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
- OpInterfaceConversionPattern(TypeConverter &typeConverter,
+ OpInterfaceConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
@@ -608,17 +608,17 @@ class OpInterfaceConversionPattern : public ConversionPattern {
/// ops which use FunctionType to represent their type.
void populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
- TypeConverter &converter);
+ const TypeConverter &converter);
template <typename FuncOpT>
void populateFunctionOpInterfaceTypeConversionPattern(
- RewritePatternSet &patterns, TypeConverter &converter) {
+ RewritePatternSet &patterns, const TypeConverter &converter) {
populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
patterns, converter);
}
void populateAnyFunctionOpInterfaceTypeConversionPattern(
- RewritePatternSet &patterns, TypeConverter &converter);
+ RewritePatternSet &patterns, const TypeConverter &converter);
//===----------------------------------------------------------------------===//
// Conversion PatternRewriter
@@ -645,7 +645,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
Block *
applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion,
- TypeConverter *converter = nullptr);
+ const TypeConverter *converter = nullptr);
/// Convert the types of block arguments within the given region. This
/// replaces each block with a new block containing the updated signature. The
@@ -653,7 +653,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
FailureOr<Block *> convertRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
/// Convert the types of block arguments within the given region except for
@@ -664,7 +664,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// example, we need to convert only a subset of a BB arguments), such
/// behavior can be specified in blockConversions.
LogicalResult convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
/// Replace all the uses of the block argument `from` with value `to`.
@@ -1024,12 +1024,12 @@ class ConversionTarget {
class PDLConversionConfig final
: public PDLPatternConfigBase<PDLConversionConfig> {
public:
- PDLConversionConfig(TypeConverter *converter) : converter(converter) {}
+ PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
~PDLConversionConfig() final = default;
/// Return the type converter used by this configuration, which may be nullptr
/// if no type conversions are expected.
- TypeConverter *getTypeConverter() const { return converter; }
+ const TypeConverter *getTypeConverter() const { return converter; }
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern.
@@ -1038,7 +1038,7 @@ class PDLConversionConfig final
private:
/// An optional type converter to use for the pattern.
- TypeConverter *converter;
+ const TypeConverter *converter;
};
/// Register the dialect conversion PDL functions with the given pattern set.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ecd4cbb25f2d5c..259b7eeb658ec3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -42,7 +42,7 @@ namespace {
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
- RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
@@ -345,7 +345,8 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
/// vector.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
- Location loc, TypeConverter *typeConverter,
+ Location loc,
+ const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
@@ -384,7 +385,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
/// be stored it in the upper part
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
- Location loc, TypeConverter *typeConverter,
+ Location loc,
+ const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
Type inputType = output.getType();
@@ -562,7 +564,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
namespace {
struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
- MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
Chipset chipset;
@@ -600,7 +602,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
- WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+ WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Chipset chipset;
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 7c8baee1448575..234d06c08da6dc 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -359,7 +359,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
/// Creates an LLVM pointer type which may either be a typed pointer or an
/// opaque pointer, depending on what options the converter was constructed
/// with.
- LLVM::LLVMPointerType getPointerType(Type elementType) {
+ LLVM::LLVMPointerType getPointerType(Type elementType) const {
if (llvmOpaquePointers)
return LLVM::LLVMPointerType::get(elementType.getContext());
return LLVM::LLVMPointerType::get(elementType);
@@ -388,13 +388,14 @@ class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
using Base = OpConversionPattern<SourceOp>;
public:
- AsyncOpConversionPattern(AsyncRuntimeTypeConverter &typeConverter,
+ AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
MLIRContext *context)
: Base(typeConverter, context) {}
/// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
- AsyncRuntimeTypeConverter *getTypeConverter() const {
- return static_cast<AsyncRuntimeTypeConverter *>(Base::getTypeConverter());
+ const AsyncRuntimeTypeConverter *getTypeConverter() const {
+ return static_cast<const AsyncRuntimeTypeConverter *>(
+ Base::getTypeConverter());
}
};
@@ -653,7 +654,7 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
LogicalResult
matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = getTypeConverter();
+ const TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
// Tokens creation maps to a simple function call.
@@ -706,7 +707,7 @@ class RuntimeCreateGroupOpLowering
LogicalResult
matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = getTypeConverter();
+ const TypeConverter *converter = getTypeConverter();
Type resultType = op.getResult().getType();
rewriter.replaceOpWithNewOp<func::CallOp>(
@@ -1040,8 +1041,8 @@ namespace {
template <typename RefCountingOp>
class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
public:
- explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
- StringRef apiFunctionName)
+ explicit RefCountingOpLowering(const TypeConverter &converter,
+ MLIRContext *ctx, StringRef apiFunctionName)
: OpConversionPattern<RefCountingOp>(converter, ctx),
apiFunctionName(apiFunctionName) {}
@@ -1065,14 +1066,16 @@ class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
public:
- explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
+ MLIRContext *ctx)
: RefCountingOpLowering(converter, ctx, kAddRef) {}
};
class RuntimeDropRefOpLowering
: public RefCountingOpLowering<RuntimeDropRefOp> {
public:
- explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
+ MLIRContext *ctx)
: RefCountingOpLowering(converter, ctx, kDropRef) {}
};
} // namespace
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d99968d78d248c..a4f146bbe475cc 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -46,7 +46,8 @@ static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
/// Generate IR that prints the given string to stderr.
static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef msg, LLVMTypeConverter &typeConverter) {
+ StringRef msg,
+ const LLVMTypeConverter &typeConverter) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7ee0ea91827f22..1db463c0ab7163 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -63,7 +63,7 @@ static constexpr StringRef barePtrAttrName = "llvm.bareptr";
/// Return `true` if the `op` should use bare pointer calling convention.
static bool shouldUseBarePtrCallConv(Operation *op,
- LLVMTypeConverter *typeConverter) {
+ const LLVMTypeConverter *typeConverter) {
return (op && op->hasAttr(barePtrAttrName)) ||
typeConverter->getOptions().useBarePtrCallConv;
}
@@ -118,7 +118,7 @@ static void prependEmptyArgAttr(OpBuilder &builder,
/// components and forwards them to `newFuncOp` and forwards the results to
/// the extra arguments.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
func::FuncOp funcOp,
LLVM::LLVMFuncOp newFuncOp) {
auto type = funcOp.getFunctionType();
@@ -182,7 +182,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
/// compatible with functions defined in C using pointers to C structs
/// corresponding to a memref descriptor.
static void wrapExternalFunction(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
func::FuncOp funcOp,
LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
@@ -281,7 +281,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
/// the bare pointer calling convention lowering of `memref` types.
static void modifyFuncOpToUseBarePtrCallingConv(
ConversionPatternRewriter &rewriter, Location loc,
- LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
+ const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
TypeRange oldArgTypes) {
if (funcOp.getBody().empty())
return;
@@ -469,7 +469,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
- FuncOpConversion(LLVMTypeConverter &converter)
+ FuncOpConversion(const LLVMTypeConverter &converter)
: FuncOpConversionBase(converter) {}
LogicalResult
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f7caf025fb79bd..2a26587be0b412 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -464,7 +464,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
/// Unrolls op if it's operating on vectors.
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &converter) {
+ const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
if (llvm::none_of(operandTypes,
[](Type type) { return isa<VectorType>(type); })) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index d61f22c9fc37df..bd90286494d803 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -15,8 +15,9 @@
namespace mlir {
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
- GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace,
- unsigned workgroupAddrSpace, StringAttr kernelAttributeName)
+ GPUFuncOpLowering(const LLVMTypeConverter &converter,
+ unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
+ StringAttr kernelAttributeName)
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
allocaAddrSpace(allocaAddrSpace),
workgroupAddrSpace(workgroupAddrSpace),
@@ -57,7 +58,7 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
/// will lower printf calls to appropriate device-side code
struct GPUPrintfOpToLLVMCallLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
- GPUPrintfOpToLLVMCallLowering(LLVMTypeConverter &converter,
+ GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
int addressSpace = 0)
: ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
addressSpace(addressSpace) {}
@@ -95,7 +96,7 @@ namespace impl {
/// Unrolls op if it's operating on vectors.
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &converter);
+ const LLVMTypeConverter &converter);
} // namespace impl
/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 666dc8e27a9f7d..e0e9a7169bc6b9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -61,7 +61,8 @@ class GpuToLLVMConversionPass
template <typename OpTy>
class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
public:
- explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ explicit ConvertOpToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
protected:
@@ -341,7 +342,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
class ConvertHostRegisterOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
public:
- ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertHostRegisterOpToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
private:
@@ -354,7 +356,7 @@ class ConvertHostUnregisterOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
public:
ConvertHostUnregisterOpToGpuRuntimeCallPattern(
- LLVMTypeConverter &typeConverter)
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
}
@@ -369,7 +371,7 @@ class ConvertHostUnregisterOpToGpuRuntimeCallPattern
class ConvertAllocOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
public:
- ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
private:
@@ -383,7 +385,8 @@ class ConvertAllocOpToGpuRuntimeCallPattern
class ConvertDeallocOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
public:
- ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertDeallocOpToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
private:
@@ -395,7 +398,8 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
class ConvertAsyncYieldToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
public:
- ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertAsyncYieldToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
private:
@@ -409,7 +413,7 @@ class ConvertAsyncYieldToGpuRuntimeCallPattern
class ConvertWaitOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
public:
- ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
private:
@@ -423,7 +427,8 @@ class ConvertWaitOpToGpuRuntimeCallPattern
class ConvertWaitAsyncOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
public:
- ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertWaitAsyncOpToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
private:
@@ -448,10 +453,9 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
class ConvertLaunchFuncOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
public:
- ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
- StringRef gpuBinaryAnnotation,
- bool kernelBarePtrCallConv,
- SymbolTable *cachedModuleTable)
+ ConvertLaunchFuncOpToGpuRuntimeCallPattern(
+ const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
+ bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
gpuBinaryAnnotation(gpuBinaryAnnotation),
kernelBarePtrCallConv(kernelBarePtrCallConv),
@@ -489,7 +493,7 @@ class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
class ConvertMemcpyOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
public:
- ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
private:
@@ -503,7 +507,7 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern
class ConvertMemsetOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
public:
- ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
private:
@@ -518,7 +522,7 @@ class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
public:
ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
- LLVMTypeConverter &typeConverter)
+ const LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
typeConverter) {}
@@ -534,7 +538,7 @@ class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
public: \
Convert##op_name##ToGpuRuntimeCallPattern( \
- LLVMTypeConverter &typeConverter) \
+ const LLVMTypeConverter &typeConverter) \
: ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
\
private: \
@@ -980,15 +984,15 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
SmallVector<Value, 4> arguments;
if (kernelBarePtrCallConv) {
// Hack the bare pointer value on just for the argument promotion
- LLVMTypeConverter *converter = getTypeConverter();
+ const LLVMTypeConverter *converter = getTypeConverter();
LowerToLLVMOptions options = converter->getOptions();
LowerToLLVMOptions overrideToMatchKernelOpts = options;
overrideToMatchKernelOpts.useBarePtrCallConv = true;
- converter->dangerousSetOptions(overrideToMatchKernelOpts);
- arguments = converter->promoteOperands(
+ LLVMTypeConverter newConverter = *converter;
+ newConverter.dangerousSetOptions(overrideToMatchKernelOpts);
+ arguments = newConverter.promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
adaptor.getOperands().take_back(numKernelOperands), builder);
- converter->dangerousSetOptions(options);
} else {
arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
@@ -1111,15 +1115,15 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
SmallVector<Value, 4> arguments;
if (kernelBarePtrCallConv) {
// Hack the bare pointer value on just for the argument promotion
- LLVMTypeConverter *converter = getTypeConverter();
+ const LLVMTypeConverter *converter = getTypeConverter();
LowerToLLVMOptions options = converter->getOptions();
LowerToLLVMOptions overrideToMatchKernelOpts = options;
overrideToMatchKernelOpts.useBarePtrCallConv = true;
- converter->dangerousSetOptions(overrideToMatchKernelOpts);
+ LLVMTypeConverter newConverter = *converter;
+ newConverter.dangerousSetOptions(overrideToMatchKernelOpts);
arguments =
- converter->promoteOperands(loc, launchOp.getKernelOperands(),
- adaptor.getKernelOperands(), rewriter);
- converter->dangerousSetOptions(options);
+ newConverter.promoteOperands(loc, launchOp.getKernelOperands(),
+ adaptor.getKernelOperands(), rewriter);
} else {
arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
@@ -1200,7 +1204,7 @@ static Value bitAndAddrspaceCast(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMPointerType destinationType,
Value sourcePtr,
- LLVMTypeConverter &typeConverter) {
+ const LLVMTypeConverter &typeConverter) {
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index feea1e34f1b43b..693cc3f6236b57 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -222,7 +222,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
// Legalizes a GPU function as an entry SPIR-V function.
static spirv::FuncOp
-lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
+lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
spirv::EntryPointABIAttr entryPointInfo,
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 57e21530b9da76..3851fb728b6654 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -88,7 +88,7 @@ struct WmmaLoadOpToSPIRVLowering
auto memrefType =
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
Value bufferPtr = spirv::getElementPtr(
- *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
auto coopType = convertMMAToSPIRVType(retType);
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
@@ -119,7 +119,7 @@ struct WmmaStoreOpToSPIRVLowering
auto memrefType =
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
Value bufferPtr = spirv::getElementPtr(
- *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 2c9580e421340a..0a3c9a57eec95d 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -41,13 +41,13 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
/// type.
MemRefDescriptor
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
MemRefType type, Value memory) {
return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
}
MemRefDescriptor MemRefDescriptor::fromStaticShape(
- OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
MemRefType type, Value memory, Value alignedMemory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
@@ -198,7 +198,7 @@ LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
}
Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter,
+ const LLVMTypeConverter &converter,
MemRefType type) {
// When we convert to LLVM, the input memref must have been normalized
// beforehand. Hence, this call is guaranteed to work.
@@ -230,8 +230,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter, MemRefType type,
- ValueRange values) {
+ const LLVMTypeConverter &converter,
+ MemRefType type, ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
@@ -340,7 +340,7 @@ void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
/// - rank of the memref;
/// - pointer to the memref descriptor.
Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
- LLVMTypeConverter &converter,
+ const LLVMTypeConverter &converter,
UnrankedMemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
@@ -363,7 +363,7 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
}
void UnrankedMemRefDescriptor::computeSizes(
- OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
SmallVectorImpl<Value> &sizes) {
if (values.empty())
@@ -453,10 +453,9 @@ castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr,
return {elementPtrPtr, elemPtrPtrType};
}
-Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value memRefDescPtr,
- LLVM::LLVMPointerType elemPtrType) {
+Value UnrankedMemRefDescriptor::alignedPtr(
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
auto [elementPtrPtr, elemPtrPtrType] =
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
@@ -466,11 +465,9 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep);
}
-void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value memRefDescPtr,
- LLVM::LLVMPointerType elemPtrType,
- Value alignedPtr) {
+void UnrankedMemRefDescriptor::setAlignedPtr(
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) {
auto [elementPtrPtr, elemPtrPtrType] =
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
@@ -481,7 +478,7 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
}
Value UnrankedMemRefDescriptor::offsetBasePtr(
- OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
auto [elementPtrPtr, elemPtrPtrType] =
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
@@ -499,7 +496,7 @@ Value UnrankedMemRefDescriptor::offsetBasePtr(
}
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType) {
Value offsetPtr =
@@ -509,7 +506,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
}
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType,
Value offset) {
@@ -518,10 +515,9 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
}
-Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value memRefDescPtr,
- LLVM::LLVMPointerType elemPtrType) {
+Value UnrankedMemRefDescriptor::sizeBasePtr(
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
Type indexTy = typeConverter.getIndexType();
Type structTy = LLVM::LLVMStructType::getLiteral(
indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
@@ -542,7 +538,7 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value index) {
Type indexTy = typeConverter.getIndexType();
@@ -554,7 +550,7 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
}
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value index,
Value size) {
Type indexTy = typeConverter.getIndexType();
@@ -565,9 +561,9 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
}
-Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value sizeBasePtr, Value rank) {
+Value UnrankedMemRefDescriptor::strideBasePtr(
+ OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
+ Value sizeBasePtr, Value rank) {
Type indexTy = typeConverter.getIndexType();
Type indexPtrTy = typeConverter.getPointerType(indexTy);
@@ -576,7 +572,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
}
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value strideBasePtr, Value index,
Value stride) {
Type indexTy = typeConverter.getIndexType();
@@ -588,7 +584,7 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
}
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value strideBasePtr, Value index,
Value stride) {
Type indexTy = typeConverter.getIndexType();
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 1699172eb9dab3..e5519df9b0185f 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -19,14 +19,13 @@ using namespace mlir;
// ConvertToLLVMPattern
//===----------------------------------------------------------------------===//
-ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
- MLIRContext *context,
- LLVMTypeConverter &typeConverter,
- PatternBenefit benefit)
+ConvertToLLVMPattern::ConvertToLLVMPattern(
+ StringRef rootOpName, MLIRContext *context,
+ const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
: ConversionPattern(typeConverter, rootOpName, benefit, context) {}
-LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
- return static_cast<LLVMTypeConverter *>(
+const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
+ return static_cast<const LLVMTypeConverter *>(
ConversionPattern::getTypeConverter());
}
@@ -337,10 +336,12 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
-LogicalResult LLVM::detail::oneToOneRewrite(
- Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
+LogicalResult
+LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
+ ValueRange operands,
+ ArrayRef<NamedAttribute> targetAttrs,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 9e03e2ffbacf83..b0842b9972c76d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -166,34 +166,35 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
}
/// Returns the MLIR context.
-MLIRContext &LLVMTypeConverter::getContext() {
+MLIRContext &LLVMTypeConverter::getContext() const {
return *getDialect()->getContext();
}
-Type LLVMTypeConverter::getIndexType() {
+Type LLVMTypeConverter::getIndexType() const {
return IntegerType::get(&getContext(), getIndexTypeBitwidth());
}
LLVM::LLVMPointerType
-LLVMTypeConverter::getPointerType(Type elementType, unsigned int addressSpace) {
+LLVMTypeConverter::getPointerType(Type elementType,
+ unsigned int addressSpace) const {
if (useOpaquePointers())
return LLVM::LLVMPointerType::get(&getContext(), addressSpace);
return LLVM::LLVMPointerType::get(elementType, addressSpace);
}
-unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
+unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
return options.dataLayout.getPointerSizeInBits(addressSpace);
}
-Type LLVMTypeConverter::convertIndexType(IndexType type) {
+Type LLVMTypeConverter::convertIndexType(IndexType type) const {
return getIndexType();
}
-Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
+Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
return IntegerType::get(&getContext(), type.getWidth());
}
-Type LLVMTypeConverter::convertFloatType(FloatType type) {
+Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
return IntegerType::get(&getContext(), type.getWidth());
@@ -204,7 +205,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
// struct with entries for the
// 1. real part and for the
// 2. imaginary part.
-Type LLVMTypeConverter::convertComplexType(ComplexType type) {
+Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
auto elementType = convertType(type.getElementType());
return LLVM::LLVMStructType::getLiteral(&getContext(),
{elementType, elementType});
@@ -212,7 +213,7 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) {
// Except for signatures, MLIR function types are converted into LLVM
// pointer-to-function types.
-Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
+Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
SignatureConversion conversion(type.getNumInputs());
Type converted = convertFunctionSignature(
type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion);
@@ -227,7 +228,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
- LLVMTypeConverter::SignatureConversion &result) {
+ LLVMTypeConverter::SignatureConversion &result) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
@@ -256,7 +257,7 @@ Type LLVMTypeConverter::convertFunctionSignature(
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
-LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
SmallVector<Type, 4> inputs;
Type resultType = type.getNumResults() == 0
@@ -315,7 +316,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
/// };
SmallVector<Type, 5>
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
- bool unpackAggregates) {
+ bool unpackAggregates) const {
if (!isStrided(type)) {
emitError(
UnknownLoc::get(type.getContext()),
@@ -353,8 +354,9 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
return results;
}
-unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
- const DataLayout &layout) {
+unsigned
+LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
+ const DataLayout &layout) const {
// Compute the descriptor size given that of its components indicated above.
unsigned space = *getMemRefAddressSpace(type);
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
@@ -363,7 +365,7 @@ unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
-Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
+Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
// When converting a MemRefType to a struct with descriptor fields, do not
// unpack the `sizes` and `strides` arrays.
SmallVector<Type, 5> types =
@@ -380,20 +382,21 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
/// be unranked.
-SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
+SmallVector<Type, 2>
+LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
return {getIndexType(), getPointerType(IntegerType::get(&getContext(), 8))};
}
-unsigned
-LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
- const DataLayout &layout) {
+unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
+ UnrankedMemRefType type, const DataLayout &layout) const {
// Compute the descriptor size given that of its components indicated above.
unsigned space = *getMemRefAddressSpace(type);
return layout.getTypeSize(getIndexType()) +
llvm::divideCeil(getPointerBitwidth(space), 8);
}
-Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
+Type LLVMTypeConverter::convertUnrankedMemRefType(
+ UnrankedMemRefType type) const {
if (!convertType(type.getElementType()))
return {};
return LLVM::LLVMStructType::getLiteral(&getContext(),
@@ -401,7 +404,7 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
}
FailureOr<unsigned>
-LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
+LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
if (!type.getMemorySpace()) // Default memory space -> 0.
return 0;
std::optional<Attribute> converted =
@@ -440,7 +443,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
}
/// Convert a memref type to a bare pointer to the memref element type.
-Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
+Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
if (!canConvertToBarePtr(type))
return {};
Type elementType = convertType(type.getElementType());
@@ -460,7 +463,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
/// As LLVM does not support arrays of scalable vectors, it is assumed that
/// scalable vectors are always 1-D. This condition could be relaxed once the
/// missing functionality is added in LLVM
-Type LLVMTypeConverter::convertVectorType(VectorType type) {
+Type LLVMTypeConverter::convertVectorType(VectorType type) const {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
@@ -484,8 +487,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(Type type,
- bool useBarePtrCallConv) {
+Type LLVMTypeConverter::convertCallingConventionType(
+ Type type, bool useBarePtrCallConv) const {
if (useBarePtrCallConv)
if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
return convertMemRefToBarePtr(memrefTy);
@@ -498,7 +501,7 @@ Type LLVMTypeConverter::convertCallingConventionType(Type type,
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void LLVMTypeConverter::promoteBarePtrsToDescriptors(
ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) {
+ SmallVectorImpl<Value> &values) const {
assert(stdTypes.size() == values.size() &&
"The number of types and values doesn't match");
for (unsigned i = 0, end = values.size(); i < end; ++i)
@@ -511,7 +514,7 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
/// LLVM-compatible type. In particular, if more than one value is
/// produced, create a literal structure with elements that correspond to each
/// of the types converted with `convertType`.
-Type LLVMTypeConverter::packOperationResults(TypeRange types) {
+Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
return convertType(types[0]);
@@ -533,7 +536,7 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) {
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) {
+ bool useBarePtrCallConv) const {
assert(!types.empty() && "expected non-empty list of type");
useBarePtrCallConv |= options.useBarePtrCallConv;
@@ -553,7 +556,7 @@ Type LLVMTypeConverter::packFunctionResults(TypeRange types,
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
- OpBuilder &builder) {
+ OpBuilder &builder) const {
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = getPointerType(operand.getType());
@@ -569,7 +572,7 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
SmallVector<Value, 4>
LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) {
+ bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
@@ -608,9 +611,9 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result) {
+LogicalResult
+mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ SmallVectorImpl<Type> &result) {
if (auto memref = dyn_cast<MemRefType>(type)) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
@@ -637,9 +640,9 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result) {
+LogicalResult
+mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ SmallVectorImpl<Type> &result) {
auto llvmTy = converter.convertCallingConventionType(
type, /*useBarePointerCallConv=*/true);
if (!llvmTy)
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 732f6c578c8b57..544bcc71aca1b5 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -17,7 +17,7 @@ using namespace mlir;
// asserted to be an llvm vector type).
LLVM::detail::NDVectorTypeInfo
LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
- LLVMTypeConverter &converter) {
+ const LLVMTypeConverter &converter) {
assert(vectorType.getRank() > 1 && "expected >1D vector type");
NDVectorTypeInfo info;
info.llvmNDVectorTy = converter.convertType(vectorType);
@@ -78,7 +78,7 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
}
LogicalResult LLVM::detail::handleMultidimensionalVectors(
- Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
+ Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
@@ -103,10 +103,12 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}
-LogicalResult LLVM::detail::vectorOneToOneRewrite(
- Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
+LogicalResult
+LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
+ ValueRange operands,
+ ArrayRef<NamedAttribute> targetAttrs,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 715d00f2e215ac..a2a426e3c29317 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -18,7 +18,7 @@ namespace {
// with SymbolTable trait instead of ModuleOp and make similar change here. This
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
// of getParentOfType<ModuleOp> to pass down the operation.
-LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
+LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
ModuleOp module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
@@ -30,7 +30,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
typeConverter->useOpaquePointers());
}
-LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter,
+LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
ModuleOp module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
@@ -57,7 +57,7 @@ Value AllocationOpLLVMLowering::createAligned(
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
Location loc, Value allocatedPtr,
MemRefType memRefType, Type elementPtrType,
- LLVMTypeConverter &typeConverter) {
+ const LLVMTypeConverter &typeConverter) {
auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType);
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4b27dcb6cda281..8843ab78eed782 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -41,7 +41,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
}
-LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) {
+LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
+ ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -52,7 +53,7 @@ LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) {
}
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
- AllocOpLowering(LLVMTypeConverter &converter)
+ AllocOpLowering(const LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
@@ -65,7 +66,7 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
};
struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
- AlignedAllocOpLowering(LLVMTypeConverter &converter)
+ AlignedAllocOpLowering(const LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
@@ -84,7 +85,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
};
struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
- AllocaOpLowering(LLVMTypeConverter &converter)
+ AllocaOpLowering(const LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
converter) {
setRequiresNumElements();
@@ -122,7 +123,7 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
using OpAdaptor = typename memref::ReallocOp::Adaptor;
- ReallocOpLoweringBase(LLVMTypeConverter &converter)
+ ReallocOpLoweringBase(const LLVMTypeConverter &converter)
: AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(),
converter) {}
@@ -247,7 +248,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
};
struct ReallocOpLowering : public ReallocOpLoweringBase {
- ReallocOpLowering(LLVMTypeConverter &converter)
+ ReallocOpLowering(const LLVMTypeConverter &converter)
: ReallocOpLoweringBase(converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
@@ -258,7 +259,7 @@ struct ReallocOpLowering : public ReallocOpLoweringBase {
};
struct AlignedReallocOpLowering : public ReallocOpLoweringBase {
- AlignedReallocOpLowering(LLVMTypeConverter &converter)
+ AlignedReallocOpLowering(const LLVMTypeConverter &converter)
: ReallocOpLoweringBase(converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
@@ -334,7 +335,7 @@ struct AssumeAlignmentOpLowering
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
using ConvertOpToLLVMPattern<
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
- explicit AssumeAlignmentOpLowering(LLVMTypeConverter &converter)
+ explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
LogicalResult
@@ -376,7 +377,7 @@ struct AssumeAlignmentOpLowering
struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
- explicit DeallocOpLowering(LLVMTypeConverter &converter)
+ explicit DeallocOpLowering(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
LogicalResult
@@ -635,8 +636,9 @@ struct GenericAtomicRMWOpLowering
};
/// Returns the LLVM type of the global variable given the memref type `type`.
-static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
- LLVMTypeConverter &typeConverter) {
+static Type
+convertGlobalMemrefTypeToLLVM(MemRefType type,
+ const LLVMTypeConverter &typeConverter) {
// LLVM type for a global memref will be a multi-dimension array. For
// declarations or uninitialized global memrefs, we can potentially flatten
// this to a 1D array. However, for memref.global's with an initial value,
@@ -703,7 +705,7 @@ struct GlobalMemrefOpLowering
/// the first element stashed into the descriptor. This reuses
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
- GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
+ GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
converter) {}
@@ -1191,7 +1193,7 @@ struct MemorySpaceCastOpLowering
/// ranked descriptor.
static void extractPointersAndOffset(Location loc,
ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
Value originalOperand,
Value convertedOperand,
Value *allocatedPtr, Value *alignedPtr,
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 1d85e64bdfbfc3..f024bdfda93888 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -61,10 +61,10 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
/// 1D array (spirv.array or spirv.rt_array), the last index is modified to load
/// the bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
-static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
- spirv::AccessChainOp op,
- int sourceBits, int targetBits,
- OpBuilder &builder) {
+static Value
+adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
+ spirv::AccessChainOp op, int sourceBits,
+ int targetBits, OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
const auto loc = op.getLoc();
IntegerType targetType = builder.getIntegerType(targetBits);
@@ -277,7 +277,7 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
Value src = adaptor.getSource();
Type srcType = src.getType();
- TypeConverter *converter = getTypeConverter();
+ const TypeConverter *converter = getTypeConverter();
Type dstType = converter->convertType(op.getType());
if (srcType != dstType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
@@ -436,7 +436,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!memrefType.getElementType().isSignlessInteger())
return failure();
- auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
@@ -768,7 +768,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
diag << "invalid src type " << src.getType();
});
- TypeConverter *converter = getTypeConverter();
+ const TypeConverter *converter = getTypeConverter();
auto dstType = converter->convertType<spirv::PointerType>(op.getType());
if (dstType != srcType)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 64394de91d4dd7..21c6780cc7887f 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -236,7 +236,7 @@ MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
/// Returns the base pointer of the mbarrier object.
static Value getMbarrierPtr(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
TypedValue<nvgpu::MBarrierType> barrier,
Value barrierMemref) {
MemRefType memrefType =
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index fb8ad5a4c31f54..d06b7033257196 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -58,7 +58,7 @@ struct RegionLessOpWithVarOperandsConversion
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
+ const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
@@ -90,7 +90,7 @@ struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
+ const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
@@ -128,7 +128,7 @@ struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
+ const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
@@ -145,7 +145,7 @@ struct AtomicReadOpConversion
LogicalResult
matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
+ const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
Type curElementType = curOp.getElementType();
auto newOp = rewriter.create<omp::AtomicReadOp>(
curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e325348242affe..92f7aa69760395 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -37,7 +37,7 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter, Location loc,
+ const LLVMTypeConverter &typeConverter, Location loc,
Value val1, Value val2, Type llvmType, int64_t rank,
int64_t pos) {
assert(rank > 0 && "0-D vector corner case should have been handled already");
@@ -54,7 +54,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter, Location loc,
+ const LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank <= 1) {
auto idxType = rewriter.getIndexType();
@@ -68,7 +68,7 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
}
// Helper that returns data layout alignment of a memref.
-LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
+LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
Type elementTy = typeConverter.convertType(memrefType.getElementType());
if (!elementTy)
@@ -84,7 +84,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
// Check if the last stride is non-unit or the memory space is not zero.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
- LLVMTypeConverter &converter) {
+ const LLVMTypeConverter &converter) {
if (!isLastMemrefDimUnitStride(memRefType))
return failure();
FailureOr<unsigned> addressSpace =
@@ -96,7 +96,7 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
// Add an index vector component to a base pointer.
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
MemRefType memRefType, Value llvmMemref, Value base,
Value index, uint64_t vLen) {
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
@@ -112,7 +112,7 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
// will be in the same address space as the incoming memref type.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, MemRefType memRefType, Type vt,
- LLVMTypeConverter &converter) {
+ const LLVMTypeConverter &converter) {
if (converter.useOpaquePointers())
return ptr;
@@ -294,7 +294,7 @@ class VectorGatherOpConversion
return success();
}
- LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+ const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
auto callback = [align, memRefType, base, ptr, loc, &rewriter,
&typeConverter](Type llvm1DVectorTy,
ValueRange vectorOperands) {
@@ -672,7 +672,7 @@ static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
public:
- explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
+ explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
bool reassociateFPRed)
: ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
reassociateFPReductions(reassociateFPRed) {}
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c19f8f182a923d..1355af14660776 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -24,7 +24,7 @@ namespace {
/// dimension directly translates into the number of rows of the tiles.
/// The second dimensions needs to be scaled by the number of bytes.
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter,
+ const LLVMTypeConverter &typeConverter,
VectorType vType, Location loc) {
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
unsigned width = vType.getElementType().getIntOrFloatBitWidth();
@@ -52,8 +52,8 @@ LogicalResult verifyStride(MemRefType mType) {
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
Value getStride(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter, MemRefType mType, Value base,
- Location loc) {
+ const LLVMTypeConverter &typeConverter, MemRefType mType,
+ Value base, Location loc) {
assert(mType.getRank() >= 2);
int64_t last = mType.getRank() - 1;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index e768940cc27b5b..2b654db87fe4ff 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -80,7 +80,7 @@ LogicalResult EmulateFloatPattern::match(Operation *op) const {
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
- TypeConverter *converter = getTypeConverter();
+ const TypeConverter *converter = getTypeConverter();
SmallVector<Type> resultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
// Note to anyone looking for this error message: this is a "can't happen".
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 94ce4ebb812947..c75d217663a9e0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -132,7 +132,7 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
return targetEnv.getAttr().getContext();
}
-bool SPIRVTypeConverter::allows(spirv::Capability capability) {
+bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
return targetEnv.allows(capability);
}
@@ -992,7 +992,7 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
return linearizedIndex;
}
-Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc,
OpBuilder &builder) {
@@ -1023,7 +1023,7 @@ Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
-Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc,
OpBuilder &builder) {
@@ -1058,7 +1058,7 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
linearizedIndices);
}
-Value mlir::spirv::getElementPtr(SPIRVTypeConverter &typeConverter,
+Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc,
OpBuilder &builder) {
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index b36f2978d20e38..c33304c18fe48a 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -40,8 +40,8 @@ struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
explicit LowerToIntrinsic(LLVMTypeConverter &converter)
: OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
- LLVMTypeConverter &getTypeConverter() const {
- return *static_cast<LLVMTypeConverter *>(
+ const LLVMTypeConverter &getTypeConverter() const {
+ return *static_cast<const LLVMTypeConverter *>(
OpConversionPattern<OpTy>::getTypeConverter());
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index fa75d6efa15bb2..78d7b47558b553 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -226,11 +226,12 @@ class OperationTransactionState {
/// This class represents one requested operation replacement via 'replaceOp' or
/// 'eraseOp`.
struct OpReplacement {
- OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {}
+ OpReplacement(const TypeConverter *converter = nullptr)
+ : converter(converter) {}
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
- TypeConverter *converter;
+ const TypeConverter *converter;
};
//===----------------------------------------------------------------------===//
@@ -333,7 +334,7 @@ class UnresolvedMaterialization {
};
UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
- TypeConverter *converter = nullptr,
+ const TypeConverter *converter = nullptr,
Kind kind = Target, Type origOutputType = nullptr)
: op(op), converterAndKind(converter, kind),
origOutputType(origOutputType) {}
@@ -343,7 +344,9 @@ class UnresolvedMaterialization {
UnrealizedConversionCastOp getOp() const { return op; }
/// Return the type converter of this materialization (which may be null).
- TypeConverter *getConverter() const { return converterAndKind.getPointer(); }
+ const TypeConverter *getConverter() const {
+ return converterAndKind.getPointer();
+ }
/// Return the kind of this materialization.
Kind getKind() const { return converterAndKind.getInt(); }
@@ -360,7 +363,7 @@ class UnresolvedMaterialization {
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
- llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
+ llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
/// The original output type. This is only used for argument conversions.
Type origOutputType;
@@ -372,7 +375,7 @@ class UnresolvedMaterialization {
static Value buildUnresolvedMaterialization(
UnresolvedMaterialization::Kind kind, Block *insertBlock,
Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
- Type origOutputType, TypeConverter *converter,
+ Type origOutputType, const TypeConverter *converter,
SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -389,7 +392,7 @@ static Value buildUnresolvedMaterialization(
}
static Value buildUnresolvedArgumentMaterialization(
PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, TypeConverter *converter,
+ Type origOutputType, Type outputType, const TypeConverter *converter,
SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
return buildUnresolvedMaterialization(
UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
@@ -397,7 +400,7 @@ static Value buildUnresolvedArgumentMaterialization(
converter, unresolvedMaterializations);
}
static Value buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType, TypeConverter *converter,
+ Location loc, Value input, Type outputType, const TypeConverter *converter,
SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
Block *insertBlock = input.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
@@ -446,7 +449,7 @@ struct ArgConverter {
/// This structure contains information pertaining to a block that has had its
/// signature converted.
struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, TypeConverter *converter)
+ ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
: origBlock(origBlock), converter(converter) {}
/// The original block that was requested to have its signature converted.
@@ -457,7 +460,7 @@ struct ArgConverter {
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
/// The type converter used to convert the arguments.
- TypeConverter *converter;
+ const TypeConverter *converter;
};
/// Return if the signature of the given block has already been converted.
@@ -466,14 +469,14 @@ struct ArgConverter {
}
/// Set the type converter to use for the given region.
- void setConverter(Region *region, TypeConverter *typeConverter) {
+ void setConverter(Region *region, const TypeConverter *typeConverter) {
assert(typeConverter && "expected valid type converter");
regionToConverter[region] = typeConverter;
}
/// Return the type converter to use for the given region, or null if there
/// isn't one.
- TypeConverter *getConverter(Region *region) {
+ const TypeConverter *getConverter(Region *region) {
return regionToConverter.lookup(region);
}
@@ -510,7 +513,7 @@ struct ArgConverter {
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *>
- convertSignature(Block *block, TypeConverter *converter,
+ convertSignature(Block *block, const TypeConverter *converter,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
@@ -521,7 +524,7 @@ struct ArgConverter {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- Block *block, TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
@@ -542,7 +545,7 @@ struct ArgConverter {
/// A mapping of regions to type converters that should be used when
/// converting the arguments of blocks within that region.
- DenseMap<Region *, TypeConverter *> regionToConverter;
+ DenseMap<Region *, const TypeConverter *> regionToConverter;
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
@@ -686,7 +689,8 @@ LogicalResult ArgConverter::materializeLiveConversions(
// Conversion
FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, TypeConverter *converter, ConversionValueMapping &mapping,
+ Block *block, const TypeConverter *converter,
+ ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
// Check if the block was already converted. If the block is detached,
// conservatively assume it is going to be deleted.
@@ -705,7 +709,7 @@ FailureOr<Block *> ArgConverter::convertSignature(
}
Block *ArgConverter::applySignatureConversion(
- Block *block, TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
@@ -865,7 +869,7 @@ struct ConversionPatternRewriterImpl {
/// Convert the signature of the given block.
FailureOr<Block *> convertBlockSignature(
- Block *block, TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
/// Apply a signature conversion on the given region, using `converter` for
@@ -873,16 +877,16 @@ struct ConversionPatternRewriterImpl {
Block *
applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion,
- TypeConverter *converter);
+ const TypeConverter *converter);
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(Region *region, TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
//===--------------------------------------------------------------------===//
@@ -962,7 +966,7 @@ struct ConversionPatternRewriterImpl {
/// The current type converter, or nullptr if no type converter is currently
/// active.
- TypeConverter *currentTypeConverter = nullptr;
+ const TypeConverter *currentTypeConverter = nullptr;
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
@@ -1283,7 +1287,7 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
// Type Conversion
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- Block *block, TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
FailureOr<Block *> result =
conversion ? argConverter.applySignatureConversion(
@@ -1301,14 +1305,14 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
- TypeConverter *converter) {
+ const TypeConverter *converter) {
if (!region->empty())
return *convertBlockSignature(®ion->front(), converter, &conversion);
return nullptr;
}
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
argConverter.setConverter(region, &converter);
if (region->empty())
@@ -1323,7 +1327,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
}
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
argConverter.setConverter(region, &converter);
if (region->empty())
@@ -1492,18 +1496,18 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
- TypeConverter *converter) {
+ const TypeConverter *converter) {
return impl->applySignatureConversion(region, conversion, converter);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
return impl->convertRegionTypes(region, converter, entryConversion);
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
}
@@ -2341,7 +2345,7 @@ struct OperationConverter {
/// type.
LogicalResult legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
- TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
+ const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
@@ -2717,7 +2721,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
}
// Try to materialize the conversion.
- if (TypeConverter *converter = mat.getConverter()) {
+ if (const TypeConverter *converter = mat.getConverter()) {
// FIXME: Determine a suitable insertion location when there are multiple
// inputs.
if (inputOperands.size() == 1)
@@ -2836,7 +2840,7 @@ static Operation *findLiveUserOfReplaced(
LogicalResult OperationConverter::legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
- TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
+ const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
Operation *liveUser =
@@ -3075,7 +3079,7 @@ TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
//===----------------------------------------------------------------------===//
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
- TypeConverter &typeConverter,
+ const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
if (!type)
@@ -3106,7 +3110,7 @@ namespace {
struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
MLIRContext *ctx,
- TypeConverter &converter)
+ const TypeConverter &converter)
: ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
LogicalResult
@@ -3131,13 +3135,13 @@ struct AnyFunctionOpInterfaceSignatureConversion
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
- TypeConverter &converter) {
+ const TypeConverter &converter) {
patterns.add<FunctionOpInterfaceSignatureConversion>(
functionLikeOpName, patterns.getContext(), converter);
}
void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
- RewritePatternSet &patterns, TypeConverter &converter) {
+ RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
converter, patterns.getContext());
}
@@ -3338,7 +3342,8 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
[](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
- if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
+ if (const TypeConverter *converter =
+ rewriterImpl.currentTypeConverter) {
if (Type newType = converter->convertType(type))
return newType;
return failure();
@@ -3351,7 +3356,7 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
TypeRange types) -> FailureOr<SmallVector<Type>> {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
- TypeConverter *converter = rewriterImpl.currentTypeConverter;
+ const TypeConverter *converter = rewriterImpl.currentTypeConverter;
if (!converter)
return SmallVector<Type>(types);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 46788edcb4df58..30ed4109ad8bd1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -668,7 +668,8 @@ struct TestUndoBlockErase : public ConversionPattern {
/// This patterns erases a region operation that has had a type conversion.
struct TestDropOpSignatureConversion : public ConversionPattern {
- TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+ TestDropOpSignatureConversion(MLIRContext *ctx,
+ const TypeConverter &converter)
: ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -677,7 +678,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
Block *entry = ®ion.front();
// Convert the original entry arguments.
- TypeConverter &converter = *getTypeConverter();
+ const TypeConverter &converter = *getTypeConverter();
TypeConverter::SignatureConversion result(entry->getNumArguments());
if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
result)) ||
@@ -1307,7 +1308,7 @@ struct TestSignatureConversionUndo
/// materializations.
struct TestTestSignatureConversionNoConverter
: public OpConversionPattern<TestSignatureConversionNoConverterOp> {
- TestTestSignatureConversionNoConverter(TypeConverter &converter,
+ TestTestSignatureConversionNoConverter(const TypeConverter &converter,
MLIRContext *context)
: OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
converter(converter) {}
@@ -1328,7 +1329,7 @@ struct TestTestSignatureConversionNoConverter
return success();
}
- TypeConverter &converter;
+ const TypeConverter &converter;
};
/// Just forward the operands to the root op. This is essentially a no-op
More information about the flang-commits
mailing list