[Mlir-commits] [mlir] [MLIR] Fix import of invokes with mismatched variadic types (PR #124828)

Henrich Lauko llvmlistbot at llvm.org
Wed Jan 29 01:12:39 PST 2025


https://github.com/xlauko updated https://github.com/llvm/llvm-project/pull/124828

>From ec00971c5821aca0ac6ce75dfad5cd066f3c8f18 Mon Sep 17 00:00:00 2001
From: xlauko <xlauko at mail.muni.cz>
Date: Tue, 28 Jan 2025 13:00:08 +0100
Subject: [PATCH] [MLIR] Fix import of invokes with mismatched variadic types

This resolves the same issue addressed in #124286, but for invoke
operations. The issue arose from duplicated logic for both imports. This
PR also refactors the common import code for call and invoke
instructions to mitigate issues in the future.
---
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  32 ++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 208 ++++++++++--------
 .../test/Target/LLVMIR/Import/instructions.ll |  18 ++
 3 files changed, 153 insertions(+), 105 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 33c9af7c6335a4..3c1221e20afbc5 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -316,24 +316,32 @@ class ModuleImport {
   LogicalResult convertBranchArgs(llvm::Instruction *branch,
                                   llvm::BasicBlock *target,
                                   SmallVectorImpl<Value> &blockArguments);
-  /// Appends the converted result type and operands of `callInst` to the
-  /// `types` and `operands` arrays. For indirect calls, the method additionally
-  /// inserts the called function at the beginning of the `operands` array.
-  /// If `allowInlineAsm` is set to false (the default), it will return failure
-  /// if the called operand is an inline asm which isn't convertible to MLIR as
-  /// a value.
-  LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
-                                           SmallVectorImpl<Type> &types,
-                                           SmallVectorImpl<Value> &operands,
-                                           bool allowInlineAsm = false);
-  /// Converts the parameter attributes attached to `func` and adds them to the
-  /// `funcOp`.
+  /// Convert `callInst` operands. For indirect calls, the method additionally
+  /// inserts the called function at the beginning of the returned `operands`
+  /// array.  If `allowInlineAsm` is set to false (the default), it will return
+  /// failure if the called operand is an inline asm which isn't convertible to
+  /// MLIR as a value.
+  FailureOr<SmallVector<Value>>
+  convertCallOperands(llvm::CallBase *callInst, bool allowInlineAsm = false);
+  /// Converts the callee's function type. For direct calls, it converts the
+  /// actual function type, which may differ from the called operand type in
+  /// variadic functions. For indirect calls, it converts the function type
+  /// associated with the call instruction.
+  LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
+  /// Returns the callee name, or an empty symbol if the call is not direct.
+  FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
+  /// Converts the parameter attributes attached to `func` and adds them to
+  /// the `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
                                   OpBuilder &builder);
   /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
   /// DictionaryAttr for the LLVM dialect.
   DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
                                            OpBuilder &builder);
+  /// Converts the attributes attached to `inst` and adds them to the `op`.
+  LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
+  /// Converts the attributes attached to `inst` and adds them to the `op`.
+  LogicalResult convertInvokeAttributes(llvm::InvokeInst *inst, InvokeOp op);
   /// Returns the builtin type equivalent to the given LLVM dialect type or
   /// nullptr if there is no equivalent. The returned type can be used to create
   /// an attribute for a GlobalOp or a ConstantOp.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 40d86efe605ad0..170662f5b276d4 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
   if (iface.isConvertibleInstruction(inst->getOpcode()))
     return iface.convertInstruction(odsBuilder, inst, llvmOperands,
                                     moduleImport);
-    // TODO: Implement the `convertInstruction` hooks in the
-    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+  // TODO: Implement the `convertInstruction` hooks in the
+  // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
@@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
   return success();
 }
 
-LogicalResult ModuleImport::convertCallTypeAndOperands(
-    llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
-    SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
-  if (!callInst->getType()->isVoidTy())
-    types.push_back(convertType(callInst->getType()));
-
+FailureOr<SmallVector<Value>>
+ModuleImport::convertCallOperands(llvm::CallBase *callInst,
+                                  bool allowInlineAsm) {
   bool isInlineAsm = callInst->isInlineAsm();
   if (isInlineAsm && !allowInlineAsm)
     return failure();
 
+  SmallVector<Value> operands;
+
   // Cannot use isIndirectCall() here because we need to handle Constant callees
   // that are not considered indirect calls by LLVM. However, in MLIR, they are
   // treated as indirect calls to constant operands that need to be converted.
@@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
     return failure();
+
   llvm::append_range(operands, *arguments);
-  return success();
+  return operands;
+}
+
+LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  Type converted = [&] {
+    if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+      return convertType(callee->getFunctionType());
+    return convertType(callInst->getFunctionType());
+  }();
+
+  if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
+    return funcTy;
+  return {};
+}
+
+FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
+  llvm::Value *calledOperand = callInst->getCalledOperand();
+  if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+    return SymbolRefAttr::get(context, callee->getName());
+  return {};
 }
 
 LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
@@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     auto callInst = cast<llvm::CallInst>(inst);
     llvm::Value *calledOperand = callInst->getCalledOperand();
 
-    SmallVector<Type> types;
-    SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(callInst, types, operands,
-                                          /*allowInlineAsm=*/true)))
+    FailureOr<SmallVector<Value>> operands =
+        convertCallOperands(callInst, /*allowInlineAsm=*/true);
+    if (failed(operands))
       return failure();
 
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
-      Type resultTy = convertType(callInst->getType());
-      if (!resultTy)
-        return failure();
-      auto callOp = builder.create<InlineAsmOp>(
-          loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
-          builder.getStringAttr(asmI->getConstraintString()),
-          /*has_side_effects=*/true,
-          /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
-          /*operand_attrs=*/nullptr);
-      if (!callInst->getType()->isVoidTy())
-        mapValue(inst, callOp.getResult(0));
-      else
-        mapNoResultOp(inst, callOp);
-    } else {
-      auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
-        // Retrieve the real function type. For direct calls, use the callee's
-        // function type, as it may differ from the operand type in the case of
-        // variadic functions. For indirect calls, use the call function type.
-        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
-          return convertType(callee->getFunctionType());
-        return convertType(callInst->getFunctionType());
-      }());
-
-      if (!funcTy)
-        return failure();
+    auto callOp = [&]() -> FailureOr<Operation *> {
+      if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+        Type resultTy = convertType(callInst->getType());
+        if (!resultTy)
+          return failure();
+        return builder
+            .create<InlineAsmOp>(
+                loc, resultTy, *operands,
+                builder.getStringAttr(asmI->getAsmString()),
+                builder.getStringAttr(asmI->getConstraintString()),
+                /*has_side_effects=*/true,
+                /*is_align_stack=*/false, /*asm_dialect=*/nullptr,
+                /*operand_attrs=*/nullptr)
+            .getOperation();
+      } else {
+        LLVMFunctionType funcTy = convertFunctionType(callInst);
+        if (!funcTy)
+          return failure();
 
-      auto callOp = [&]() -> CallOp {
-        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
-          auto name = SymbolRefAttr::get(context, callee->getName());
-          return builder.create<CallOp>(loc, funcTy, name, operands);
-        }
-        return builder.create<CallOp>(loc, funcTy, operands);
-      }();
-
-      // Handle function attributes.
-      callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
-      callOp.setTailCallKind(
-          convertTailCallKindFromLLVM(callInst->getTailCallKind()));
-      setFastmathFlagsAttr(inst, callOp);
-
-      callOp.setConvergent(callInst->isConvergent());
-      callOp.setNoUnwind(callInst->doesNotThrow());
-      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
-
-      llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
-      ModRefInfo othermem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::Other));
-      ModRefInfo argMem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
-      ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
-          memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
-      auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
-                                            argMem, inaccessibleMem);
-      // Only set the attribute when it does not match the default value.
-      if (!memAttr.isReadWrite())
-        callOp.setMemoryEffectsAttr(memAttr);
-
-      if (!callInst->getType()->isVoidTy())
-        mapValue(inst, callOp.getResult());
-      else
-        mapNoResultOp(inst, callOp);
-    }
+        FlatSymbolRefAttr callee = convertCalleeName(callInst);
+        auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
+        if (failed(convertCallAttributes(callInst, callOp)))
+          return failure();
+        return callOp.getOperation();
+      }
+    }();
+
+    if (failed(callOp))
+      return failure();
+
+    if (!callInst->getType()->isVoidTy())
+      mapValue(inst, (*callOp)->getResult(0));
+    else
+      mapNoResultOp(inst, *callOp);
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::LandingPad) {
@@ -1695,9 +1685,11 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
   if (inst->getOpcode() == llvm::Instruction::Invoke) {
     auto *invokeInst = cast<llvm::InvokeInst>(inst);
 
-    SmallVector<Type> types;
-    SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
+    if (invokeInst->isInlineAsm())
+      return emitError(loc) << "invoke of inline assembly is not supported";
+
+    FailureOr<SmallVector<Value>> operands = convertCallOperands(invokeInst);
+    if (failed(operands))
       return failure();
 
     // Check whether the invoke result is an argument to the normal destination
@@ -1724,27 +1716,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                  unwindArgs)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
+    auto funcTy = convertFunctionType(invokeInst);
     if (!funcTy)
       return failure();
 
+    FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
+
     // Create the invoke operation. Normal destination block arguments will be
     // added later on to handle the case in which the operation result is
     // included in this list.
-    InvokeOp invokeOp;
-    if (llvm::Function *callee = invokeInst->getCalledFunction()) {
-      invokeOp = builder.create<InvokeOp>(
-          loc, funcTy,
-          SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
-          directNormalDest, ValueRange(),
-          lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
-    } else {
-      invokeOp = builder.create<InvokeOp>(
-          loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
-          ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
-    }
-    invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
+    auto invokeOp = builder.create<InvokeOp>(
+        loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
+        lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
+
+    if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
+      return failure();
+
     if (!invokeInst->getType()->isVoidTy())
       mapValue(inst, invokeOp.getResults().front());
     else
@@ -2097,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+template <typename Op>
+static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
+  op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
+  return success();
+}
+
+LogicalResult ModuleImport::convertInvokeAttributes(llvm::InvokeInst *inst,
+                                                    InvokeOp op) {
+  return convertCallBaseAttributes(inst, op);
+}
+
+LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
+                                                  CallOp op) {
+  setFastmathFlagsAttr(inst, op.getOperation());
+  op.setTailCallKind(convertTailCallKindFromLLVM(inst->getTailCallKind()));
+  op.setConvergent(inst->isConvergent());
+  op.setNoUnwind(inst->doesNotThrow());
+  op.setWillReturn(inst->hasFnAttr(llvm::Attribute::WillReturn));
+
+  llvm::MemoryEffects memEffects = inst->getMemoryEffects();
+  ModRefInfo othermem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+  ModRefInfo argMem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+  ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+      memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+  auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
+                                        inaccessibleMem);
+  // Only set the attribute when it does not match the default value.
+  if (!memAttr.isReadWrite())
+    op.setMemoryEffectsAttr(memAttr);
+
+  return convertCallBaseAttributes(inst, op);
+}
+
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 77052ab6e41f6c..c294e1b34f9bbd 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -702,3 +702,21 @@ define void @fence() {
   fence syncscope("") seq_cst
   ret void
 }
+
+; // -----
+
+; CHECK-LABEL: @f
+define void @f() personality ptr @__gxx_personality_v0 {
+entry:
+; CHECK: llvm.invoke @g() to ^bb1 unwind ^bb2 vararg(!llvm.func<void (...)>) : () -> ()
+  invoke void @g() to label %bb1 unwind label %bb2
+bb1:
+  ret void
+bb2:
+  %0 = landingpad i32 cleanup
+  unreachable
+}
+
+declare void @g(...)
+
+declare i32 @__gxx_personality_v0(...)



More information about the Mlir-commits mailing list