[Mlir-commits] [mlir] 95d993a - [MLIR] Fix import of calls with mismatched variadic types (#124286)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 24 11:28:42 PST 2025


Author: Henrich Lauko
Date: 2025-01-24T20:28:36+01:00
New Revision: 95d993a838863269dc1b90de3808c1e40ac6d5f2

URL: https://github.com/llvm/llvm-project/commit/95d993a838863269dc1b90de3808c1e40ac6d5f2
DIFF: https://github.com/llvm/llvm-project/commit/95d993a838863269dc1b90de3808c1e40ac6d5f2.diff

LOG: [MLIR] Fix import of calls with mismatched variadic types (#124286)

Previously, an indirect call was incorrectly generated when
`llvm::CallBase::getCalledFunction` returned null due to a type mismatch
between the call and the function. This patch updates the code to use
`llvm::CallBase::getCalledOperand` instead.

Added: 
    

Modified: 
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Target/LLVMIR/Import/instructions.ll

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..40d86efe605ad0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
-  if (!callInst->getCalledFunction()) {
-    if (!allowInlineAsm ||
-        !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
-      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-      if (failed(called))
-        return failure();
-      operands.push_back(*called);
-    }
+  bool isInlineAsm = callInst->isInlineAsm();
+  if (isInlineAsm && !allowInlineAsm)
+    return failure();
+
+  // Cannot use isIndirectCall() here because we need to handle Constant callees
+  // that are not considered indirect calls by LLVM. However, in MLIR, they are
+  // treated as indirect calls to constant operands that need to be converted.
+  // Skip the callee operand if it's inline assembly, as it's handled separately
+  // in InlineAsmOp.
+  if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+    if (failed(called))
+      return failure();
+    operands.push_back(*called);
   }
+
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto *callInst = cast<llvm::CallInst>(inst);
+    auto callInst = cast<llvm::CallInst>(inst);
+    llvm::Value *calledOperand = callInst->getCalledOperand();
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
@@ -1601,15 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                           /*allowInlineAsm=*/true)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
-    if (!funcTy)
-      return failure();
-
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      Type resultTy = convertType(callInst->getType());
+      if (!resultTy)
+        return failure();
       auto callOp = builder.create<InlineAsmOp>(
-          loc, funcTy.getReturnType(), operands,
-          builder.getStringAttr(asmI->getAsmString()),
+          loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
           builder.getStringAttr(asmI->getConstraintString()),
           /*has_side_effects=*/true,
           /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
@@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       else
         mapNoResultOp(inst, callOp);
     } else {
-      CallOp callOp;
+      auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
+        // Retrieve the real function type. For direct calls, use the callee's
+        // function type, as it may 
diff er from the operand type in the case of
+        // variadic functions. For indirect calls, use the call function type.
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+          return convertType(callee->getFunctionType());
+        return convertType(callInst->getFunctionType());
+      }());
+
+      if (!funcTy)
+        return failure();
 
-      if (llvm::Function *callee = callInst->getCalledFunction()) {
-        callOp = builder.create<CallOp>(
-            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-            operands);
-      } else {
-        callOp = builder.create<CallOp>(loc, funcTy, operands);
-      }
+      auto callOp = [&]() -> CallOp {
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+          auto name = SymbolRefAttr::get(context, callee->getName());
+          return builder.create<CallOp>(loc, funcTy, name, operands);
+        }
+        return builder.create<CallOp>(loc, funcTy, operands);
+      }();
+
+      // Handle function attributes.
       callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
       callOp.setTailCallKind(
           convertTailCallKindFromLLVM(callInst->getTailCallKind()));
       setFastmathFlagsAttr(inst, callOp);
 
-      // Handle function attributes.
-      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-        callOp.setConvergent(true);
-      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-        callOp.setNoUnwind(true);
-      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-        callOp.setWillReturn(true);
+      callOp.setConvergent(callInst->isConvergent());
+      callOp.setNoUnwind(callInst->doesNotThrow());
+      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
 
       llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
       ModRefInfo othermem = convertModRefInfoFromLLVM(

diff  --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..77052ab6e41f6c 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,31 @@ define void @varargs_call(i32 %0) {
 
 ; // -----
 
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+  ; CHECK:  llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+  call void @varargs(i32 %0)
+  ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+  ; CHECK:  llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+  call void @varargs()
+  ret void
+}
+
+; // -----
+
 ; CHECK: llvm.func @f()
 declare void @f()
 


        


More information about the Mlir-commits mailing list