[Mlir-commits] [mlir] 95d993a - [MLIR] Fix import of calls with mismatched variadic types (#124286)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 24 11:28:42 PST 2025
Author: Henrich Lauko
Date: 2025-01-24T20:28:36+01:00
New Revision: 95d993a838863269dc1b90de3808c1e40ac6d5f2
URL: https://github.com/llvm/llvm-project/commit/95d993a838863269dc1b90de3808c1e40ac6d5f2
DIFF: https://github.com/llvm/llvm-project/commit/95d993a838863269dc1b90de3808c1e40ac6d5f2.diff
LOG: [MLIR] Fix import of calls with mismatched variadic types (#124286)
Previously, an indirect call was incorrectly generated when
`llvm::CallBase::getCalledFunction` returned null due to a type mismatch
between the call and the function. This patch updates the code to use
`llvm::CallBase::getCalledOperand` instead.
Added:
Modified:
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/test/Target/LLVMIR/Import/instructions.ll
Removed:
################################################################################
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..40d86efe605ad0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));
- if (!callInst->getCalledFunction()) {
- if (!allowInlineAsm ||
- !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
- FailureOr<Value> called = convertValue(callInst->getCalledOperand());
- if (failed(called))
- return failure();
- operands.push_back(*called);
- }
+ bool isInlineAsm = callInst->isInlineAsm();
+ if (isInlineAsm && !allowInlineAsm)
+ return failure();
+
+ // Cannot use isIndirectCall() here because we need to handle Constant callees
+ // that are not considered indirect calls by LLVM. However, in MLIR, they are
+ // treated as indirect calls to constant operands that need to be converted.
+ // Skip the callee operand if it's inline assembly, as it's handled separately
+ // in InlineAsmOp.
+ if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+ FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+ if (failed(called))
+ return failure();
+ operands.push_back(*called);
}
+
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}
if (inst->getOpcode() == llvm::Instruction::Call) {
- auto *callInst = cast<llvm::CallInst>(inst);
+ auto callInst = cast<llvm::CallInst>(inst);
+ llvm::Value *calledOperand = callInst->getCalledOperand();
SmallVector<Type> types;
SmallVector<Value> operands;
@@ -1601,15 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*allowInlineAsm=*/true)))
return failure();
- auto funcTy =
- dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
- if (!funcTy)
- return failure();
-
- if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+ Type resultTy = convertType(callInst->getType());
+ if (!resultTy)
+ return failure();
auto callOp = builder.create<InlineAsmOp>(
- loc, funcTy.getReturnType(), operands,
- builder.getStringAttr(asmI->getAsmString()),
+ loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
@@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
else
mapNoResultOp(inst, callOp);
} else {
- CallOp callOp;
+ auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
+ // Retrieve the real function type. For direct calls, use the callee's
+ // function type, as it may
diff er from the operand type in the case of
+ // variadic functions. For indirect calls, use the call function type.
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+ return convertType(callee->getFunctionType());
+ return convertType(callInst->getFunctionType());
+ }());
+
+ if (!funcTy)
+ return failure();
- if (llvm::Function *callee = callInst->getCalledFunction()) {
- callOp = builder.create<CallOp>(
- loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
- operands);
- } else {
- callOp = builder.create<CallOp>(loc, funcTy, operands);
- }
+ auto callOp = [&]() -> CallOp {
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+ auto name = SymbolRefAttr::get(context, callee->getName());
+ return builder.create<CallOp>(loc, funcTy, name, operands);
+ }
+ return builder.create<CallOp>(loc, funcTy, operands);
+ }();
+
+ // Handle function attributes.
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);
- // Handle function attributes.
- if (callInst->hasFnAttr(llvm::Attribute::Convergent))
- callOp.setConvergent(true);
- if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
- callOp.setNoUnwind(true);
- if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
- callOp.setWillReturn(true);
+ callOp.setConvergent(callInst->isConvergent());
+ callOp.setNoUnwind(callInst->doesNotThrow());
+ callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..77052ab6e41f6c 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,31 @@ define void @varargs_call(i32 %0) {
; // -----
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+ ; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+ call void @varargs(i32 %0)
+ ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+ ; CHECK: llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+ call void @varargs()
+ ret void
+}
+
+; // -----
+
; CHECK: llvm.func @f()
declare void @f()
More information about the Mlir-commits
mailing list