[Mlir-commits] [mlir] [MLIR] Fix import of invokes with mismatched variadic types (PR #124828)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 12:03:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Henrich Lauko (xlauko)

<details>
<summary>Changes</summary>

This resolves the same issue addressed in https://github.com/llvm/llvm-project/pull/124286, but for invoke operations. The issue arose from duplicated logic for both imports. This PR also refactors the common import code for call and invoke instructions to mitigate issues in the future.

---
Full diff: https://github.com/llvm/llvm-project/pull/124828.diff


3 Files Affected:

- (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+19-12) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+116-91) 
- (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+18) 


``````````diff
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 33c9af7c6335a4..7977cf818649e3 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -316,24 +316,31 @@ class ModuleImport {
   LogicalResult convertBranchArgs(llvm::Instruction *branch,
                                   llvm::BasicBlock *target,
                                   SmallVectorImpl<Value> &blockArguments);
-  /// Appends the converted result type and operands of `callInst` to the
-  /// `types` and `operands` arrays. For indirect calls, the method additionally
-  /// inserts the called function at the beginning of the `operands` array.
-  /// If `allowInlineAsm` is set to false (the default), it will return failure
-  /// if the called operand is an inline asm which isn't convertible to MLIR as
-  /// a value.
-  LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
-                                           SmallVectorImpl<Type> &types,
-                                           SmallVectorImpl<Value> &operands,
-                                           bool allowInlineAsm = false);
-  /// Converts the parameter attributes attached to `func` and adds them to the
-  /// `funcOp`.
+  /// Convert `callInst` operands. For indirect calls, the method additionally
+  /// inserts the called function at the beginning of the returned `operands`
+  /// array.  If `allowInlineAsm` is set to false (the default), it will return
+  /// failure if the called operand is an inline asm which isn't convertible to
+  /// MLIR as a value.
+  FailureOr<SmallVector<Value>>
+  convertCallOperands(llvm::CallBase *callInst, bool allowInlineAsm = false);
+  /// Converts the callee's function type. For direct calls, it converts the
+  /// actual function type, which may differ from the called operand type in
+  /// variadic functions. For indirect calls, it converts the function type
+  /// associated with the call instruction.
+  LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
+  /// Returns the callee name, or an empty symbol if the call is not direct.
+  FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
+  /// Converts the parameter attributes attached to `func` and adds them to
+  /// the `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
                                   OpBuilder &builder);
   /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
+  /// Converts the attributes attached to `inst` and adds them to the `op`.
+  LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
+  LogicalResult convertInvokeAttributes(llvm::InvokeInst *inst, InvokeOp op);
   /// Returns the builtin type equivalent to the given LLVM dialect type or
   /// nullptr if there is no equivalent. The returned type can be used to create
   /// an attribute for a GlobalOp or a ConstantOp.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 40d86efe605ad0..b5a15ac5443ddf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
   return success();
 }
 
-LogicalResult ModuleImport::convertCallTypeAndOperands(
-    llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
-    SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
-  if (!callInst->getType()->isVoidTy())
-    types.push_back(convertType(callInst->getType()));
-
+FailureOr<SmallVector<Value>>
+ModuleImport::convertCallOperands(llvm::CallBase *callInst,
+                                  bool allowInlineAsm) {
   bool isInlineAsm = callInst->isInlineAsm();
   if (isInlineAsm && !allowInlineAsm)
     return failure();
 
+  SmallVector<Value> operands;
+
   // Cannot use isIndirectCall() here because we need to handle Constant callees
   // that are not considered indirect calls by LLVM. However, in MLIR, they are
   // treated as indirect calls to constant operands that need to be converted.
@@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
     return failure();
+
   llvm::append_range(operands, *arguments);
-  return success();
+  return operands;
+}
+
+LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  Type converted = [&] {
+    if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+      return convertType(callee->getFunctionType());
+    return convertType(callInst->getFunctionType());
+  }();
+
+  if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
+    return funcTy;
+  return {};
+}
+
+FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+    return SymbolRefAttr::get(context, callee->getName());
+  return {};
 }
 
 LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
@@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     auto callInst = cast<llvm::CallInst>(inst);
     llvm::Value *calledOperand = callInst->getCalledOperand();
 
-    SmallVector<Type> types;
-    SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(callInst, types, operands,
-                                          /*allowInlineAsm=*/true)))
+    FailureOr<SmallVector<Value>> operands =
+        convertCallOperands(callInst, /*allowInlineAsm=*/true);
+    if (failed(operands))
       return failure();
 
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
-      Type resultTy = convertType(callInst->getType());
-      if (!resultTy)
-        return failure();
-      auto callOp = builder.create<InlineAsmOp>(
-          loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
-          builder.getStringAttr(asmI->getConstraintString()),
-          /*has_side_effects=*/true,
-          /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
-          /*operand_attrs=*/nullptr);
-      if (!callInst->getType()->isVoidTy())
-        mapValue(inst, callOp.getResult(0));
-      else
-        mapNoResultOp(inst, callOp);
-    } else {
-      auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
-        // Retrieve the real function type. For direct calls, use the callee's
-        // function type, as it may differ from the operand type in the case of
-        // variadic functions. For indirect calls, use the call function type.
-        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
-          return convertType(callee->getFunctionType());
-        return convertType(callInst->getFunctionType());
-      }());
-
-      if (!funcTy)
-        return failure();
+    auto callOp = [&]() -> FailureOr<Operation *> {
+      if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+        Type resultTy = convertType(callInst->getType());
+        if (!resultTy)
+          return failure();
+        return builder
+            .create<InlineAsmOp>(
+                loc, resultTy, *operands,
+                builder.getStringAttr(asmI->getAsmString()),
+                builder.getStringAttr(asmI->getConstraintString()),
+                /*has_side_effects=*/true,
+                /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
+                /*operand_attrs=*/nullptr)
+            .getOperation();
+      } else {
+        LLVMFunctionType funcTy = convertFunctionType(callInst);
+        if (!funcTy)
+          return failure();
 
-      auto callOp = [&]() -> CallOp {
-        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
-          auto name = SymbolRefAttr::get(context, callee->getName());
-          return builder.create<CallOp>(loc, funcTy, name, operands);
-        }
-        return builder.create<CallOp>(loc, funcTy, operands);
-      }();
-
-      // Handle function attributes.
-      callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
-      callOp.setTailCallKind(
-          convertTailCallKindFromLLVM(callInst->getTailCallKind()));
-      setFastmathFlagsAttr(inst, callOp);
-
-      callOp.setConvergent(callInst->isConvergent());
-      callOp.setNoUnwind(callInst->doesNotThrow());
-      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
-
-      llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
-      ModRefInfo othermem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::Other));
-      ModRefInfo argMem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
-      ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
-      auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
-                                            argMem, inaccessibleMem);
-      // Only set the attribute when it does not match the default value.
-      if (!memAttr.isReadWrite())
-        callOp.setMemoryEffectsAttr(memAttr);
-
-      if (!callInst->getType()->isVoidTy())
-        mapValue(inst, callOp.getResult());
-      else
-        mapNoResultOp(inst, callOp);
-    }
+        FlatSymbolRefAttr callee = convertCalleeName(callInst);
+        auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
+        if (failed(convertCallAttributes(callInst, callOp)))
+          return failure();
+        return callOp.getOperation();
+      }
+    }();
+
+    if (failed(callOp))
+      return failure();
+
+    if (!callInst->getType()->isVoidTy())
+      mapValue(inst, (*callOp)->getResult(0));
+    else
+      mapNoResultOp(inst, *callOp);
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::LandingPad) {
@@ -1695,9 +1685,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
   if (inst->getOpcode() == llvm::Instruction::Invoke) {
     auto *invokeInst = cast<llvm::InvokeInst>(inst);
 
-    SmallVector<Type> types;
-    SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
+    if (invokeInst->isInlineAsm()) {
+      return emitError(loc) << "invoke of inline assembly is not supported";
+    }
+
+    FailureOr<SmallVector<Value>> operands = convertCallOperands(invokeInst);
+    if (failed(operands))
       return failure();
 
     // Check whether the invoke result is an argument to the normal destination
@@ -1724,27 +1717,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                  unwindArgs)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
+    auto funcTy = convertFunctionType(invokeInst);
     if (!funcTy)
       return failure();
 
+    FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
+
     // Create the invoke operation. Normal destination block arguments will be
     // added later on to handle the case in which the operation result is
     // included in this list.
-    InvokeOp invokeOp;
-    if (llvm::Function *callee = invokeInst->getCalledFunction()) {
-      invokeOp = builder.create<InvokeOp>(
-          loc, funcTy,
-          SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
-          directNormalDest, ValueRange(),
-          lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
-    } else {
-      invokeOp = builder.create<InvokeOp>(
-          loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
-          ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
-    }
-    invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
+    auto invokeOp = builder.create<InvokeOp>(
+        loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
+        lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
+
+    if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
+      return failure();
+
     if (!invokeInst->getType()->isVoidTy())
       mapValue(inst, invokeOp.getResults().front());
     else
@@ -2097,6 +2085,43 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+namespace impl {
+template <typename Op>
+LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
+  op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
+  return success();
+}
+} // namespace impl
+
+LogicalResult ModuleImport::convertInvokeAttributes(llvm::InvokeInst *inst,
+                                                    InvokeOp op) {
+  return ::impl::convertCallBaseAttributes(inst, op);
+}
+
+LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
+                                                  CallOp op) {
+  setFastmathFlagsAttr(inst, op.getOperation());
+  op.setTailCallKind(convertTailCallKindFromLLVM(inst->getTailCallKind()));
+  op.setConvergent(inst->isConvergent());
+  op.setNoUnwind(inst->doesNotThrow());
+  op.setWillReturn(inst->hasFnAttr(llvm::Attribute::WillReturn));
+
+  llvm::MemoryEffects memEffects = inst->getMemoryEffects();
+  ModRefInfo othermem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+  ModRefInfo argMem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+  ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+  auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
+                                        inaccessibleMem);
+  // Only set the attribute when it does not match the default value.
+  if (!memAttr.isReadWrite())
+    op.setMemoryEffectsAttr(memAttr);
+
+  return ::impl::convertCallBaseAttributes(inst, op);
+}
+
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 77052ab6e41f6c..c294e1b34f9bbd 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -702,3 +702,21 @@ define void @fence() {
   fence syncscope("") seq_cst
   ret void
 }
+
+; // -----
+
+; CHECK-LABEL: @f
+define void @f() personality ptr @__gxx_personality_v0 {
+entry:
+; CHECK: llvm.invoke @g() to ^bb1 unwind ^bb2 vararg(!llvm.func<void (...)>) : () -> ()
+  invoke void @g() to label %bb1 unwind label %bb2
+bb1:
+  ret void
+bb2:
+  %0 = landingpad i32 cleanup
+  unreachable
+}
+
+declare void @g(...)
+
+declare i32 @__gxx_personality_v0(...)

``````````

</details>


https://github.com/llvm/llvm-project/pull/124828


More information about the Mlir-commits mailing list