[Mlir-commits] [mlir] [MLIR][LLVMIR] Relax mismatching calls (PR #135895)
Bruno Cardoso Lopes
llvmlistbot at llvm.org
Tue Apr 15 18:41:45 PDT 2025
https://github.com/bcardosolopes updated https://github.com/llvm/llvm-project/pull/135895
>From 4992f780eeec2ccbc944de2a76beef38f5b72847 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Tue, 15 Apr 2025 18:11:33 -0700
Subject: [PATCH] [MLIR][LLVMIR] Relax mismatching calls
LLVM IR currently [accepts](https://godbolt.org/z/nqnEsW1ja):
```
define void @incompatible_call_and_callee_types() {
call void @callee(i64 0)
ret void
}
define void @callee({ptr, i64}, i32) {
ret void
}
```
This currently fails to import. Even though these constructs are dangerous and
probably indicate some ODR violation (or optimization bug), they are "valid"
and should be imported into LLVM IR dialect. This PR implements that by using
an indirect call to represent it. Translation already works nicely and outputs
the same source llvm IR file.
The error is now a warning, the tests in
`mlir/test/Target/LLVMIR/Import/import-failure.ll` already use `CHECK` lines,
so no need to add extra diagnostic tests.
---
.../include/mlir/Target/LLVMIR/ModuleImport.h | 9 ++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 62 ++++++++++++++-----
.../test/Target/LLVMIR/Import/instructions.ll | 16 +++++
3 files changed, 68 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 3dc848c413905..7b01a96026413 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -362,9 +362,12 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
- /// associated with the call instruction. Returns failure when the call and
- /// the callee are not compatible or when nested type conversions failed.
- FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
+ /// associated with the call instruction. When the call and the callee are not
+ /// compatible (or when nested type conversions failed), emit a warning but
+ /// attempt translation using a bitcast and an indirect call (in order
+ /// represent valid and verified LLVM IR).
+ FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
+ Value &castResult);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter and result attributes attached to `func` and adds
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8e7a7ab95b6b6..7a790df75e66e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1668,8 +1668,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
-verifyFunctionTypeCompatibility(LLVMFunctionType callType,
- LLVMFunctionType calleeType) {
+checkFunctionTypeCompatibility(LLVMFunctionType callType,
+ LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();
@@ -1695,7 +1695,7 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
}
FailureOr<LLVMFunctionType>
-ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
+ModuleImport::convertFunctionType(llvm::CallBase *callInst, Value &castResult) {
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
@@ -1718,11 +1718,17 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
if (failed(calleeType))
return failure();
- // Compare the types to avoid constructing illegal call/invoke operations.
- if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
+ // Compare the types, if they are not compatible, avoid illegal call/invoke
+ // operations by casting to the callsite type and issuing an indirect call.
+ // LLVM IR currently supports this usage.
+ if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
Location loc = translateLoc(callInst->getDebugLoc());
- return emitError(loc) << "incompatible call and callee types: " << *callType
- << " and " << *calleeType;
+ FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
+ castResult = builder.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(context), calleeSym);
+ emitWarning(loc) << "incompatible call and callee types: " << *callType
+ << " and " << *calleeType;
+ return callType;
}
return calleeType;
@@ -1839,16 +1845,29 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*operand_attrs=*/nullptr)
.getOperation();
}
- FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
+ Value castResult;
+ FailureOr<LLVMFunctionType> funcTy =
+ convertFunctionType(callInst, castResult);
if (failed(funcTy))
return failure();
- FlatSymbolRefAttr callee = convertCalleeName(callInst);
- auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
+ FlatSymbolRefAttr callee = nullptr;
+ // If no cast is needed, use the original callee name. Otherwise patch
+ // operands to include the indirect call target. Build indirect call by
+ // passing using a nullptr `callee`.
+ if (!castResult)
+ callee = convertCalleeName(callInst);
+ else
+ operands->insert(operands->begin(), castResult);
+ CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
+
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
- // Handle parameter and result attributes.
- convertParameterAttributes(callInst, callOp, builder);
+
+ // Handle parameter and result attributes. Don't bother if there's a
+ // type mismatch.
+ if (!castResult)
+ convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();
@@ -1913,11 +1932,20 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();
- FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
+ Value castResult;
+ FailureOr<LLVMFunctionType> funcTy =
+ convertFunctionType(invokeInst, castResult);
if (failed(funcTy))
return failure();
- FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
+ FlatSymbolRefAttr calleeName = nullptr;
+ // If no cast is needed, use the original callee name. Otherwise patch
+ // operands to include the indirect call target. Build indirect call by
+ // passing using a nullptr `callee`.
+ if (!castResult)
+ calleeName = convertCalleeName(invokeInst);
+ else
+ operands->insert(operands->begin(), castResult);
// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
@@ -1929,8 +1957,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();
- // Handle parameter and result attributes.
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ // Handle parameter and result attributes. Don't bother if there's a
+ // type mismatch.
+ if (!castResult)
+ convertParameterAttributes(invokeInst, invokeOp, builder);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index c294e1b34f9bb..ef3c6430b7152 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -720,3 +720,19 @@ bb2:
declare void @g(...)
declare i32 @__gxx_personality_v0(...)
+
+; // -----
+
+; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
+define void @incompatible_call_and_callee_types() {
+ ; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
+ ; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
+ ; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
+ call void @callee(i64 0)
+ ; CHECK: llvm.return
+ ret void
+}
+
+define void @callee({ptr, i64}, i32) {
+ ret void
+}
More information about the Mlir-commits
mailing list