[Mlir-commits] [mlir] acb69f3 - [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type.
Christian Sigg
llvmlistbot at llvm.org
Sat Nov 28 04:09:36 PST 2020
Author: Christian Sigg
Date: 2020-11-28T13:09:25+01:00
New Revision: acb69f3b7c83f411c08b77d75f2e812faf3cb83f
URL: https://github.com/llvm/llvm-project/commit/acb69f3b7c83f411c08b77d75f2e812faf3cb83f
DIFF: https://github.com/llvm/llvm-project/commit/acb69f3b7c83f411c08b77d75f2e812faf3cb83f.diff
LOG: [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type.
Reviewed By: herhut, ftynse
Differential Revision: https://reviews.llvm.org/D92111
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/lib/Transforms/TestConvertCallOp.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 919a93ac84a2..70db4c1510bf 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -564,14 +564,47 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
-template <typename OpTy>
+template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
- : ConvertToLLVMPattern(OpTy::getOperationName(),
+ : ConvertToLLVMPattern(SourceOp::getOperationName(),
&typeConverter.getContext(), typeConverter,
benefit) {}
+
+ /// Wrappers around the RewritePattern methods that pass the derived op type.
+ void rewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+ LogicalResult match(Operation *op) const final {
+ return match(cast<SourceOp>(op));
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+
+ /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// overridden by the derived pattern class.
+ virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ llvm_unreachable("must override rewrite or matchAndRewrite");
+ }
+ virtual LogicalResult match(SourceOp op) const {
+ llvm_unreachable("must override match or matchAndRewrite");
+ }
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (succeeded(match(op))) {
+ rewrite(op, operands, rewriter);
+ return success();
+ }
+ return failure();
+ }
};
namespace LLVM {
@@ -604,7 +637,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
operands, this->typeConverter,
@@ -621,7 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index d625db95e976..cb7644cb7202 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -163,7 +163,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -205,7 +205,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -219,7 +219,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern
private:
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -251,7 +251,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
Location loc, OpBuilder &builder) const;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
llvm::SmallString<32> gpuBinaryAnnotation;
@@ -321,14 +321,15 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
}
LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
+ gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
+ auto *op = hostRegisterOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
return failure();
Location loc = op->getLoc();
- auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
+ auto memRefType = hostRegisterOp.value().getType();
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
@@ -412,19 +413,19 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
// afterwards. In case this isn't correct, we will get a runtime error.
// Eventually, we will have a pass that guarantees this property.
LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
+ gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (cast<gpu::WaitOp>(op).asyncToken())
- return rewriter.notifyMatchFailure(op, "Cannot convert async op.");
+ if (waitOp.asyncToken())
+ return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
- Location loc = op->getLoc();
+ Location loc = waitOp.getLoc();
for (auto asyncDependency : operands)
streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
for (auto asyncDependency : operands)
streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});
- rewriter.eraseOp(op);
+ rewriter.eraseOp(waitOp);
return success();
}
@@ -435,23 +436,23 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
// assumes that there is no other use between the definition and this op, and
// the plan is to have a pass that guarantees this property.
LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
+ gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (!cast<gpu::WaitOp>(op).asyncToken())
- return rewriter.notifyMatchFailure(op, "Can only convert async op.");
+ if (!waitOp.asyncToken())
+ return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
- Location loc = op->getLoc();
+ Location loc = waitOp.getLoc();
auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
- for (auto pair : llvm::zip(op->getOperands(), operands)) {
+ for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
auto token = std::get<0>(pair);
if (auto *defOp = token.getDefiningOp()) {
rewriter.setInsertionPointAfter(defOp);
} else {
// If we can't find the defining op, we record the event at block start,
// which is late and therefore misses parallelism, but still valid.
- rewriter.setInsertionPointToStart(op->getBlock());
+ rewriter.setInsertionPointToStart(waitOp.getOperation()->getBlock());
}
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
auto stream = std::get<1>(pair);
@@ -464,7 +465,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
for (auto event : events)
eventDestroyCallBuilder.create(loc, rewriter, {event});
- rewriter.replaceOp(op, {stream});
+ rewriter.replaceOp(waitOp, {stream});
return success();
}
@@ -564,23 +565,21 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
// If the op is async, the stream corresponds to the (single) async dependency
// as well as the async token the op produces.
LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
+ gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
return failure();
- auto launchOp = cast<gpu::LaunchFuncOp>(op);
-
if (launchOp.asyncDependencies().size() > 1)
return rewriter.notifyMatchFailure(
- op, "Cannot convert with more than one async dependency.");
+ launchOp, "Cannot convert with more than one async dependency.");
// Fail when the synchronous version of the op has async dependencies. The
// lowering destroys the stream, and we do not want to check that there is no
// use of the stream after this op.
if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
return rewriter.notifyMatchFailure(
- op, "Cannot convert non-async op with async dependencies.");
+ launchOp, "Cannot convert non-async op with async dependencies.");
Location loc = launchOp.getLoc();
@@ -612,7 +611,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
- auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary());
+ auto adaptor = gpu::LaunchFuncOpAdaptor(
+ operands, launchOp.getOperation()->getAttrDictionary());
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
@@ -620,23 +620,24 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
- launchKernelCallBuilder.create(
- loc, rewriter,
- {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
- launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
- launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams,
- /*extra=*/nullpointer});
+ launchKernelCallBuilder.create(loc, rewriter,
+ {function.getResult(0), launchOp.gridSizeX(),
+ launchOp.gridSizeY(), launchOp.gridSizeZ(),
+ launchOp.blockSizeX(), launchOp.blockSizeY(),
+ launchOp.blockSizeZ(),
+ /*sharedMemBytes=*/zero, stream, kernelParams,
+ /*extra=*/nullpointer});
if (launchOp.asyncToken()) {
// Async launch: make dependent ops use the same stream.
- rewriter.replaceOp(op, {stream});
+ rewriter.replaceOp(launchOp, {stream});
} else {
// Synchronize with host and destroy stream. This must be the stream created
// above (with no other uses) because we check that the synchronous version
// does not have any async dependencies.
streamSynchronizeCallBuilder.create(loc, rewriter, stream);
streamDestroyCallBuilder.create(loc, rewriter, stream);
- rewriter.eraseOp(op);
+ rewriter.eraseOp(launchOp);
}
moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index c34198e48d6f..525a5be24485 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -151,9 +151,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op);
+ auto *op = launchOp.getOperation();
MLIRContext *context = rewriter.getContext();
auto module = launchOp.getParentOfType<ModuleOp>();
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 49942995fc78..c19f53c4e999 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1396,10 +1396,8 @@ struct FuncOpConversion : public FuncOpConversionBase {
: FuncOpConversionBase(converter) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto funcOp = cast<FuncOp>(op);
-
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
return failure();
@@ -1407,14 +1405,14 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (typeConverter.getOptions().emitCWrappers ||
funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
if (newFuncOp.isExternal())
- wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
+ wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp,
newFuncOp);
else
- wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp,
+ wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp,
newFuncOp);
}
- rewriter.eraseOp(op);
+ rewriter.eraseOp(funcOp);
return success();
}
};
@@ -1425,10 +1423,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
using FuncOpConversionBase::FuncOpConversionBase;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto funcOp = cast<FuncOp>(op);
-
// Store the type of memref-typed arguments before the conversion so that we
// can promote them to MemRef descriptor at the beginning of the function.
SmallVector<Type, 8> oldArgTypes =
@@ -1438,7 +1434,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
if (!newFuncOp)
return failure();
if (newFuncOp.getBody().empty()) {
- rewriter.eraseOp(op);
+ rewriter.eraseOp(funcOp);
return success();
}
@@ -1471,7 +1467,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
// MemRef descriptor instructions. We may want to have a utility in the
// rewriter to properly handle this use case.
- Location loc = op->getLoc();
+ Location loc = funcOp.getLoc();
auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
@@ -1480,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
rewriter.replaceOp(placeholder, {desc});
}
- rewriter.eraseOp(op);
+ rewriter.eraseOp(funcOp);
return success();
}
};
@@ -1711,13 +1707,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = op.getLoc();
AssertOp::Adaptor transformed(operands);
// Insert the `abort` declaration if necessary.
- auto module = op->getParentOfType<ModuleOp>();
+ auto module = op.getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
@@ -1754,13 +1750,13 @@ struct CreateComplexOpLowering
using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(CreateComplexOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto complexOp = cast<CreateComplexOp>(op);
CreateComplexOp::Adaptor transformed(operands);
// Pack real and imaginary part in a complex number struct.
- auto loc = op->getLoc();
+ auto loc = op.getLoc();
auto structType = typeConverter.convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
complexStruct.setReal(rewriter, loc, transformed.real());
@@ -1775,13 +1771,13 @@ struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ReOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ReOp::Adaptor transformed(operands);
// Extract real part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
- Value real = complexStruct.real(rewriter, op->getLoc());
+ Value real = complexStruct.real(rewriter, op.getLoc());
rewriter.replaceOp(op, real);
return success();
@@ -1792,13 +1788,13 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ImOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ImOp::Adaptor transformed(operands);
// Extract imaginary part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
- Value imaginary = complexStruct.imaginary(rewriter, op->getLoc());
+ Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
rewriter.replaceOp(op, imaginary);
return success();
@@ -1833,9 +1829,8 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+ matchAndRewrite(AddCFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto op = cast<AddCFOp>(operation);
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
@@ -1861,9 +1856,8 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+ matchAndRewrite(SubCFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto op = cast<SubCFOp>(operation);
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
@@ -1889,9 +1883,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+ matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto op = cast<ConstantOp>(operation);
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
auto type = typeConverter.convertType(op.getResult().getType())
@@ -2284,10 +2277,9 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using Base = ConvertOpToLLVMPattern<CallOpType>;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename CallOpType::Adaptor transformed(operands);
- auto callOp = cast<CallOpType>(op);
// Pack the result types into a struct.
Type packedResult = nullptr;
@@ -2301,10 +2293,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
}
auto promoted = this->typeConverter.promoteOperands(
- op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
+ callOp.getLoc(), /*opOperands=*/callOp.getOperation()->getOperands(),
+ operands, rewriter);
auto newOp = rewriter.create<LLVM::CallOp>(
- op->getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
- promoted, op->getAttrs());
+ callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
+ promoted, callOp.getAttrs());
SmallVector<Value, 4> results;
if (numResults < 2) {
@@ -2315,9 +2308,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto type = this->typeConverter.convertType(op->getResult(i).getType());
+ auto type =
+ this->typeConverter.convertType(callOp.getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), type, newOp.getOperation()->getResult(0),
+ callOp.getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
}
}
@@ -2327,16 +2321,16 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
// descriptors.
assert(results.size() == resultTypes.size() &&
"The number of arguments and types doesn't match");
- this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(),
- resultTypes, results);
- } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(),
+ this->typeConverter.promoteBarePtrsToDescriptors(
+ rewriter, callOp.getLoc(), resultTypes, results);
+ } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
this->typeConverter, resultTypes,
results,
/*toDynamic=*/false))) {
return failure();
}
- rewriter.replaceOp(op, results);
+ rewriter.replaceOp(callOp, results);
return success();
}
};
@@ -2359,18 +2353,18 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
: ConvertOpToLLVMPattern<DeallocOp>(converter) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
DeallocOp::Adaptor transformed(operands);
// Insert the `free` declaration if it is not already present.
auto freeFunc =
- op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
+ op.getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
if (!freeFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(
- op->getParentOfType<ModuleOp>().getBody());
+ op.getParentOfType<ModuleOp>().getBody());
freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
@@ -2379,8 +2373,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
MemRefDescriptor memref(transformed.memref());
Value casted = rewriter.create<LLVM::BitcastOp>(
- op->getLoc(), getVoidPtrType(),
- memref.allocatedPtr(rewriter, op->getLoc()));
+ op.getLoc(), getVoidPtrType(),
+ memref.allocatedPtr(rewriter, op.getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
return success();
@@ -2410,9 +2404,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto global = cast<GlobalMemrefOp>(op);
MemRefType type = global.type().cast<MemRefType>();
if (!isSupportedMemRefType(type))
return failure();
@@ -2434,7 +2427,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
}
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
- op, arrayTy, global.constant(), linkage, global.sym_name(),
+ global, arrayTy, global.constant(), linkage, global.sym_name(),
initialValue, type.getMemorySpace());
return success();
}
@@ -2491,7 +2484,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RsqrtOp::Adaptor transformed(operands);
auto operandType =
@@ -2500,8 +2493,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
if (!operandType)
return failure();
- auto loc = op->getLoc();
- auto resultType = *op->result_type_begin();
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
@@ -2524,7 +2517,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
return failure();
return handleMultidimensionalVectors(
- op, operands, typeConverter,
+ op.getOperation(), operands, typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
@@ -2543,8 +2536,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
- LogicalResult match(Operation *op) const override {
- auto memRefCastOp = cast<MemRefCastOp>(op);
+ LogicalResult match(MemRefCastOp memRefCastOp) const override {
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
@@ -2568,19 +2560,18 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
: failure();
}
- void rewrite(Operation *op, ArrayRef<Value> operands,
+ void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memRefCastOp = cast<MemRefCastOp>(op);
MemRefCastOp::Adaptor transformed(operands);
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
auto targetStructType = typeConverter.convertType(memRefCastOp.getType());
- auto loc = op->getLoc();
+ auto loc = memRefCastOp.getLoc();
// For ranked/ranked case, just keep the original descriptor.
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
- return rewriter.replaceOp(op, {transformed.source()});
+ return rewriter.replaceOp(memRefCastOp, {transformed.source()});
if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
// Casting ranked to unranked memref type
@@ -2607,7 +2598,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, voidptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
- rewriter.replaceOp(op, (Value)memRefDesc);
+ rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
// Casting from unranked type to ranked.
@@ -2625,7 +2616,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
.getResult();
// struct = LoadOp castPtr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
- rewriter.replaceOp(op, loadOp.getResult());
+ rewriter.replaceOp(memRefCastOp, loadOp.getResult());
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
}
@@ -2680,17 +2671,17 @@ struct MemRefReinterpretCastOpLowering
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto castOp = cast<MemRefReinterpretCastOp>(op);
- MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
+ MemRefReinterpretCastOp::Adaptor adaptor(
+ operands, castOp.getOperation()->getAttrDictionary());
Type srcType = castOp.source().getType();
Value descriptor;
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
adaptor, &descriptor)))
return failure();
- rewriter.replaceOp(op, {descriptor});
+ rewriter.replaceOp(castOp, {descriptor});
return success();
}
@@ -2748,10 +2739,9 @@ struct MemRefReshapeOpLowering
using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto reshapeOp = cast<MemRefReshapeOp>(op);
-
+ auto *op = reshapeOp.getOperation();
MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
Type srcType = reshapeOp.source().getType();
@@ -2898,15 +2888,14 @@ struct DialectCastOpLowering
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto castOp = cast<LLVM::DialectCastOp>(op);
LLVM::DialectCastOp::Adaptor transformed(operands);
if (transformed.in().getType() !=
typeConverter.convertType(castOp.getType())) {
return failure();
}
- rewriter.replaceOp(op, transformed.in());
+ rewriter.replaceOp(castOp, transformed.in());
return success();
}
};
@@ -2917,19 +2906,18 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto dimOp = cast<DimOp>(op);
Type operandType = dimOp.memrefOrTensor().getType();
if (operandType.isa<UnrankedMemRefType>()) {
- rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp,
- operands, rewriter)});
+ rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
+ operandType, dimOp, operands, rewriter)});
return success();
}
if (operandType.isa<MemRefType>()) {
- rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp,
- operands, rewriter)});
+ rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
+ operandType, dimOp, operands, rewriter)});
return success();
}
return failure();
@@ -3006,10 +2994,10 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(RankOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- Type operandType = cast<RankOp>(op).memrefOrTensor().getType();
+ Location loc = op.getLoc();
+ Type operandType = op.memrefOrTensor().getType();
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
@@ -3033,8 +3021,8 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
using Base = LoadStoreOpLowering<Derived>;
- LogicalResult match(Operation *op) const override {
- MemRefType type = cast<Derived>(op).getMemRefType();
+ LogicalResult match(Derived op) const override {
+ MemRefType type = op.getMemRefType();
return isSupportedMemRefType(type) ? success() : failure();
}
};
@@ -3045,16 +3033,15 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loadOp = cast<LoadOp>(op);
LoadOp::Adaptor transformed(operands);
auto type = loadOp.getMemRefType();
Value dataPtr =
- getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+ getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
transformed.indices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
return success();
}
};
@@ -3065,13 +3052,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto type = cast<StoreOp>(op).getMemRefType();
+ auto type = op.getMemRefType();
StoreOp::Adaptor transformed(operands);
Value dataPtr =
- getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+ getStridedElementPtr(op.getLoc(), type, transformed.memref(),
transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
@@ -3085,29 +3072,26 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto prefetchOp = cast<PrefetchOp>(op);
PrefetchOp::Adaptor transformed(operands);
auto type = prefetchOp.getMemRefType();
+ auto loc = prefetchOp.getLoc();
- Value dataPtr =
- getStridedElementPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
+ transformed.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
auto isWrite = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmI32Type,
- rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
+ loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
auto localityHint = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmI32Type,
+ loc, llvmI32Type,
rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
auto isData = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmI32Type,
- rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
+ loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
- rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
+ rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
localityHint, isData);
return success();
}
@@ -3121,10 +3105,9 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpAdaptor transformed(operands);
- auto indexCastOp = cast<IndexCastOp>(op);
auto targetType =
this->typeConverter.convertType(indexCastOp.getResult().getType())
@@ -3134,12 +3117,12 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
unsigned sourceBits = sourceType.getIntegerBitWidth();
if (targetBits == sourceBits)
- rewriter.replaceOp(op, transformed.in());
+ rewriter.replaceOp(indexCastOp, transformed.in());
else if (targetBits < sourceBits)
- rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+ rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
transformed.in());
else
- rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+ rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
transformed.in());
return success();
}
@@ -3156,13 +3139,12 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto cmpiOp = cast<CmpIOp>(op);
CmpIOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
- op, typeConverter.convertType(cmpiOp.getResult().getType()),
+ cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -3175,13 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto cmpfOp = cast<CmpFOp>(op);
CmpFOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
- op, typeConverter.convertType(cmpfOp.getResult().getType()),
+ cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
@@ -3243,10 +3224,10 @@ struct OneToOneLLVMTerminatorLowering
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
- op->getAttrs());
+ rewriter.replaceOpWithNewOp<TargetOp>(
+ op, operands, op.getOperation()->getSuccessors(), op.getAttrs());
return success();
}
};
@@ -3261,16 +3242,16 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- unsigned numArguments = op->getNumOperands();
+ Location loc = op.getLoc();
+ unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
if (typeConverter.getOptions().useBarePtrCallConv) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
- for (auto it : llvm::zip(op->getOperands(), operands)) {
+ for (auto it : llvm::zip(op.getOperation()->getOperands(), operands)) {
Type oldTy = std::get<0>(it).getType();
Value newOperand = std::get<1>(it);
if (oldTy.isa<MemRefType>()) {
@@ -3286,26 +3267,26 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
} else {
updatedOperands = llvm::to_vector<4>(operands);
copyUnrankedDescriptors(rewriter, loc, typeConverter,
- op->getOperands().getTypes(), updatedOperands,
+ op.getOperands().getTypes(), updatedOperands,
/*toDynamic=*/true);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
- op->getAttrs());
+ op.getAttrs());
return success();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
- op, TypeRange(), updatedOperands, op->getAttrs());
+ op, TypeRange(), updatedOperands, op.getAttrs());
return success();
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
auto packedType = typeConverter.packFunctionResults(
- llvm::to_vector<4>(op->getOperandTypes()));
+ llvm::to_vector<4>(op.getOperandTypes()));
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
for (unsigned i = 0; i < numArguments; ++i) {
@@ -3314,7 +3295,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
rewriter.getI64ArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
- op->getAttrs());
+ op.getAttrs());
return success();
}
};
@@ -3335,29 +3316,30 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter.convertType(splatOp.getType());
- Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
+ Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)),
+ splatOp.getLoc(),
+ typeConverter.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
auto v = rewriter.create<LLVM::InsertElementOp>(
- op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
+ splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
+ zeroAttrs);
return success();
}
};
@@ -3369,16 +3351,15 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto splatOp = cast<SplatOp>(op);
SplatOp::Adaptor adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
- auto loc = op->getLoc();
+ auto loc = splatOp.getLoc();
auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter);
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
@@ -3409,7 +3390,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
position);
});
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(splatOp, desc);
return success();
}
};
@@ -3431,10 +3412,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto subViewOp = cast<SubViewOp>(op);
+ auto loc = subViewOp.getLoc();
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
auto sourceElementTy =
@@ -3545,7 +3525,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
j--;
}
- rewriter.replaceOp(op, {targetMemRef});
+ rewriter.replaceOp(subViewOp, {targetMemRef});
return success();
}
};
@@ -3562,16 +3542,15 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = transposeOp.getLoc();
TransposeOpAdaptor adaptor(operands);
MemRefDescriptor viewMemRef(adaptor.in());
- auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
- return rewriter.replaceOp(op, {viewMemRef}), success();
+ return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
auto targetMemRef = MemRefDescriptor::undef(
rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
@@ -3596,7 +3575,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
viewMemRef.stride(rewriter, loc, sourcePos));
}
- rewriter.replaceOp(op, {targetMemRef});
+ rewriter.replaceOp(transposeOp, {targetMemRef});
return success();
}
};
@@ -3643,10 +3622,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto viewOp = cast<ViewOp>(op);
+ auto loc = viewOp.getLoc();
ViewOpAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
@@ -3656,14 +3634,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
auto targetDescTy =
typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
- return op->emitWarning("Target descriptor type not converted to LLVM"),
+ return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
- return op->emitWarning("cannot cast to non-strided shape"), failure();
+ return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
assert(offset == 0 && "expected offset to be 0");
// Create the descriptor.
@@ -3695,11 +3673,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
- return rewriter.replaceOp(op, {targetMemRef}), success();
+ return rewriter.replaceOp(viewOp, {targetMemRef}), success();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
- return op->emitWarning("cannot cast to non-contiguous shape"), failure();
+ return viewOp.emitWarning("cannot cast to non-contiguous shape"),
+ failure();
Value stride = nullptr, nextSize = nullptr;
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.
@@ -3712,7 +3691,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
nextSize = size;
}
- rewriter.replaceOp(op, {targetMemRef});
+ rewriter.replaceOp(viewOp, {targetMemRef});
return success();
}
};
@@ -3722,11 +3701,12 @@ struct AssumeAlignmentOpLowering
using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AssumeAlignmentOp::Adaptor transformed(operands);
Value memref = transformed.memref();
- unsigned alignment = cast<AssumeAlignmentOp>(op).alignment();
+ unsigned alignment = op.alignment();
+ auto loc = op.getLoc();
MemRefDescriptor memRefDescriptor(memref);
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
@@ -3741,16 +3721,14 @@ struct AssumeAlignmentOpLowering
// pointer SSA value.
auto intPtrType =
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
- Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0);
- Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType,
- alignment - 1);
- Value ptrValue =
- rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), intPtrType, ptr);
+ Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
+ Value mask =
+ createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
+ Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
rewriter.create<LLVM::AssumeOp>(
- op->getLoc(),
- rewriter.create<LLVM::ICmpOp>(
- op->getLoc(), LLVM::ICmpPredicate::eq,
- rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
+ loc, rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::ICmpPredicate::eq,
+ rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
rewriter.eraseOp(op);
return success();
@@ -3789,9 +3767,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto atomicOp = cast<AtomicRMWOp>(op);
+ if (failed(match(atomicOp)))
+ return failure();
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
@@ -3799,10 +3778,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr =
- getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(),
+ getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
- op, resultType, *maybeKind, dataPtr, adaptor.value(),
+ atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
return success();
}
@@ -3840,11 +3819,10 @@ struct GenericAtomicRMWOpLowering
using Base::Base;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto atomicOp = cast<GenericAtomicRMWOp>(op);
- auto loc = op->getLoc();
+ auto loc = atomicOp.getLoc();
GenericAtomicRMWOp::Adaptor adaptor(operands);
LLVM::LLVMType valueType =
typeConverter.convertType(atomicOp.getResult().getType())
@@ -3908,7 +3886,7 @@ struct GenericAtomicRMWOpLowering
std::next(opsToMoveEnd), rewriter);
// The 'result' of the atomic_rmw op is the newly loaded value.
- rewriter.replaceOp(op, {newLoaded});
+ rewriter.replaceOp(atomicOp, {newLoaded});
return success();
}
diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index a612738c5dcc..61062c7938fe 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -25,7 +25,7 @@ class TestTypeProducerOpConverter
test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(test::TestTypeProducerOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
return success();
More information about the Mlir-commits
mailing list