[Mlir-commits] [mlir] 5e0c3b4 - [mlir][LLVMIR] Clean up the definitions of ReturnOp/CallOp
Jeff Niu
llvmlistbot at llvm.org
Wed Aug 10 21:35:17 PDT 2022
Author: Jeff Niu
Date: 2022-08-11T00:35:02-04:00
New Revision: 5e0c3b4309df8ad74ff096593c3a1dd28f8dd571
URL: https://github.com/llvm/llvm-project/commit/5e0c3b4309df8ad74ff096593c3a1dd28f8dd571
DIFF: https://github.com/llvm/llvm-project/commit/5e0c3b4309df8ad74ff096593c3a1dd28f8dd571.diff
LOG: [mlir][LLVMIR] Clean up the definitions of ReturnOp/CallOp
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 406f838f14606..9de4334a9d70b 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -14,6 +14,7 @@
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
+class CallOpInterface;
namespace LLVM {
namespace detail {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 7a9167c5151f2..ac86e8461d277 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -29,10 +29,9 @@ namespace LLVM {
class LLVMFuncOp;
/// Helper functions to lookup or create the declaration for commonly used
-/// external C function calls. Such ops can then be invoked by creating a CallOp
-/// with the proper arguments via `createLLVMCall`.
-/// The list of functions provided here must be implemented separately (e.g. as
-/// part of a support runtime library or as part of the libc).
+/// external C function calls. The list of functions provided here must be
+/// implemented separately (e.g. as part of a support runtime library or as part
+/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
@@ -58,12 +57,6 @@ LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {});
-/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
-Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
- LLVM::LLVMFuncOp fn,
- ValueRange args = {},
- ArrayRef<Type> resultTypes = {});
-
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 203f35f36c197..84fc59243dc4f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -645,14 +645,16 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
def LLVM_CallOp : LLVM_Op<"call",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
-
-
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
implements this behavior by providing a variadic `call` operation for 0- and
1-result functions. Even though MLIR supports multi-result functions, LLVM
@@ -678,29 +680,20 @@ def LLVM_CallOp : LLVM_Op<"call",
llvm.call %1(%0) : (f32) -> ()
```
}];
+
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
- Variadic<LLVM_Type>,
- DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
- let results = (outs Variadic<LLVM_Type>);
+ Variadic<LLVM_Type>,
+ DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ let results = (outs Optional<LLVM_Type>:$result);
+
let builders = [
- OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
- Type resultType = func.getFunctionType().getReturnType();
- if (!resultType.isa<LLVM::LLVMVoidType>())
- $_state.addTypes(resultType);
- $_state.addAttribute("callee", SymbolRefAttr::get(func));
- $_state.addAttributes(attributes);
- $_state.addOperands(operands);
- }]>,
+ OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
- CArg<"ValueRange", "{}">:$operands), [{
- build($_builder, $_state, results, SymbolRefAttr::get(callee), operands);
- }]>,
+ CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
- CArg<"ValueRange", "{}">:$operands), [{
- build($_builder, $_state, results,
- StringAttr::get($_builder.getContext(), callee), operands);
- }]>];
+ CArg<"ValueRange", "{}">:$args)>
+ ];
+
let hasCustomAssemblyFormat = 1;
}
@@ -878,25 +871,38 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
falseOperands);
}]>, LLVM_TerminatorPassthroughOpBuilder];
}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
- let arguments = (ins Variadic<LLVM_Type>:$args);
+ let arguments = (ins Optional<LLVM_Type>:$arg);
+ let assemblyFormat = "attr-dict ($arg^ `:` type($arg))?";
+
+ let builders = [
+ OpBuilder<(ins "ValueRange":$args), [{
+ build($_builder, $_state, TypeRange(), args);
+ }]>
+ ];
+
+ let hasVerifier = 1;
+
string llvmBuilder = [{
if ($_numOperands != 0)
- builder.CreateRet($args[0]);
+ builder.CreateRet($arg[0]);
else
builder.CreateRetVoid();
}];
-
- let assemblyFormat = "attr-dict ($args^ `:` type($args))?";
- let hasVerifier = 1;
}
-def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
+
+def LLVM_ResumeOp : LLVM_TerminatorOp<"resume"> {
let arguments = (ins LLVM_Type:$value);
string llvmBuilder = [{ builder.CreateResume($value); }];
let assemblyFormat = "$value attr-dict `:` type($value)";
let hasVerifier = 1;
}
-def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
+def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
string llvmBuilder = [{ builder.CreateUnreachable(); }];
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index e7930653aee1a..aa6598c26853d 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -350,8 +350,8 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
- return rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getI64Type(), c);
+ return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
+ rewriter.getI64Type(), c);
};
coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
coroSize =
@@ -365,13 +365,12 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
auto coroAlloc = rewriter.create<LLVM::CallOp>(
- loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
- ValueRange{coroAlign, coroSize});
+ loc, allocFuncOp, ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
- op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
+ op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
return success();
}
@@ -400,8 +399,7 @@ class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
- SymbolRefAttr::get(freeFuncOp),
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
ValueRange(coroMem.getResult()));
return success();
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index e8feb90198c1a..1d7e28291fc23 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -164,7 +164,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
if (resultIsNowArg) {
- rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
+ rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
wrapperFuncOp.getArgument(0));
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
} else {
@@ -265,7 +265,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
if (resultIsNowArg) {
Value result = builder.create<LLVM::LoadOp>(loc, args.front());
- builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
+ builder.create<LLVM::ReturnOp>(loc, result);
} else {
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}
@@ -617,12 +617,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments == 0) {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
- op->getAttrs());
- return success();
- }
- if (numArguments == 1) {
+ if (numArguments <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
@@ -630,13 +625,13 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto packedType = getTypeConverter()->packFunctionResults(
- llvm::to_vector<4>(op.getOperandTypes()));
+ auto packedType =
+ getTypeConverter()->packFunctionResults(op.getOperandTypes());
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
- for (unsigned i = 0; i < numArguments; ++i) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed,
- updatedOperands[i], i);
+ for (auto &it : llvm::enumerate(updatedOperands)) {
+ packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, it.value(),
+ it.index());
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a62191d586787..459180b9d9e4c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -220,7 +220,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
- Value printfDesc = printfBeginCall.getResult(0);
+ Value printfDesc = printfBeginCall.getResult();
// Create a global constant for the format string
unsigned stringNumber = 0;
@@ -259,7 +259,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.args().empty() ? oneI32 : zeroI32});
- printfDesc = appendFormatCall.getResult(0);
+ printfDesc = appendFormatCall.getResult();
// __ockl_printf_append_args takes 7 values per append call
constexpr size_t argsPerAppend = 7;
@@ -293,7 +293,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
- printfDesc = call.getResult(0);
+ printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 2704e1408f769..8756603acf3e6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -482,7 +482,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = this->getElementPtrType(memRefType);
auto stream = adaptor.asyncDependencies().front();
Value allocatedPtr =
- allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
+ allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult();
allocatedPtr =
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
@@ -539,7 +539,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
continue;
auto idx = operand.getOperandNumber();
auto stream = adaptor.getOperands()[idx];
- auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, stream});
newOperands[idx] = event;
streams.insert(stream);
@@ -612,8 +612,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// into the stream just after the last use of the original token operand.
auto *defOp = std::get<0>(pair).getDefiningOp();
rewriter.setInsertionPointAfter(defOp);
- auto event =
- eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
eventRecordCallBuilder.create(loc, rewriter, {event, operand});
events.push_back(event);
} else {
@@ -623,7 +622,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
}
}
rewriter.restoreInsertionPoint(insertionPoint);
- auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
for (auto event : events)
streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
for (auto event : events)
@@ -784,11 +783,11 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp.getKernelModuleName().getValue(),
launchOp.getKernelName().getValue(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
- loc, rewriter, {module.getResult(0), kernelName});
+ loc, rewriter, {module.getResult(), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
Value stream =
adaptor.asyncDependencies().empty()
- ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
+ ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
@@ -798,7 +797,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
: zero;
launchKernelCallBuilder.create(
loc, rewriter,
- {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
+ {function.getResult(), adaptor.gridSizeX(), adaptor.gridSizeY(),
adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
/*extra=*/nullpointer});
@@ -814,7 +813,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
streamDestroyCallBuilder.create(loc, rewriter, stream);
rewriter.eraseOp(launchOp);
}
- moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
+ moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 2cf16934acdc0..0168e6e016449 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -60,17 +60,17 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
- auto callOp = rewriter.create<LLVM::CallOp>(
- op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
+ auto callOp =
+ rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
if (resultType == adaptor.getOperands().front().getType()) {
- rewriter.replaceOp(op, {callOp.getResult(0)});
+ rewriter.replaceOp(op, {callOp.getResult()});
return success();
}
Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), adaptor.getOperands().front().getType(),
- callOp.getResult(0));
+ callOp.getResult());
rewriter.replaceOp(op, {truncated});
return success();
}
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 09e3ec41a43f6..3242add0799d1 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -374,7 +374,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
loc, TypeRange{getPointerType()}, kInitVulkan);
// The result of `initVulkan` function is a pointer to Vulkan runtime, we
// need to pass that pointer to each Vulkan runtime call.
- auto vulkanRuntime = initVulkanCall.getResult(0);
+ auto vulkanRuntime = initVulkanCall.getResult();
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 15839695978b7..73bb64d278237 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -273,7 +273,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Value memory =
toDynamic
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
- .getResult(0)
+ .getResult()
: builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
/*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ba92ae9ebfb59..8b3348f039b28 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -71,10 +71,9 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
// descriptor.
Type elementPtrType = this->getElementPtrType(memRefType);
auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
- auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
- getVoidPtrType());
- Value allocatedPtr =
- rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
+ auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+ Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+ results.getResult());
Value alignedPtr = allocatedPtr;
if (alignment) {
@@ -168,11 +167,10 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
Type elementPtrType = this->getElementPtrType(memRefType);
auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
- auto results =
- createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
- getVoidPtrType());
- Value allocatedPtr =
- rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
+ auto results = rewriter.create<LLVM::CallOp>(
+ loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+ Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+ results.getResult());
return std::make_tuple(allocatedPtr, allocatedPtr);
}
@@ -330,8 +328,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
Value casted = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op.getLoc()));
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
- op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, casted);
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 65e26aaed5672..409c513f69ddb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -149,13 +149,3 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-
-Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
- LLVM::LLVMFuncOp fn,
- ValueRange paramTypes,
- ArrayRef<Type> resultTypes) {
- return b
- .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
- paramTypes)
- ->getResults();
-}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9706e4b2e622e..7c4f1a4c827f4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1145,9 +1145,28 @@ ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
}
//===----------------------------------------------------------------------===//
-// Verifying/Printing/parsing for LLVM::CallOp.
+// CallOp
//===----------------------------------------------------------------------===//
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+ StringRef callee, ValueRange args) {
+ build(builder, state, results, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+ StringAttr callee, ValueRange args) {
+ build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
+ ValueRange args) {
+ SmallVector<Type> results;
+ Type resultType = func.getFunctionType().getReturnType();
+ if (!resultType.isa<LLVM::LLVMVoidType>())
+ results.push_back(resultType);
+ build(builder, state, results, SymbolRefAttr::get(func), args, nullptr);
+}
+
CallInterfaceCallable CallOp::getCallableForCallee() {
// Direct call.
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
@@ -1235,8 +1254,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return emitOpError()
<< "expected LLVM function call to produce 0 or 1 result";
- if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
- return emitOpError() << "result type mismatch: " << getResult(0).getType()
+ if (getNumResults() && getResult().getType() != funcType.getReturnType())
+ return emitOpError() << "result type mismatch: " << getResult().getType()
<< " != " << funcType.getReturnType();
return success();
@@ -1608,34 +1627,33 @@ void InsertValueOp::build(OpBuilder &builder, OperationState &state,
}
//===----------------------------------------------------------------------===//
-// Printing, parsing and verification for LLVM::ReturnOp.
+// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
- if (getNumOperands() > 1)
- return emitOpError("expected at most 1 operand");
-
- if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
- Type expectedType = parent.getFunctionType().getReturnType();
- if (expectedType.isa<LLVMVoidType>()) {
- if (getNumOperands() == 0)
- return success();
- InFlightDiagnostic diag = emitOpError("expected no operands");
- diag.attachNote(parent->getLoc()) << "when returning from function";
- return diag;
- }
- if (getNumOperands() == 0) {
- if (expectedType.isa<LLVMVoidType>())
- return success();
- InFlightDiagnostic diag = emitOpError("expected 1 operand");
- diag.attachNote(parent->getLoc()) << "when returning from function";
- return diag;
- }
- if (expectedType != getOperand(0).getType()) {
- InFlightDiagnostic diag = emitOpError("mismatching result types");
- diag.attachNote(parent->getLoc()) << "when returning from function";
- return diag;
- }
+ auto parent = (*this)->getParentOfType<LLVMFuncOp>();
+ if (!parent)
+ return success();
+
+ Type expectedType = parent.getFunctionType().getReturnType();
+ if (expectedType.isa<LLVMVoidType>()) {
+ if (!getArg())
+ return success();
+ InFlightDiagnostic diag = emitOpError("expected no operands");
+ diag.attachNote(parent->getLoc()) << "when returning from function";
+ return diag;
+ }
+ if (!getArg()) {
+ if (expectedType.isa<LLVMVoidType>())
+ return success();
+ InFlightDiagnostic diag = emitOpError("expected 1 operand");
+ diag.attachNote(parent->getLoc()) << "when returning from function";
+ return diag;
+ }
+ if (expectedType != getArg().getType()) {
+ InFlightDiagnostic diag = emitOpError("mismatching result types");
+ diag.attachNote(parent->getLoc()) << "when returning from function";
+ return diag;
}
return success();
}
More information about the Mlir-commits
mailing list