[Mlir-commits] [mlir] [MLIR] Fix import of calls with mismatched variadic types (PR #124286)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 24 07:21:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Henrich Lauko (xlauko)

<details>
<summary>Changes</summary>

Previously, an indirect call was incorrectly generated when `llvm::CallBase::getCalledFunction` returned null due to a type mismatch between the call and the function. This patch updates the code to use `llvm::CallBase::getCalledOperand` instead.

---
Full diff: https://github.com/llvm/llvm-project/pull/124286.diff


2 Files Affected:

- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+45-31) 
- (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+35) 


``````````diff
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..b8f66357d9250e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
-  if (!callInst->getCalledFunction()) {
-    if (!allowInlineAsm ||
-        !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
-      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-      if (failed(called))
-        return failure();
-      operands.push_back(*called);
-    }
+  bool isInlineAsm = callInst->isInlineAsm();
+  if (isInlineAsm && !allowInlineAsm)
+    return failure();
+
+  // Cannot use isIndirectCall() here because we need to handle Constant callees
+  // that are not considered indirect calls by LLVM.  However, in MLIR, they are
+  // treated as indirect calls to constant operands that need to be converted.
+  // Skip the callee operand if it's inline assembly, as it's handled separately
+  // in InlineAsmOp.
+  if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+    if (failed(called))
+      return failure();
+    operands.push_back(*called);
   }
+
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto *callInst = cast<llvm::CallInst>(inst);
+    auto callInst = cast<llvm::CallInst>(inst);
+    auto calledOperand = callInst->getCalledOperand();
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
@@ -1601,14 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                           /*allowInlineAsm=*/true)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
-    if (!funcTy)
-      return failure();
-
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      auto resultTy = convertType(callInst->getType());
+      if (!resultTy)
+        return failure();
       auto callOp = builder.create<InlineAsmOp>(
-          loc, funcTy.getReturnType(), operands,
+          loc, resultTy, operands,
           builder.getStringAttr(asmI->getAsmString()),
           builder.getStringAttr(asmI->getConstraintString()),
           /*has_side_effects=*/true,
@@ -1619,27 +1625,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       else
         mapNoResultOp(inst, callOp);
     } else {
-      CallOp callOp;
+      auto funcTy = dyn_cast<LLVMFunctionType>([&] () -> Type {
+        // Retrieve the real function type. For direct calls, use the callee's
+        // function type, as it may differ from the operand type in the case of
+        // variadic functions. For indirect calls, use the call function type.
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+          return convertType(callee->getFunctionType());
+        return convertType(callInst->getFunctionType());
+      }() );
+
+      if (!funcTy)
+        return failure();
 
-      if (llvm::Function *callee = callInst->getCalledFunction()) {
-        callOp = builder.create<CallOp>(
-            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-            operands);
-      } else {
-        callOp = builder.create<CallOp>(loc, funcTy, operands);
-      }
+      auto callOp = [&]() -> CallOp {
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+          auto name = SymbolRefAttr::get(context, callee->getName());
+          return builder.create<CallOp>(loc, funcTy, name, operands);
+        }
+        return builder.create<CallOp>(loc, funcTy, operands);
+      }();
+
+      // Handle function attributes.
       callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
       callOp.setTailCallKind(
           convertTailCallKindFromLLVM(callInst->getTailCallKind()));
       setFastmathFlagsAttr(inst, callOp);
 
-      // Handle function attributes.
-      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-        callOp.setConvergent(true);
-      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-        callOp.setNoUnwind(true);
-      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-        callOp.setWillReturn(true);
+      callOp.setConvergent(callInst->isConvergent());
+      callOp.setNoUnwind(callInst->doesNotThrow());
+      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
 
       llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
       ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..47f821c8d29909 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,41 @@ define void @varargs_call(i32 %0) {
 
 ; // -----
 
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+  ; CHECK:  llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+  call void @varargs(i32 %0)
+  ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+  ; CHECK:  llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+  call void @varargs()
+  ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @undef_call
+define void @undef_call() {
+  ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : !llvm.ptr
+  ; CHECK-NEXT: %[[CONST:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+  ; CHECK-NEXT: llvm.call %[[UNDEF]](%[[CONST]]) : !llvm.ptr, (i32) -> ()
+  call void undef(i32 0)
+  ret void
+}
+; // -----
+
 ; CHECK: llvm.func @f()
 declare void @f()
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/124286


More information about the Mlir-commits mailing list