[Mlir-commits] [mlir] 8cdc16d - [MLIR][LLVM] Avoid importing broken calls and invokes (#125041)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 30 05:06:32 PST 2025


Author: Christian Ulmann
Date: 2025-01-30T14:06:28+01:00
New Revision: 8cdc16d350c9f162759689402db841685e9534b1

URL: https://github.com/llvm/llvm-project/commit/8cdc16d350c9f162759689402db841685e9534b1
DIFF: https://github.com/llvm/llvm-project/commit/8cdc16d350c9f162759689402db841685e9534b1.diff

LOG: [MLIR][LLVM] Avoid importing broken calls and invokes (#125041)

This commit adds a check to catch calls/invokes that do not satisfy the
type constraints of their callee. This is not verified in LLVM IR but is
considered UB. Importing this into MLIR will lead to verification
errors, thus we should avoid this early on.

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Target/LLVMIR/Import/import-failure.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 3c1221e20afbc5..84aecbd4373e05 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -326,8 +326,9 @@ class ModuleImport {
   /// Converts the callee's function type. For direct calls, it converts the
   /// actual function type, which may 
diff er from the called operand type in
   /// variadic functions. For indirect calls, it converts the function type
-  /// associated with the call instruction.
-  LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
+  /// associated with the call instruction. Returns failure when the call and
+  /// the callee are not compatible or when nested type conversions failed.
+  FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
   /// Returns the callee name, or an empty symbol if the call is not direct.
   FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
   /// Converts the parameter attributes attached to `func` and adds them to

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 1d1a985c46fb5b..e23ffdedd9a60c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
   return operands;
 }
 
-LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
-  llvm::Value *calledOperand = callInst->getCalledOperand();
-  Type converted = [&] {
-    if (auto callee = dyn_cast<llvm::Function>(calledOperand))
-      return convertType(callee->getFunctionType());
-    return convertType(callInst->getFunctionType());
-  }();
+/// Checks if `callType` and `calleeType` are compatible and can be represented
+/// in MLIR.
+static LogicalResult
+verifyFunctionTypeCompatibility(LLVMFunctionType callType,
+                                LLVMFunctionType calleeType) {
+  if (callType.getReturnType() != calleeType.getReturnType())
+    return failure();
+
+  if (calleeType.isVarArg()) {
+    // For variadic functions, the call can have more types than the callee
+    // specifies.
+    if (callType.getNumParams() < calleeType.getNumParams())
+      return failure();
+  } else {
+    // For non-variadic functions, the number of parameters needs to be the
+    // same.
+    if (callType.getNumParams() != calleeType.getNumParams())
+      return failure();
+  }
+
+  // Check that all operands match.
+  for (auto [operandType, argumentType] :
+       llvm::zip(callType.getParams(), calleeType.getParams()))
+    if (operandType != argumentType)
+      return failure();
+
+  return success();
+}
 
-  if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
+FailureOr<LLVMFunctionType>
+ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
+  auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
+    auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
+    if (!funcTy)
+      return failure();
     return funcTy;
-  return {};
+  };
+
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  FailureOr<LLVMFunctionType> callType =
+      castOrFailure(convertType(callInst->getFunctionType()));
+  if (failed(callType))
+    return failure();
+  auto *callee = dyn_cast<llvm::Function>(calledOperand);
+  // For indirect calls, return the type of the call itself.
+  if (!callee)
+    return callType;
+
+  FailureOr<LLVMFunctionType> calleeType =
+      castOrFailure(convertType(callee->getFunctionType()));
+  if (failed(calleeType))
+    return failure();
+
+  // Compare the types to avoid constructing illegal call/invoke operations.
+  if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
+    Location loc = translateLoc(callInst->getDebugLoc());
+    return emitError(loc) << "incompatible call and callee types: " << *callType
+                          << " and " << *calleeType;
+  }
+
+  return calleeType;
 }
 
 FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
   llvm::Value *calledOperand = callInst->getCalledOperand();
-  if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+  if (auto *callee = dyn_cast<llvm::Function>(calledOperand))
     return SymbolRefAttr::get(context, callee->getName());
   return {};
 }
@@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto callInst = cast<llvm::CallInst>(inst);
+    auto *callInst = cast<llvm::CallInst>(inst);
     llvm::Value *calledOperand = callInst->getCalledOperand();
 
     FailureOr<SmallVector<Value>> operands =
@@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       return failure();
 
     auto callOp = [&]() -> FailureOr<Operation *> {
-      if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      if (auto *asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
         Type resultTy = convertType(callInst->getType());
         if (!resultTy)
           return failure();
@@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                 /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
                 /*operand_attrs=*/nullptr)
             .getOperation();
-      } else {
-        LLVMFunctionType funcTy = convertFunctionType(callInst);
-        if (!funcTy)
-          return failure();
-
-        FlatSymbolRefAttr callee = convertCalleeName(callInst);
-        auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
-        if (failed(convertCallAttributes(callInst, callOp)))
-          return failure();
-        return callOp.getOperation();
       }
+      FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
+      if (failed(funcTy))
+        return failure();
+
+      FlatSymbolRefAttr callee = convertCalleeName(callInst);
+      auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
+      if (failed(convertCallAttributes(callInst, callOp)))
+        return failure();
+      return callOp.getOperation();
     }();
 
     if (failed(callOp))
@@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                  unwindArgs)))
       return failure();
 
-    auto funcTy = convertFunctionType(invokeInst);
-    if (!funcTy)
+    FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
+    if (failed(funcTy))
       return failure();
 
     FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
@@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     // added later on to handle the case in which the operation result is
     // included in this list.
     auto invokeOp = builder.create<InvokeOp>(
-        loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
+        loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(),
         lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
 
     if (failed(convertInvokeAttributes(invokeInst, invokeOp)))

diff  --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index b616cb81e0a8a5..d929a592847622 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -1,4 +1,4 @@
-; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
+; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s
 
 ; CHECK:      <unknown>
 ; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2]
@@ -353,3 +353,32 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
 ; CHECK:      import-failure.ll
 ; CHECK-SAME: warning: unhandled data layout token: ni:42
 target datalayout = "e-ni:42-i64:64"
+
+; // -----
+
+; CHECK:      <unknown>
+; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
+define void @incompatible_call_and_callee_types() {
+  call void @callee(i64 0)
+  ret void
+}
+
+declare void @callee(ptr)
+
+; // -----
+
+; CHECK:      <unknown>
+; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
+define void @f() personality ptr @__gxx_personality_v0 {
+entry:
+  invoke void @g() to label %bb1 unwind label %bb2
+bb1:
+  ret void
+bb2:
+  %0 = landingpad i32 cleanup
+  unreachable
+}
+
+declare i32 @g()
+
+declare i32 @__gxx_personality_v0(...)


        


More information about the Mlir-commits mailing list