[Mlir-commits] [mlir] [mlir][LLVMIR] Add IFuncOp to LLVM dialect (PR #147697)

Robert Konicar llvmlistbot at llvm.org
Wed Jul 9 04:10:00 PDT 2025


https://github.com/Jezurko created https://github.com/llvm/llvm-project/pull/147697

Add IFunc to LLVM dialect and add support for lifting/exporting LLVMIR IFunc.

>From 36b982ffbf2f92b8048a4319b3fc4f8ed44475ce Mon Sep 17 00:00:00 2001
From: Robert Konicar <rkonicar at mail.muni.cz>
Date: Mon, 7 Jul 2025 14:01:49 +0200
Subject: [PATCH] [mlir][LLVMIR] Add IFuncOp to LLVM dialect

Add IFunc to LLVM dialect and add support for lifting/exporting LLVMIR
IFunc.
---
 .../include/mlir/Dialect/LLVMIR/LLVMDialect.h |   3 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  40 ++++++
 .../include/mlir/Target/LLVMIR/ModuleImport.h |   5 +
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |   9 ++
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 106 ++++++++++++---
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  24 +++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  49 ++++++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  28 +++-
 mlir/test/Target/LLVMIR/Import/ifunc.ll       | 115 ++++++++++++++++
 mlir/test/Target/LLVMIR/ifunc.mlir            | 123 ++++++++++++++++++
 10 files changed, 471 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/ifunc.ll
 create mode 100644 mlir/test/Target/LLVMIR/ifunc.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 63e007cdc335c..e355bb8f5ddae 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -223,6 +223,9 @@ Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
 /// function confirms that the Operation has the desired properties.
 bool satisfiesLLVMModule(Operation *op);
 
+/// Lookup parent Module satisfying LLVM conditions on the Module Operation.
+Operation *parentLLVMModule(Operation *op);
+
 /// Convert an array of integer attributes to a vector of integers that can be
 /// used as indices in LLVM operations.
 template <typename IntT = int64_t>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f4c1640098320..fe1418b12b90a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1285,6 +1285,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
     /// Return the llvm.mlir.alias operation that defined the value referenced
     /// here.
     AliasOp getAlias(SymbolTableCollection &symbolTable);
+
+    /// Return the llvm.mlir.ifunc operation that defined the value referenced
+    /// here.
+    IFuncOp getIFunc(SymbolTableCollection &symbolTable);
   }];
 
   let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))";
@@ -1601,6 +1605,42 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias",
   let hasRegionVerifier = 1;
 }
 
+def LLVM_IFuncOp : LLVM_Op<"mlir.ifunc",
+    [IsolatedFromAbove, Symbol, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let arguments = (ins
+    SymbolNameAttr:$sym_name,
+    TypeAttr:$i_func_type,
+    FlatSymbolRefAttr:$resolver,
+    TypeAttr:$resolver_type,
+    UnitAttr:$dso_local,
+    DefaultValuedAttr<ConfinedAttr<I32Attr, [IntNonNegative]>, "0">:$address_space,
+    DefaultValuedAttr<Linkage, "mlir::LLVM::Linkage::External">:$linkage,
+    DefaultValuedAttr<UnnamedAddr, "mlir::LLVM::UnnamedAddr::None">:$unnamed_addr,
+    DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_
+  );
+  let summary = "LLVM dialect ifunc";
+  let description = [{
+    `llvm.mlir.ifunc` is a top level operation that defines a global ifunc.
+    It defines a new symbol and takes a symbol refering to a resolver function.
+    IFuncs can be called as regular functions. The function type is the same
+    as the IFuncType. The symbol is resolved at runtime by calling a resolver
+    function.
+  }];
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$name, "Type":$i_func_type,
+      "StringRef":$resolver, "Type":$resolver_type,
+      "Linkage":$linkage, "LLVM::Visibility":$visibility)>
+  ];
+
+  let assemblyFormat = [{
+    (custom<LLVMLinkage>($linkage)^)? ($visibility_^)? ($unnamed_addr^)?
+    $sym_name `:` $i_func_type `,` $resolver_type $resolver attr-dict
+  }];
+  let hasVerifier = 1;
+}
+
+
 def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
     [Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins FlatSymbolRefAttr:$function_name);
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 9902c6bb15caf..b21600b634d2e 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -71,6 +71,9 @@ class ModuleImport {
   /// Converts all aliases of the LLVM module to MLIR variables.
   LogicalResult convertAliases();
 
+  /// Converts all ifuncs of the LLVM module to MLIR variables.
+  LogicalResult convertIFuncs();
+
   /// Converts the data layout of the LLVM module to an MLIR data layout
   /// specification.
   LogicalResult convertDataLayout();
@@ -320,6 +323,8 @@ class ModuleImport {
   /// Converts an LLVM global alias variable into an MLIR LLVM dialect alias
   /// operation if a conversion exists. Otherwise, returns failure.
   LogicalResult convertAlias(llvm::GlobalAlias *alias);
+  // Converts an LLVM global ifunc into an MLIR LLVM diaeclt ifunc operation
+  LogicalResult convertIFunc(llvm::GlobalIFunc *ifunc);
   /// Returns personality of `func` as a FlatSymbolRefAttr.
   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
   /// Imports `bb` into `block`, which must be initially empty.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 79e8bb6add0da..fc82be3b88395 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -223,6 +223,12 @@ class ModuleTranslation {
     return aliasesMapping.lookup(op);
   }
 
+  /// Finds an LLVM IR global value that corresponds to the given MLIR operation
+  /// defining an IFunc.
+  llvm::GlobalValue *lookupIFunc(Operation *op) {
+    return ifuncMapping.lookup(op);
+  }
+
   /// Returns the OpenMP IR builder associated with the LLVM IR module being
   /// constructed.
   llvm::OpenMPIRBuilder *getOpenMPBuilder();
@@ -308,6 +314,7 @@ class ModuleTranslation {
                                  bool recordInsertions = false);
   LogicalResult convertFunctionSignatures();
   LogicalResult convertFunctions();
+  LogicalResult convertIFuncs();
   LogicalResult convertComdats();
 
   LogicalResult convertUnresolvedBlockAddress();
@@ -369,6 +376,8 @@ class ModuleTranslation {
   /// aliases.
   DenseMap<Operation *, llvm::GlobalValue *> aliasesMapping;
 
+  DenseMap<Operation *, llvm::GlobalValue *> ifuncMapping;
+
   /// A stateful object used to translate types.
   TypeToLLVMIRTranslator typeTranslator;
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6dcd94e6eea17..da5a3f40faaa3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -139,6 +139,17 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
   return static_cast<RetTy>(index);
 }
 
+static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
+  p << stringifyLinkage(val.getLinkage());
+}
+
+static OptionalParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
+  val = LinkageAttr::get(
+      p.getContext(),
+      parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Operand bundle helpers.
 //===----------------------------------------------------------------------===//
@@ -1175,14 +1186,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       return emitOpError()
              << "'" << calleeName.getValue()
              << "' does not reference a symbol in the current scope";
-    auto fn = dyn_cast<LLVMFuncOp>(callee);
-    if (!fn)
-      return emitOpError() << "'" << calleeName.getValue()
-                           << "' does not reference a valid LLVM function";
-
-    if (failed(verifyCallOpDebugInfo(*this, fn)))
-      return failure();
-    fnType = fn.getFunctionType();
+    if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
+      if (failed(verifyCallOpDebugInfo(*this, fn)))
+        return failure();
+      fnType = fn.getFunctionType();
+    } else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
+      fnType = ifunc.getIFuncType();
+    } else {
+      return emitOpError()
+             << "'" << calleeName.getValue()
+             << "' does not reference a valid LLVM function or IFunc";
+    }
   }
 
   LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
@@ -2038,14 +2052,6 @@ LogicalResult ReturnOp::verify() {
 // LLVM::AddressOfOp.
 //===----------------------------------------------------------------------===//
 
-static Operation *parentLLVMModule(Operation *op) {
-  Operation *module = op->getParentOp();
-  while (module && !satisfiesLLVMModule(module))
-    module = module->getParentOp();
-  assert(module && "unexpected operation outside of a module");
-  return module;
-}
-
 GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
   return dyn_cast_or_null<GlobalOp>(
       symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
@@ -2061,6 +2067,11 @@ AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
       symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
 }
 
+IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
+  return dyn_cast_or_null<IFuncOp>(
+      symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
+}
+
 LogicalResult
 AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   Operation *symbol =
@@ -2069,10 +2080,11 @@ AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto global = dyn_cast_or_null<GlobalOp>(symbol);
   auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
   auto alias = dyn_cast_or_null<AliasOp>(symbol);
+  auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);
 
-  if (!global && !function && !alias)
+  if (!global && !function && !alias && !ifunc)
     return emitOpError("must reference a global defined by 'llvm.mlir.global', "
-                       "'llvm.mlir.alias' or 'llvm.func'");
+                       "'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");
 
   LLVMPointerType type = getType();
   if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
@@ -2682,6 +2694,56 @@ unsigned AliasOp::getAddrSpace() {
   return ptrTy.getAddressSpace();
 }
 
+//===----------------------------------------------------------------------===//
+// IFuncOp
+//===----------------------------------------------------------------------===//
+
+void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
+                    Type iFuncType, StringRef resolverName, Type resolverType,
+                    Linkage linkage, LLVM::Visibility visibility) {
+  return build(builder, result, name, iFuncType, resolverName, resolverType,
+               /* dso_local */ false, /* addr_space */ 0, linkage,
+               UnnamedAddr::None, visibility);
+}
+LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return success();
+  Operation *symbol =
+      symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
+  auto resolver = dyn_cast<LLVMFuncOp>(symbol);
+  if (!resolver)
+    return emitOpError("IFunc must have a Function resolver");
+
+  // Copying logic from llvm/lib/IR/Verifier.cpp
+  Linkage linkage = resolver.getLinkage();
+  if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
+    return emitOpError("IFunc resolver must be a definition");
+  if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
+    return emitOpError("IFunc resolver must return a pointer");
+  auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
+  if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
+    return emitOpError("IFunc resolver has incorrect type");
+  return success();
+}
+
+LogicalResult IFuncOp::verify() {
+  switch (getLinkage()) {
+  case Linkage::External:
+  case Linkage::Internal:
+  case Linkage::Private:
+  case Linkage::Weak:
+  case Linkage::WeakODR:
+  case Linkage::Linkonce:
+  case Linkage::LinkonceODR:
+    break;
+  default:
+    return emitOpError() << "'" << stringifyLinkage(getLinkage())
+                         << "' linkage not supported in ifuncs, available "
+                            "options: private, internal, linkonce, weak, "
+                            "linkonce_odr, weak_odr, or external linkage";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ShuffleVectorOp
 //===----------------------------------------------------------------------===//
@@ -4329,3 +4391,11 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
   return op->hasTrait<OpTrait::SymbolTable>() &&
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
+
+Operation *mlir::LLVM::parentLLVMModule(Operation *op) {
+  Operation *module = op->getParentOp();
+  while (module && !satisfiesLLVMModule(module))
+    module = module->getParentOp();
+  assert(module && "unexpected operation outside of a module");
+  return module;
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 70029d7e15a90..9a648ecd3a8d2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -422,9 +422,18 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     ArrayRef<llvm::Value *> operandsRef(operands);
     llvm::CallInst *call;
     if (auto attr = callOp.getCalleeAttr()) {
-      call =
-          builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
-                             operandsRef, opBundles);
+      if (llvm::Function *function =
+              moduleTranslation.lookupFunction(attr.getValue())) {
+        call = builder.CreateCall(function, operandsRef, opBundles);
+      } else {
+        Operation *module = parentLLVMModule(&opInst);
+        Operation *ifuncOp =
+            moduleTranslation.symbolTable().lookupSymbolIn(module, attr);
+        llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
+        llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
+            moduleTranslation.convertType(callOp.getCalleeFunctionType()));
+        call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
+      }
     } else {
       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
           moduleTranslation.convertType(callOp.getCalleeFunctionType()));
@@ -648,18 +657,21 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     LLVM::LLVMFuncOp function =
         addressOfOp.getFunction(moduleTranslation.symbolTable());
     LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable());
+    LLVM::IFuncOp ifunc = addressOfOp.getIFunc(moduleTranslation.symbolTable());
 
     // The verifier should not have allowed this.
-    assert((global || function || alias) &&
-           "referencing an undefined global, function, or alias");
+    assert((global || function || alias || ifunc) &&
+           "referencing an undefined global, function, alias, or ifunc");
 
     llvm::Value *llvmValue = nullptr;
     if (global)
       llvmValue = moduleTranslation.lookupGlobal(global);
     else if (alias)
       llvmValue = moduleTranslation.lookupAlias(alias);
-    else
+    else if (function)
       llvmValue = moduleTranslation.lookupFunction(function.getName());
+    else
+      llvmValue = moduleTranslation.lookupIFunc(ifunc);
 
     moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue);
     return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index bfda223fe0f5f..a88e1c9847fcf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1031,6 +1031,16 @@ LogicalResult ModuleImport::convertAliases() {
   return success();
 }
 
+LogicalResult ModuleImport::convertIFuncs() {
+  for (llvm::GlobalIFunc &ifunc : llvmModule->ifuncs()) {
+    if (failed(convertIFunc(&ifunc))) {
+      return emitError(UnknownLoc::get(context))
+             << "unhandled global ifunc: " << diag(ifunc);
+    }
+  }
+  return success();
+}
+
 LogicalResult ModuleImport::convertDataLayout() {
   Location loc = mlirModule.getLoc();
   DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
@@ -1369,6 +1379,21 @@ LogicalResult ModuleImport::convertAlias(llvm::GlobalAlias *alias) {
   return success();
 }
 
+LogicalResult ModuleImport::convertIFunc(llvm::GlobalIFunc *ifunc) {
+  OpBuilder::InsertionGuard guard = setGlobalInsertionPoint();
+
+  Type type = convertType(ifunc->getValueType());
+  llvm::Constant *resolver = ifunc->getResolver();
+  Type resolverType = convertType(resolver->getType());
+  builder.create<IFuncOp>(mlirModule.getLoc(), ifunc->getName(), type,
+                          resolver->getName(), resolverType,
+                          ifunc->isDSOLocal(), ifunc->getAddressSpace(),
+                          convertLinkageFromLLVM(ifunc->getLinkage()),
+                          convertUnnamedAddrFromLLVM(ifunc->getUnnamedAddr()),
+                          convertVisibilityFromLLVM(ifunc->getVisibility()));
+  return success();
+}
+
 LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
   // Insert the global after the last one or at the start of the module.
   OpBuilder::InsertionGuard guard = setGlobalInsertionPoint();
@@ -1973,8 +1998,9 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
   // treated as indirect calls to constant operands that need to be converted.
   // Skip the callee operand if it's inline assembly, as it's handled separately
   // in InlineAsmOp.
-  if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
-    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+  llvm::Value *caleeOperand = callInst->getCalledOperand();
+  if (!isa<llvm::Function, llvm::GlobalIFunc>(caleeOperand) && !isInlineAsm) {
+    FailureOr<Value> called = convertValue(caleeOperand);
     if (failed(called))
       return failure();
     operands.push_back(*called);
@@ -2035,12 +2061,21 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst,
   if (failed(callType))
     return failure();
   auto *callee = dyn_cast<llvm::Function>(calledOperand);
+
+  llvm::FunctionType *origCalleeType = nullptr;
+  if (callee) {
+    origCalleeType = callee->getFunctionType();
+  } else if (auto *ifunc = dyn_cast<llvm::GlobalIFunc>(calledOperand)) {
+    origCalleeType =
+        dyn_cast_or_null<llvm::FunctionType>(ifunc->getValueType());
+  }
+
   // For indirect calls, return the type of the call itself.
-  if (!callee)
+  if (!origCalleeType)
     return callType;
 
   FailureOr<LLVMFunctionType> calleeType =
-      castOrFailure(convertType(callee->getFunctionType()));
+      castOrFailure(convertType(origCalleeType));
   if (failed(calleeType))
     return failure();
 
@@ -2059,8 +2094,8 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst,
 
 FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
   llvm::Value *calledOperand = callInst->getCalledOperand();
-  if (auto *callee = dyn_cast<llvm::Function>(calledOperand))
-    return SymbolRefAttr::get(context, callee->getName());
+  if (isa<llvm::Function, llvm::GlobalIFunc>(calledOperand))
+    return SymbolRefAttr::get(context, calledOperand->getName());
   return {};
 }
 
@@ -3162,6 +3197,8 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule(
     return {};
   if (failed(moduleImport.convertAliases()))
     return {};
+  if (failed(moduleImport.convertIFuncs()))
+    return {};
   moduleImport.convertTargetTriple();
   return module;
 }
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 8908703cc1368..1a2a585e34d44 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -791,6 +791,8 @@ void ModuleTranslation::forgetMapping(Region &region) {
           globalsMapping.erase(&op);
         if (isa<LLVM::AliasOp>(op))
           aliasesMapping.erase(&op);
+        if (isa<LLVM::IFuncOp>(op))
+          ifuncMapping.erase(&op);
         if (isa<LLVM::CallOp>(op))
           callMapping.erase(&op);
         llvm::append_range(
@@ -1868,6 +1870,27 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
+LogicalResult ModuleTranslation::convertIFuncs() {
+  for (auto op : getModuleBody(mlirModule).getOps<IFuncOp>()) {
+    llvm::Type *type = convertType(op.getIFuncType());
+    llvm::GlobalValue::LinkageTypes linkage =
+        convertLinkageToLLVM(op.getLinkage());
+    llvm::Constant *cst =
+        dyn_cast<llvm::Constant>(lookupFunction(op.getResolver()));
+
+    auto *ifunc =
+        llvm::GlobalIFunc::create(type, op.getAddressSpace(), linkage,
+                                  op.getSymName(), cst, llvmModule.get());
+    addRuntimePreemptionSpecifier(op.getDsoLocal(), ifunc);
+    ifunc->setUnnamedAddr(convertUnnamedAddrToLLVM(op.getUnnamedAddr()));
+    ifunc->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
+
+    ifuncMapping.try_emplace(op, ifunc);
+  }
+
+  return success();
+}
+
 LogicalResult ModuleTranslation::convertComdats() {
   for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
     for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
@@ -2284,6 +2307,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
     return nullptr;
   if (failed(translator.convertGlobalsAndAliases()))
     return nullptr;
+  if (failed(translator.convertIFuncs()))
+    return nullptr;
   if (failed(translator.createTBAAMetadata()))
     return nullptr;
   if (failed(translator.createIdentMetadata()))
@@ -2296,7 +2321,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
   // Convert other top-level operations if possible.
   for (Operation &o : getModuleBody(module).getOperations()) {
     if (!isa<LLVM::LLVMFuncOp, LLVM::AliasOp, LLVM::GlobalOp,
-             LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
+             LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp,
+             LLVM::IFuncOp>(&o) &&
         !o.hasTrait<OpTrait::IsTerminator>() &&
         failed(translator.convertOperation(o, llvmBuilder))) {
       return nullptr;
diff --git a/mlir/test/Target/LLVMIR/Import/ifunc.ll b/mlir/test/Target/LLVMIR/Import/ifunc.ll
new file mode 100644
index 0000000000000..020cf8f99d9b7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/ifunc.ll
@@ -0,0 +1,115 @@
+; RUN: mlir-translate --import-llvm %s -split-input-file | FileCheck %s
+
+ at __const.main.data = private unnamed_addr constant [10 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10], align 16
+
+; CHECK: llvm.mlir.ifunc @foo : !llvm.func<void (ptr, i64)>, !llvm.ptr @resolve_foo {dso_local}
+ at foo = dso_local ifunc void (ptr, i64), ptr @resolve_foo
+
+define dso_local void @foo_1(ptr noundef %0, i64 noundef %1) #0 {
+  %3 = alloca ptr, align 8
+  %4 = alloca i64, align 8
+  store ptr %0, ptr %3, align 8
+  store i64 %1, ptr %4, align 8
+  ret void
+}
+
+define dso_local void @foo_2(ptr noundef %0, i64 noundef %1) #0 {
+  %3 = alloca ptr, align 8
+  %4 = alloca i64, align 8
+  store ptr %0, ptr %3, align 8
+  store i64 %1, ptr %4, align 8
+  ret void
+}
+
+define dso_local i32 @main() #0 {
+  %1 = alloca [10 x i32], align 16
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %1, ptr align 16 @__const.main.data, i64 40, i1 false)
+  %2 = getelementptr inbounds [10 x i32], ptr %1, i64 0, i64 0
+; CHECK: llvm.call @foo
+  call void @foo(ptr noundef %2, i64 noundef 10)
+  ret i32 0
+}
+
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) #1
+
+define internal ptr @resolve_foo() #2 {
+  %1 = alloca ptr, align 8
+  %2 = call i32 @check()
+  %3 = icmp ne i32 %2, 0
+  br i1 %3, label %4, label %5
+
+4:                                                ; preds = %0
+  store ptr @foo_1, ptr %1, align 8
+  br label %6
+
+5:                                                ; preds = %0
+  store ptr @foo_2, ptr %1, align 8
+  br label %6
+
+6:                                                ; preds = %5, %4
+  %7 = load ptr, ptr %1, align 8
+  ret ptr %7
+}
+
+declare i32 @check() #3
+
+; // -----
+
+ at __const.main.data = private unnamed_addr constant [10 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10], align 16
+
+; CHECK: llvm.mlir.ifunc @foo : !llvm.func<void (ptr, i64)>, !llvm.ptr @resolve_foo {dso_local}
+ at foo = dso_local ifunc void (ptr, i64), ptr @resolve_foo
+
+define dso_local void @foo_1(ptr noundef %0, i64 noundef %1) #0 {
+  %3 = alloca ptr, align 8
+  %4 = alloca i64, align 8
+  store ptr %0, ptr %3, align 8
+  store i64 %1, ptr %4, align 8
+  ret void
+}
+
+define dso_local void @foo_2(ptr noundef %0, i64 noundef %1) #0 {
+  %3 = alloca ptr, align 8
+  %4 = alloca i64, align 8
+  store ptr %0, ptr %3, align 8
+  store i64 %1, ptr %4, align 8
+  ret void
+}
+
+define dso_local i32 @main() #0 {
+  %1 = alloca [10 x i32], align 16
+  %2 = alloca ptr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %1, ptr align 16 @__const.main.data, i64 40, i1 false)
+; CHECK: [[CALLEE:%[0-9]+]] = llvm.mlir.addressof @foo
+; CHECK: llvm.store [[CALLEE]], [[STORED:%[0-9]+]]
+; CHECK: [[LOADED_CALLEE:%[0-9]+]] = llvm.load [[STORED]]
+  store ptr @foo, ptr %2, align 8
+  %3 = load ptr, ptr %2, align 8
+  %4 = getelementptr inbounds [10 x i32], ptr %1, i64 0, i64 0
+; CHECK: llvm.call [[LOADED_CALLEE]]
+  call void %3(ptr noundef %4, i64 noundef 10)
+  ret i32 0
+}
+
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) #1
+
+define internal ptr @resolve_foo() #2 {
+  %1 = alloca ptr, align 8
+  %2 = call i32 @check()
+  %3 = icmp ne i32 %2, 0
+  br i1 %3, label %4, label %5
+
+4:                                                ; preds = %0
+  store ptr @foo_1, ptr %1, align 8
+  br label %6
+
+5:                                                ; preds = %0
+  store ptr @foo_2, ptr %1, align 8
+  br label %6
+
+6:                                                ; preds = %5, %4
+  %7 = load ptr, ptr %1, align 8
+  ret ptr %7
+}
+
+declare i32 @check() #3
diff --git a/mlir/test/Target/LLVMIR/ifunc.mlir b/mlir/test/Target/LLVMIR/ifunc.mlir
new file mode 100644
index 0000000000000..ea0d590bbd0ce
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/ifunc.mlir
@@ -0,0 +1,123 @@
+// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file | FileCheck %s
+
+llvm.mlir.global private unnamed_addr constant @__const.main.data(dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 16 : i64, dso_local} : !llvm.array<10 x i32>
+
+// CHECK: @foo = dso_local ifunc void (ptr, i64), ptr @resolve_foo
+llvm.mlir.ifunc @foo : !llvm.func<void (ptr, i64)>, !llvm.ptr @resolve_foo {dso_local}
+llvm.func @foo_1(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i64 {llvm.noundef}) attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg0, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.store %arg1, %2 {alignment = 8 : i64} : i64, !llvm.ptr
+  llvm.return
+}
+llvm.func @foo_2(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i64 {llvm.noundef}) attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg0, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.store %arg1, %2 {alignment = 8 : i64} : i64, !llvm.ptr
+  llvm.return
+}
+llvm.func @main() -> i32 attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.addressof @__const.main.data : !llvm.ptr
+  %2 = llvm.mlir.constant(40 : i64) : i64
+  %3 = llvm.mlir.constant(0 : i64) : i64
+  %4 = llvm.mlir.constant(10 : i64) : i64
+  %5 = llvm.mlir.constant(0 : i32) : i32
+  %6 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 16 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memcpy"(%6, %1, %2) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
+  %7 = llvm.getelementptr inbounds %6[%3, %3] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.array<10 x i32>
+
+// CHECK: call void @foo
+  llvm.call @foo(%7, %4) : (!llvm.ptr {llvm.noundef}, i64 {llvm.noundef}) -> ()
+  llvm.return %5 : i32
+}
+llvm.func internal @resolve_foo() -> !llvm.ptr attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(0 : i32) : i32
+  %2 = llvm.mlir.addressof @foo_2 : !llvm.ptr
+  %3 = llvm.mlir.addressof @foo_1 : !llvm.ptr
+  %4 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %5 = llvm.call @check() : () -> i32
+  %6 = llvm.icmp "ne" %5, %1 : i32
+  llvm.cond_br %6, ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  llvm.store %3, %4 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.br ^bb3
+^bb2:  // pred: ^bb0
+  llvm.store %2, %4 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.br ^bb3
+^bb3:  // 2 preds: ^bb1, ^bb2
+  %7 = llvm.load %4 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.return %7 : !llvm.ptr
+}
+llvm.func @check() -> i32
+
+// -----
+
+llvm.mlir.global private unnamed_addr constant @__const.main.data(dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 16 : i64, dso_local} : !llvm.array<10 x i32>
+
+// CHECK: @foo = dso_local ifunc void (ptr, i64), ptr @resolve_foo
+llvm.mlir.ifunc @foo : !llvm.func<void (ptr, i64)>, !llvm.ptr @resolve_foo {dso_local}
+llvm.func @foo_1(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i64 {llvm.noundef}) attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg0, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.store %arg1, %2 {alignment = 8 : i64} : i64, !llvm.ptr
+  llvm.return
+}
+llvm.func @foo_2(%arg0: !llvm.ptr {llvm.noundef}, %arg1: i64 {llvm.noundef}) attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg0, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.store %arg1, %2 {alignment = 8 : i64} : i64, !llvm.ptr
+  llvm.return
+}
+llvm.func @main() -> i32 attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.addressof @__const.main.data : !llvm.ptr
+  %2 = llvm.mlir.constant(40 : i64) : i64
+  %3 = llvm.mlir.addressof @foo : !llvm.ptr
+  %4 = llvm.mlir.constant(0 : i64) : i64
+  %5 = llvm.mlir.constant(10 : i64) : i64
+  %6 = llvm.mlir.constant(0 : i32) : i32
+  %7 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 16 : i64} : (i32) -> !llvm.ptr
+  %8 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memcpy"(%7, %1, %2) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
+
+// CHECK: store ptr @foo, ptr [[STORED:%[0-9]+]]
+  llvm.store %3, %8 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+
+// CHECK: [[LOADED:%[0-9]+]] = load ptr, ptr [[STORED]]
+  %9 = llvm.load %8 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  %10 = llvm.getelementptr inbounds %7[%4, %4] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.array<10 x i32>
+
+// CHECK: call void [[LOADED]]
+  llvm.call %9(%10, %5) : !llvm.ptr, (!llvm.ptr {llvm.noundef}, i64 {llvm.noundef}) -> ()
+  llvm.return %6 : i32
+}
+llvm.func internal @resolve_foo() -> !llvm.ptr attributes {dso_local} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(0 : i32) : i32
+  %2 = llvm.mlir.addressof @foo_2 : !llvm.ptr
+  %3 = llvm.mlir.addressof @foo_1 : !llvm.ptr
+  %4 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %5 = llvm.call @check() : () -> i32
+  %6 = llvm.icmp "ne" %5, %1 : i32
+  llvm.cond_br %6, ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  llvm.store %3, %4 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.br ^bb3
+^bb2:  // pred: ^bb0
+  llvm.store %2, %4 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  llvm.br ^bb3
+^bb3:  // 2 preds: ^bb1, ^bb2
+  %7 = llvm.load %4 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.return %7 : !llvm.ptr
+}
+llvm.func @check() -> i32



More information about the Mlir-commits mailing list