[Mlir-commits] [mlir] ef97633 - [mlir:OpConversion] Remove the remaing usages of the deprecated matchAndRewrite methods
River Riddle
llvmlistbot at llvm.org
Fri Sep 24 10:58:51 PDT 2021
Author: River Riddle
Date: 2021-09-24T17:51:41Z
New Revision: ef976337f581dd8a80820a8b14b4bbd70670b7fc
URL: https://github.com/llvm/llvm-project/commit/ef976337f581dd8a80820a8b14b4bbd70670b7fc
DIFF: https://github.com/llvm/llvm-project/commit/ef976337f581dd8a80820a8b14b4bbd70670b7fc.diff
LOG: [mlir:OpConversion] Remove the remaing usages of the deprecated matchAndRewrite methods
This commits updates the remaining usages of the ArrayRef<Value> based
matchAndRewrite/rewrite methods in favor of the new OpAdaptor
overload.
Differential Revision: https://reviews.llvm.org/D110360
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 81358dc3fcbfd..4ffd135538b79 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -217,11 +217,11 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
- matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
- operands, *this->getTypeConverter(),
- rewriter);
+ adaptor.getOperands(),
+ *this->getTypeConverter(), rewriter);
}
};
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 383516ac3cd6d..7eba83c3b95a9 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -70,14 +70,14 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
LogicalResult
- matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
return LLVM::detail::vectorOneToOneRewrite(
- op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
- rewriter);
+ op, TargetOp::getOperationName(), adaptor.getOperands(),
+ *this->getTypeConverter(), rewriter);
}
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index a2ba4c507c1ac..129b9e624879d 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -58,12 +58,11 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- complex::AbsOp::Adaptor transformed(operands);
auto loc = op.getLoc();
- ComplexStructBuilder complexStruct(transformed.complex());
+ ComplexStructBuilder complexStruct(adaptor.complex());
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());
@@ -81,16 +80,14 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
+ matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- complex::CreateOp::Adaptor transformed(operands);
-
// Pack real and imaginary part in a complex number struct.
auto loc = complexOp.getLoc();
auto structType = typeConverter->convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
- complexStruct.setReal(rewriter, loc, transformed.real());
- complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
+ complexStruct.setReal(rewriter, loc, adaptor.real());
+ complexStruct.setImaginary(rewriter, loc, adaptor.imaginary());
rewriter.replaceOp(complexOp, {complexStruct});
return success();
@@ -101,12 +98,10 @@ struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- complex::ReOp::Adaptor transformed(operands);
-
// Extract real part from the complex number struct.
- ComplexStructBuilder complexStruct(transformed.complex());
+ ComplexStructBuilder complexStruct(adaptor.complex());
Value real = complexStruct.real(rewriter, op.getLoc());
rewriter.replaceOp(op, real);
@@ -118,12 +113,10 @@ struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- complex::ImOp::Adaptor transformed(operands);
-
// Extract imaginary part from the complex number struct.
- ComplexStructBuilder complexStruct(transformed.complex());
+ ComplexStructBuilder complexStruct(adaptor.complex());
Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
rewriter.replaceOp(op, imaginary);
@@ -138,17 +131,16 @@ struct BinaryComplexOperands {
template <typename OpTy>
BinaryComplexOperands
-unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
+unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) {
auto loc = op.getLoc();
- typename OpTy::Adaptor transformed(operands);
// Extract real and imaginary values from operands.
BinaryComplexOperands unpacked;
- ComplexStructBuilder lhs(transformed.lhs());
+ ComplexStructBuilder lhs(adaptor.lhs());
unpacked.lhs.real(lhs.real(rewriter, loc));
unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
- ComplexStructBuilder rhs(transformed.rhs());
+ ComplexStructBuilder rhs(adaptor.rhs());
unpacked.rhs.real(rhs.real(rewriter, loc));
unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
@@ -159,11 +151,11 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
- unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
+ unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
@@ -187,11 +179,11 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
- unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
+ unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
@@ -232,11 +224,11 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
- unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
+ unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
@@ -269,11 +261,11 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
+ matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
- unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
+ unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 98aa0e0aea1e0..c8fc7b2346496 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -14,10 +14,8 @@
using namespace mlir;
LogicalResult
-GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
- ArrayRef<Value> operands,
+GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- assert(operands.empty() && "func op is not expected to have operands");
Location loc = gpuFuncOp.getLoc();
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 9d54001d10292..72805dcb8191b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -22,7 +22,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
kernelAttributeName(kernelAttributeName) {}
LogicalResult
- matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
private:
@@ -37,9 +37,9 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(gpu::ReturnOp op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 0b33ab49fc46a..40a4463f7dbe6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -195,7 +195,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -209,7 +209,7 @@ class ConvertAllocOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -223,7 +223,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -235,7 +235,7 @@ class ConvertAsyncYieldToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands,
+ matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -249,7 +249,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -263,7 +263,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -289,13 +289,13 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
gpuBinaryAnnotation(gpuBinaryAnnotation) {}
private:
- Value generateParamsArray(gpu::LaunchFuncOp launchOp,
- ArrayRef<Value> operands, OpBuilder &builder) const;
+ Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
+ OpBuilder &builder) const;
Value generateKernelNameConstant(StringRef moduleName, StringRef name,
Location loc, OpBuilder &builder) const;
LogicalResult
- matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
llvm::SmallString<32> gpuBinaryAnnotation;
@@ -323,7 +323,7 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -337,7 +337,7 @@ class ConvertMemsetOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
@@ -398,10 +398,10 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
}
LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
+ gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *op = hostRegisterOp.getOperation();
- if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
Location loc = op->getLoc();
@@ -410,8 +410,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
- auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
- operands, rewriter);
+ auto arguments = getTypeConverter()->promoteOperands(
+ loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
@@ -420,17 +420,16 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
}
LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::AllocOp allocOp, ArrayRef<Value> operands,
+ gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType memRefType = allocOp.getType();
- if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
+ if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, allocOp)))
return failure();
auto loc = allocOp.getLoc();
- auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());
// Get shape of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands.
@@ -462,16 +461,14 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
}
LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
+ gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
+ if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, deallocOp)))
return failure();
Location loc = deallocOp.getLoc();
- auto adaptor =
- gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
Value pointer =
MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
@@ -491,19 +488,19 @@ static bool isGpuAsyncTokenType(Value value) {
// are passed as events between them. For each !gpu.async.token operand, we
// create an event and record it on the stream.
LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
- async::YieldOp yieldOp, ArrayRef<Value> operands,
+ async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
Location loc = yieldOp.getLoc();
- SmallVector<Value, 4> newOperands(operands.begin(), operands.end());
+ SmallVector<Value, 4> newOperands(adaptor.getOperands());
llvm::SmallDenseSet<Value> streams;
for (auto &operand : yieldOp->getOpOperands()) {
if (!isGpuAsyncTokenType(operand.get()))
continue;
auto idx = operand.getOperandNumber();
- auto stream = operands[idx];
+ auto stream = adaptor.getOperands()[idx];
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
eventRecordCallBuilder.create(loc, rewriter, {event, stream});
newOperands[idx] = event;
@@ -530,14 +527,14 @@ static bool isDefinedByCallTo(Value value, StringRef functionName) {
// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
// runtime error. Eventually, we should guarantee this property.
LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::WaitOp waitOp, ArrayRef<Value> operands,
+ gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
Location loc = waitOp.getLoc();
- for (auto operand : operands) {
+ for (auto operand : adaptor.getOperands()) {
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream.
streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
@@ -560,7 +557,7 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
// Otherwise we will get a runtime error. Eventually, we should guarantee this
// property.
LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::WaitOp waitOp, ArrayRef<Value> operands,
+ gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
@@ -569,7 +566,8 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
- for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
+ for (auto pair :
+ llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
auto operand = std::get<1>(pair);
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream. Insert an event
@@ -611,13 +609,12 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// llvm.store %fieldPtr, %elementPtr
// return %array
Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
- gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
- OpBuilder &builder) const {
+ gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
auto loc = launchOp.getLoc();
auto numKernelOperands = launchOp.getNumKernelOperands();
auto arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
- operands.take_back(numKernelOperands), builder);
+ adaptor.getOperands().take_back(numKernelOperands), builder);
auto numArguments = arguments.size();
SmallVector<Type, 4> argumentTypes;
argumentTypes.reserve(numArguments);
@@ -693,9 +690,9 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
// If the op is async, the stream corresponds to the (single) async dependency
// as well as the async token the op produces.
LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
+ gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
+ if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
return failure();
if (launchOp.asyncDependencies().size() > 1)
@@ -741,14 +738,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
- auto adaptor =
- gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
- auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
+ auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
launchKernelCallBuilder.create(loc, rewriter,
{function.getResult(0), adaptor.gridSizeX(),
@@ -775,17 +770,16 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
}
LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
+ gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
- if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) ||
+ if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
return failure();
auto loc = memcpyOp.getLoc();
- auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
MemRefDescriptor srcDesc(adaptor.src());
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
@@ -812,17 +806,16 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
}
LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
- gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+ gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
- if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
+ if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memsetOp)))
return failure();
auto loc = memsetOp.getLoc();
- auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());
Type valueType = adaptor.value().getType();
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 1f8012272fb90..416964d4e3db6 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -41,7 +41,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
// Convert the kernel arguments to an LLVM type, preserve the rest.
LogicalResult
- matchAndRewrite(Op op, ArrayRef<Value> operands,
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index b8781fc68b346..2c1c0e107a578 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -37,7 +37,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
f64Func(f64Func) {}
LogicalResult
- matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
@@ -50,7 +50,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
"expected op with same operand and result types");
SmallVector<Value, 1> castedOperands;
- for (Value operand : operands)
+ for (Value operand : adaptor.getOperands())
castedOperands.push_back(maybeCast(operand, rewriter));
Type resultType = castedOperands.front().getType();
@@ -64,13 +64,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
- if (resultType == operands.front().getType()) {
+ if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult(0)});
return success();
}
Value truncated = rewriter.create<LLVM::FPTruncOp>(
- op->getLoc(), operands.front().getType(), callOp.getResult(0));
+ op->getLoc(), adaptor.getOperands().front().getType(),
+ callOp.getResult(0));
rewriter.replaceOp(op, {truncated});
return success();
}
@@ -85,11 +86,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
}
- Type getFunctionType(Type resultType, ArrayRef<Value> operands) const {
- SmallVector<Type, 1> operandTypes;
- for (Value operand : operands) {
- operandTypes.push_back(operand.getType());
- }
+ Type getFunctionType(Type resultType, ValueRange operands) const {
+ SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4c8ad15b098b9..69a9feabf5904 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -57,10 +57,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
LogicalResult
- matchAndRewrite(gpu::ShuffleOp op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- gpu::ShuffleOpAdaptor adaptor(operands);
auto valueTy = adaptor.value().getType();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 3a86e2e96bb96..0296390d9c082 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -69,10 +69,10 @@ struct WmmaLoadOpToNVVMLowering
LogicalResult
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaLoadMatrixOp.getOperation();
- if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
unsigned indexTypeBitwidth =
@@ -88,7 +88,6 @@ struct WmmaLoadOpToNVVMLowering
auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
- gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
@@ -177,10 +176,10 @@ struct WmmaStoreOpToNVVMLowering
LogicalResult
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaStoreMatrixOp.getOperation();
- if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
unsigned indexTypeBitwidth =
@@ -194,7 +193,6 @@ struct WmmaStoreOpToNVVMLowering
Location loc = op->getLoc();
- gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedDstOp(adaptor.dstMemref());
@@ -282,10 +280,10 @@ struct WmmaMmaOpToNVVMLowering
LogicalResult
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaComputeOp.getOperation();
- if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
Location loc = op->getLoc();
@@ -317,17 +315,16 @@ struct WmmaMmaOpToNVVMLowering
subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> cTypeShape = cType.getShape();
- gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
- unpackOp(transformedOperands.opA());
- unpackOp(transformedOperands.opB());
- unpackOp(transformedOperands.opC());
+ unpackOp(adaptor.opA());
+ unpackOp(adaptor.opB());
+ unpackOp(adaptor.opC());
if (cType.getElementType().isF16()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF16F16M16N16K16Op>(
- op, transformedOperands.opC().getType(), unpackedOps);
+ op, adaptor.opC().getType(), unpackedOps);
return success();
}
@@ -338,7 +335,7 @@ struct WmmaMmaOpToNVVMLowering
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF32F32M16N16K16Op>(
- op, transformedOperands.opC().getType(), unpackedOps);
+ op, adaptor.opC().getType(), unpackedOps);
return success();
}
@@ -356,13 +353,13 @@ struct WmmaConstantOpToNVVMLowering
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands,
- rewriter)))
+ if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
+ adaptor.getOperands(), rewriter)))
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
- Value cst = operands[0];
+ Value cst = adaptor.getOperands()[0];
LLVM::LLVMStructType type = convertMMAToLLVMType(
subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
// If the element type is a vector create a vector from the operand.
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 205b2bfd6f652..713890425acc6 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -73,7 +73,7 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
+ matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
@@ -81,7 +81,6 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
// Fill in an aggregate value of the descriptor.
- RangeOpAdaptor adaptor(operands);
Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
rewriter.getI64ArrayAttr(0));
@@ -101,9 +100,9 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
+ matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 07364c0817b01..3c476f25a0ed9 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -34,10 +34,9 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
+ matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- math::ExpM1Op::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
+ auto operandType = adaptor.operand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
@@ -56,7 +55,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
- auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
+ auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
return success();
}
@@ -66,7 +65,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
@@ -88,10 +87,9 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
+ matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- math::Log1pOp::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
+ auto operandType = adaptor.operand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
@@ -111,7 +109,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
- transformed.operand());
+ adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
return success();
}
@@ -121,7 +119,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
@@ -143,10 +141,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
+ matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- math::RsqrtOp::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
+ auto operandType = adaptor.operand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
@@ -165,7 +162,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
+ auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return success();
}
@@ -175,7 +172,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return failure();
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 85c05f04b07f3..e43be68fb15d2 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -194,7 +194,7 @@ struct AllocaScopeOpLowering
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
Location loc = allocaScopeOp.getLoc();
@@ -249,10 +249,9 @@ struct AssumeAlignmentOpLowering
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
+ matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::AssumeAlignmentOp::Adaptor transformed(operands);
- Value memref = transformed.memref();
+ Value memref = adaptor.memref();
unsigned alignment = op.alignment();
auto loc = op.getLoc();
@@ -293,14 +292,11 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
LogicalResult
- matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
+ matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- assert(operands.size() == 1 && "dealloc takes one operand");
- memref::DeallocOp::Adaptor transformed(operands);
-
// Insert the `free` declaration if it is not already present.
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
- MemRefDescriptor memref(transformed.memref());
+ MemRefDescriptor memref(adaptor.memref());
Value casted = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op.getLoc()));
@@ -316,18 +312,20 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type operandType = dimOp.source().getType();
if (operandType.isa<UnrankedMemRefType>()) {
- rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
- operandType, dimOp, operands, rewriter)});
+ rewriter.replaceOp(
+ dimOp, {extractSizeOfUnrankedMemRef(
+ operandType, dimOp, adaptor.getOperands(), rewriter)});
return success();
}
if (operandType.isa<MemRefType>()) {
- rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
- operandType, dimOp, operands, rewriter)});
+ rewriter.replaceOp(
+ dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
+ adaptor.getOperands(), rewriter)});
return success();
}
llvm_unreachable("expected MemRefType or UnrankedMemRefType");
@@ -335,10 +333,9 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
private:
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
- memref::DimOp::Adaptor transformed(operands);
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
auto scalarMemRefType =
@@ -348,7 +345,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// Extract pointer to the underlying ranked descriptor and bitcast it to a
// memref<element_type> descriptor pointer to minimize the number of GEP
// operations.
- UnrankedMemRefDescriptor unrankedDesc(transformed.source());
+ UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc,
@@ -369,7 +366,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
- loc, createIndexConstant(rewriter, loc, 1), transformed.index());
+ loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
ValueRange({idxPlusOne}));
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
@@ -386,26 +383,26 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
}
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
- ArrayRef<Value> operands,
+ OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
- memref::DimOp::Adaptor transformed(operands);
+
// Take advantage if index is constant.
MemRefType memRefType = operandType.cast<MemRefType>();
if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
int64_t i = index.getValue();
if (memRefType.isDynamicDim(i)) {
// extract dynamic size from the memref descriptor.
- MemRefDescriptor descriptor(transformed.source());
+ MemRefDescriptor descriptor(adaptor.source());
return descriptor.size(rewriter, loc, i);
}
// Use constant for static size.
int64_t dimSize = memRefType.getDimSize(i);
return createIndexConstant(rewriter, loc, dimSize);
}
- Value index = transformed.index();
+ Value index = adaptor.index();
int64_t rank = memRefType.getRank();
- MemRefDescriptor memrefDescriptor(transformed.source());
+ MemRefDescriptor memrefDescriptor(adaptor.source());
return memrefDescriptor.size(rewriter, loc, index, rank);
}
};
@@ -432,7 +429,7 @@ struct GlobalMemrefOpLowering
using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
+ matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = global.type();
if (!isConvertibleAndHasIdentityMaps(type))
@@ -536,14 +533,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::LoadOp::Adaptor transformed(operands);
auto type = loadOp.getMemRefType();
- Value dataPtr =
- getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr = getStridedElementPtr(
+ loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
return success();
}
@@ -555,16 +550,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
- memref::StoreOp::Adaptor transformed(operands);
- Value dataPtr =
- getStridedElementPtr(op.getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
- dataPtr);
+ Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
+ adaptor.indices(), rewriter);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
return success();
}
};
@@ -575,14 +567,13 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::PrefetchOp::Adaptor transformed(operands);
auto type = prefetchOp.getMemRefType();
auto loc = prefetchOp.getLoc();
- Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
+ adaptor.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
@@ -627,10 +618,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
: failure();
}
- void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
+ void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::CastOp::Adaptor transformed(operands);
-
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
@@ -638,7 +627,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// For ranked/ranked case, just keep the original descriptor.
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
- return rewriter.replaceOp(memRefCastOp, {transformed.source()});
+ return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
// Casting ranked to unranked memref type
@@ -649,7 +638,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
- loc, transformed.source(), rewriter);
+ loc, adaptor.source(), rewriter);
// voidptr = BitCastOp srcType* to void*
auto voidPtr =
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
@@ -671,7 +660,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// Casting from unranked type to ranked.
// The operation is assumed to be doing a correct cast. If the destination
// type mismatches the unranked the type, it is undefined behavior.
- UnrankedMemRefDescriptor memRefDesc(transformed.source());
+ UnrankedMemRefDescriptor memRefDesc(adaptor.source());
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
@@ -693,10 +682,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
+ matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- memref::CopyOp::Adaptor adaptor(operands);
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto targetType = op.target().getType().cast<BaseMemRefType>();
@@ -799,10 +787,8 @@ struct MemRefReinterpretCastOpLowering
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::ReinterpretCastOp::Adaptor adaptor(operands,
- castOp->getAttrDictionary());
Type srcType = castOp.source().getType();
Value descriptor;
@@ -867,17 +853,15 @@ struct MemRefReshapeOpLowering
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto *op = reshapeOp.getOperation();
- memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
Type srcType = reshapeOp.source().getType();
Value descriptor;
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
adaptor, &descriptor)))
return failure();
- rewriter.replaceOp(op, {descriptor});
+ rewriter.replaceOp(reshapeOp, {descriptor});
return success();
}
@@ -1152,7 +1136,7 @@ class ReassociatingReshapeOpConversion
using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
LogicalResult
- matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
+ matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType dstType = reshapeOp.getResultType();
MemRefType srcType = reshapeOp.getSrcType();
@@ -1168,7 +1152,6 @@ class ReassociatingReshapeOpConversion
reshapeOp, "failed to get stride and offset exprs");
}
- ReshapeOpAdaptor adaptor(operands);
MemRefDescriptor srcDesc(adaptor.src());
Location loc = reshapeOp->getLoc();
auto dstDesc = MemRefDescriptor::undef(
@@ -1217,7 +1200,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = subViewOp.getLoc();
@@ -1249,9 +1232,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
return failure();
// Create the descriptor.
- if (!LLVM::isCompatibleType(operands.front().getType()))
+ if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
return failure();
- MemRefDescriptor sourceMemRef(operands.front());
+ MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.
@@ -1296,7 +1279,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
Value offset =
// TODO: need OpFoldResult ODS adaptor to clean this up.
subViewOp.isDynamicOffset(i)
- ? operands[subViewOp.getIndexOfDynamicOffset(i)]
+ ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
: rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType,
rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
@@ -1346,7 +1329,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
// TODO: need OpFoldResult ODS adaptor to clean this up.
size =
subViewOp.isDynamicSize(i)
- ? operands[subViewOp.getIndexOfDynamicSize(i)]
+ ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
: rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType,
rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
@@ -1354,12 +1337,13 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
stride = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
} else {
- stride = subViewOp.isDynamicStride(i)
- ? operands[subViewOp.getIndexOfDynamicStride(i)]
- : rewriter.create<LLVM::ConstantOp>(
- loc, llvmIndexType,
- rewriter.getI64IntegerAttr(
- subViewOp.getStaticStride(i)));
+ stride =
+ subViewOp.isDynamicStride(i)
+ ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
+ : rewriter.create<LLVM::ConstantOp>(
+ loc, llvmIndexType,
+ rewriter.getI64IntegerAttr(
+ subViewOp.getStaticStride(i)));
stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
}
}
@@ -1385,10 +1369,9 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = transposeOp.getLoc();
- memref::TransposeOpAdaptor adaptor(operands);
MemRefDescriptor viewMemRef(adaptor.in());
// No permutation, early exit.
@@ -1465,10 +1448,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
}
LogicalResult
- matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
+ matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = viewOp.getLoc();
- memref::ViewOpAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
diff --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
index 511af47dd5389..fdff881a57211 100644
--- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
@@ -79,7 +79,7 @@ class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Op op, ArrayRef<Value> operands,
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &builder) const override {
Location loc = op.getLoc();
TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
@@ -87,8 +87,8 @@ class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
unsigned numDataOperand = op.getNumDataOperands();
// Keep the non data operands without modification.
- auto nonDataOperands =
- operands.take_front(operands.size() - numDataOperand);
+ auto nonDataOperands = adaptor.getOperands().take_front(
+ adaptor.getOperands().size() - numDataOperand);
SmallVector<Value> convertedOperands;
convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index e0b3ed8f1fc45..0e6010c2ddb6e 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -29,10 +29,10 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
+ matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
- curOp->getAttrs());
+ auto newOp = rewriter.create<OpType>(
+ curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(),
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index eef6f2c0b6f7b..2f54a38c50f52 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -157,7 +157,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *op = launchOp.getOperation();
MLIRContext *context = rewriter.getContext();
@@ -206,7 +206,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
Location loc = launchOp.getLoc();
SmallVector<CopyInfo, 4> copyInfo;
auto numKernelOperands = launchOp.getNumKernelOperands();
- auto kernelOperands = operands.take_back(numKernelOperands);
+ auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
for (auto operand : llvm::enumerate(kernelOperands)) {
// Check if the kernel's operand is a ranked memref.
auto memRefType = launchOp.getKernelOperand(operand.index())
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 67583f9a74795..ba942bbc1846d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -178,7 +178,7 @@ class VectorBitCastOpConversion
using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType resultTy = bitCastOp.getType();
@@ -186,7 +186,7 @@ class VectorBitCastOpConversion
return failure();
Type newResultTy = typeConverter->convertType(resultTy);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
- operands[0]);
+ adaptor.getOperands()[0]);
return success();
}
};
@@ -199,9 +199,8 @@ class VectorMatmulOpConversion
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
matmulOp, typeConverter->convertType(matmulOp.res().getType()),
adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
@@ -218,9 +217,8 @@ class VectorFlatTransposeOpConversion
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.res().getType()),
adaptor.matrix(), transOp.rows(), transOp.columns());
@@ -270,7 +268,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
+ matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
+ typename LoadOrStoreOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
@@ -278,7 +277,6 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
return failure();
auto loc = loadOrStoreOp->getLoc();
- auto adaptor = LoadOrStoreOpAdaptor(operands);
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
// Resolve alignment.
@@ -306,10 +304,9 @@ class VectorGatherOpConversion
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
+ matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = gather->getLoc();
- auto adaptor = vector::GatherOpAdaptor(operands);
MemRefType memRefType = gather.getMemRefType();
// Resolve alignment.
@@ -341,10 +338,9 @@ class VectorScatterOpConversion
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- auto adaptor = vector::ScatterOpAdaptor(operands);
MemRefType memRefType = scatter.getMemRefType();
// Resolve alignment.
@@ -376,10 +372,9 @@ class VectorExpandLoadOpConversion
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = expand->getLoc();
- auto adaptor = vector::ExpandLoadOpAdaptor(operands);
MemRefType memRefType = expand.getMemRefType();
// Resolve address.
@@ -400,10 +395,9 @@ class VectorCompressStoreOpConversion
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
+ matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = compress->getLoc();
- auto adaptor = vector::CompressStoreOpAdaptor(operands);
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
@@ -426,42 +420,43 @@ class VectorReductionOpConversion
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
- matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter->convertType(eltType);
+ Value operand = adaptor.getOperands()[0];
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
+ llvmType, operand);
else if (kind == "mul")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
+ llvmType, operand);
else if (kind == "min" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
- reductionOp, llvmType, operands[0]);
+ reductionOp, llvmType, operand);
else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
- reductionOp, llvmType, operands[0]);
+ reductionOp, llvmType, operand);
else if (kind == "max" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
- reductionOp, llvmType, operands[0]);
+ reductionOp, llvmType, operand);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
- reductionOp, llvmType, operands[0]);
+ reductionOp, llvmType, operand);
else if (kind == "and")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
+ llvmType, operand);
else if (kind == "or")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
+ llvmType, operand);
else if (kind == "xor")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
+ llvmType, operand);
else
return failure();
return success();
@@ -473,29 +468,30 @@ class VectorReductionOpConversion
// Floating-point reductions: add/mul/min/max
if (kind == "add") {
// Optional accumulator (or zero).
- Value acc = operands.size() > 1 ? operands[1]
- : rewriter.create<LLVM::ConstantOp>(
- reductionOp->getLoc(), llvmType,
- rewriter.getZeroAttr(eltType));
+ Value acc = adaptor.getOperands().size() > 1
+ ? adaptor.getOperands()[1]
+ : rewriter.create<LLVM::ConstantOp>(
+ reductionOp->getLoc(), llvmType,
+ rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
- reductionOp, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
// Optional accumulator (or one).
- Value acc = operands.size() > 1
- ? operands[1]
+ Value acc = adaptor.getOperands().size() > 1
+ ? adaptor.getOperands()[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
- reductionOp, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "min")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
+ llvmType, operand);
else if (kind == "max")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
- reductionOp, llvmType, operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
+ llvmType, operand);
else
return failure();
return success();
@@ -511,10 +507,9 @@ class VectorShuffleOpConversion
using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = shuffleOp->getLoc();
- auto adaptor = vector::ShuffleOpAdaptor(operands);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
@@ -573,10 +568,8 @@ class VectorExtractElementOpConversion
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp,
- ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::ExtractElementOpAdaptor(operands);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
@@ -596,10 +589,9 @@ class VectorExtractOpConversion
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = extractOp->getLoc();
- auto adaptor = vector::ExtractOpAdaptor(operands);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
@@ -667,9 +659,8 @@ class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::FMAOpAdaptor(operands);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
@@ -685,9 +676,8 @@ class VectorInsertElementOpConversion
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::InsertElementOpAdaptor(operands);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);
@@ -708,10 +698,9 @@ class VectorInsertOpConversion
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = insertOp->getLoc();
- auto adaptor = vector::InsertOpAdaptor(operands);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
@@ -984,7 +973,7 @@ class VectorTypeCastOpConversion
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
@@ -997,10 +986,10 @@ class VectorTypeCastOpConversion
return failure();
auto llvmSourceDescriptorTy =
- operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
+ adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
if (!llvmSourceDescriptorTy)
return failure();
- MemRefDescriptor sourceMemRef(operands[0]);
+ MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMStructType>();
@@ -1074,9 +1063,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// TODO: rely solely on libc in future? something else?
//
LogicalResult
- matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
if (typeConverter->convertType(printType) == nullptr)
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 9685faf649c72..cc54b7f8bd2ed 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -30,7 +30,7 @@ using namespace mlir;
using namespace mlir::vector;
static LogicalResult replaceTransferOpWithMubuf(
- ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter, ValueRange operands,
LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
Value &glc, Value &slc) {
@@ -40,7 +40,7 @@ static LogicalResult replaceTransferOpWithMubuf(
}
static LogicalResult replaceTransferOpWithMubuf(
- ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter, ValueRange operands,
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
Value &glc, Value &slc) {
@@ -62,10 +62,8 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
+ matchAndRewrite(ConcreteOp xferOp, typename ConcreteOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary());
-
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)
return failure();
@@ -139,8 +137,8 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
loc, toLLVMTy(i32Ty),
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
return replaceTransferOpWithMubuf(
- rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
- dwordConfig, int32Zero, int32Zero, int1False, int1False);
+ rewriter, adaptor.getOperands(), *this->getTypeConverter(), loc, xferOp,
+ vecTy, dwordConfig, int32Zero, int32Zero, int1False, int1False);
}
};
} // end anonymous namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index a8cf8c17c6ca8..007a1c01ccc7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -244,9 +244,8 @@ class InsertSliceOpConverter
using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands,
+ matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
Value sourceMemRef = adaptor.source();
assert(sourceMemRef.getType().isa<MemRefType>());
@@ -273,12 +272,10 @@ class VectorTransferReadOpConverter
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (readOp.getShapedType().isa<MemRefType>())
return failure();
- vector::TransferReadOp::Adaptor adaptor(operands,
- readOp->getAttrDictionary());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
@@ -293,12 +290,10 @@ class VectorTransferWriteOpConverter
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
+ matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (writeOp.getShapedType().isa<MemRefType>())
return failure();
- vector::TransferWriteOp::Adaptor adaptor(operands,
- writeOp->getAttrDictionary());
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
adaptor.permutation_map(),
diff --git a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
index 873a9a8cebb74..43f57fecd8e9e 100644
--- a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
@@ -25,7 +25,7 @@ class TestTypeProducerOpConverter
test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(test::TestTypeProducerOp op, ArrayRef<Value> operands,
+ matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
return success();
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d51cf5ea1824c..4be8cb1c0345a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -783,7 +783,7 @@ struct OneVResOneVOperandOp1Converter
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
LogicalResult
- matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
+ matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto origOps = op.getOperands();
assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
@@ -878,7 +878,7 @@ struct TestTypeConversionProducer
: public OpConversionPattern<TestTypeProducerOp> {
using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
+ matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Type resultType = op.getType();
if (resultType.isa<FloatType>())
@@ -900,7 +900,7 @@ struct TestSignatureConversionUndo
using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
+ matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
(void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
return failure();
@@ -914,9 +914,10 @@ struct TestTypeConsumerForward
using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
+ matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
+ rewriter.updateRootInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
@@ -1022,7 +1023,7 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
+ matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Block &firstBlock = op.body().front();
Operation *branchOp = firstBlock.getTerminator();
@@ -1065,7 +1066,7 @@ struct TestMergeSingleBlockOps
SingleBlockImplicitTerminatorOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
+ matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SingleBlockImplicitTerminatorOp parentOp =
op->getParentOfType<SingleBlockImplicitTerminatorOp>();
More information about the Mlir-commits
mailing list