[Mlir-commits] [mlir] [MLIR][LLVM] Avoid importing broken calls and invokes (PR #125041)

Christian Ulmann llvmlistbot at llvm.org
Thu Jan 30 03:53:21 PST 2025


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/125041

>From e4063f9c87462aa5b9fea79911e4ce8924dd390a Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Thu, 30 Jan 2025 09:19:34 +0000
Subject: [PATCH] [MLIR][LLVM] Avoid importing broken calls and invokes

This commit adds a check to catch calls/invokes that do not satisfy the
return type of their callee. This is not verified in LLVM IR but is
considered UB. Importing this into MLIR will lead to verification
errors, thus we should avoid this early on.
---
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  5 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 99 ++++++++++++++-----
 .../Target/LLVMIR/Import/import-failure.ll    | 31 +++++-
 3 files changed, 107 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 3c1221e20afbc5c..84aecbd4373e05e 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -326,8 +326,9 @@ class ModuleImport {
   /// Converts the callee's function type. For direct calls, it converts the
   /// actual function type, which may differ from the called operand type in
   /// variadic functions. For indirect calls, it converts the function type
-  /// associated with the call instruction.
-  LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
+  /// associated with the call instruction. Returns failure when the call and
+  /// the callee are not compatible or when nested type conversions failed.
+  FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
   /// Returns the callee name, or an empty symbol if the call is not direct.
   FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
   /// Converts the parameter attributes attached to `func` and adds them to
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 1d1a985c46fb5b9..e23ffdedd9a60cf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
   return operands;
 }
 
-LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
-  llvm::Value *calledOperand = callInst->getCalledOperand();
-  Type converted = [&] {
-    if (auto callee = dyn_cast<llvm::Function>(calledOperand))
-      return convertType(callee->getFunctionType());
-    return convertType(callInst->getFunctionType());
-  }();
+/// Checks if `callType` and `calleeType` are compatible and can be represented
+/// in MLIR.
+static LogicalResult
+verifyFunctionTypeCompatibility(LLVMFunctionType callType,
+                                LLVMFunctionType calleeType) {
+  if (callType.getReturnType() != calleeType.getReturnType())
+    return failure();
+
+  if (calleeType.isVarArg()) {
+    // For variadic functions, the call can have more types than the callee
+    // specifies.
+    if (callType.getNumParams() < calleeType.getNumParams())
+      return failure();
+  } else {
+    // For non-variadic functions, the number of parameters needs to be the
+    // same.
+    if (callType.getNumParams() != calleeType.getNumParams())
+      return failure();
+  }
+
+  // Check that all operands match.
+  for (auto [operandType, argumentType] :
+       llvm::zip(callType.getParams(), calleeType.getParams()))
+    if (operandType != argumentType)
+      return failure();
+
+  return success();
+}
 
-  if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
+FailureOr<LLVMFunctionType>
+ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
+  auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
+    auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
+    if (!funcTy)
+      return failure();
     return funcTy;
-  return {};
+  };
+
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  FailureOr<LLVMFunctionType> callType =
+      castOrFailure(convertType(callInst->getFunctionType()));
+  if (failed(callType))
+    return failure();
+  auto *callee = dyn_cast<llvm::Function>(calledOperand);
+  // For indirect calls, return the type of the call itself.
+  if (!callee)
+    return callType;
+
+  FailureOr<LLVMFunctionType> calleeType =
+      castOrFailure(convertType(callee->getFunctionType()));
+  if (failed(calleeType))
+    return failure();
+
+  // Compare the types to avoid constructing illegal call/invoke operations.
+  if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
+    Location loc = translateLoc(callInst->getDebugLoc());
+    return emitError(loc) << "incompatible call and callee types: " << *callType
+                          << " and " << *calleeType;
+  }
+
+  return calleeType;
 }
 
 FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
   llvm::Value *calledOperand = callInst->getCalledOperand();
-  if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+  if (auto *callee = dyn_cast<llvm::Function>(calledOperand))
     return SymbolRefAttr::get(context, callee->getName());
   return {};
 }
@@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto callInst = cast<llvm::CallInst>(inst);
+    auto *callInst = cast<llvm::CallInst>(inst);
     llvm::Value *calledOperand = callInst->getCalledOperand();
 
     FailureOr<SmallVector<Value>> operands =
@@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       return failure();
 
     auto callOp = [&]() -> FailureOr<Operation *> {
-      if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      if (auto *asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
         Type resultTy = convertType(callInst->getType());
         if (!resultTy)
           return failure();
@@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                 /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
                 /*operand_attrs=*/nullptr)
             .getOperation();
-      } else {
-        LLVMFunctionType funcTy = convertFunctionType(callInst);
-        if (!funcTy)
-          return failure();
-
-        FlatSymbolRefAttr callee = convertCalleeName(callInst);
-        auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
-        if (failed(convertCallAttributes(callInst, callOp)))
-          return failure();
-        return callOp.getOperation();
       }
+      FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
+      if (failed(funcTy))
+        return failure();
+
+      FlatSymbolRefAttr callee = convertCalleeName(callInst);
+      auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
+      if (failed(convertCallAttributes(callInst, callOp)))
+        return failure();
+      return callOp.getOperation();
     }();
 
     if (failed(callOp))
@@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                  unwindArgs)))
       return failure();
 
-    auto funcTy = convertFunctionType(invokeInst);
-    if (!funcTy)
+    FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
+    if (failed(funcTy))
       return failure();
 
     FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
@@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     // added later on to handle the case in which the operation result is
     // included in this list.
     auto invokeOp = builder.create<InvokeOp>(
-        loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
+        loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(),
         lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
 
     if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index b616cb81e0a8a5c..d929a5928476223 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -1,4 +1,4 @@
-; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
+; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s
 
 ; CHECK:      <unknown>
 ; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2]
@@ -353,3 +353,32 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
 ; CHECK:      import-failure.ll
 ; CHECK-SAME: warning: unhandled data layout token: ni:42
 target datalayout = "e-ni:42-i64:64"
+
+; // -----
+
+; CHECK:      <unknown>
+; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
+define void @incompatible_call_and_callee_types() {
+  call void @callee(i64 0)
+  ret void
+}
+
+declare void @callee(ptr)
+
+; // -----
+
+; CHECK:      <unknown>
+; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
+define void @f() personality ptr @__gxx_personality_v0 {
+entry:
+  invoke void @g() to label %bb1 unwind label %bb2
+bb1:
+  ret void
+bb2:
+  %0 = landingpad i32 cleanup
+  unreachable
+}
+
+declare i32 @g()
+
+declare i32 @__gxx_personality_v0(...)



More information about the Mlir-commits mailing list