[Mlir-commits] [mlir] 0252357 - [mlir][LLVM] Add support for Calling Convention in LLVMFuncOp

Alexander Batashev llvmlistbot at llvm.org
Thu May 26 23:43:51 PDT 2022


Author: Alexander Batashev
Date: 2022-05-27T09:43:31+03:00
New Revision: 0252357b3e1e8f6d3bc51ac6d7ac075842b2c956

URL: https://github.com/llvm/llvm-project/commit/0252357b3e1e8f6d3bc51ac6d7ac075842b2c956
DIFF: https://github.com/llvm/llvm-project/commit/0252357b3e1e8f6d3bc51ac6d7ac075842b2c956.diff

LOG: [mlir][LLVM] Add support for Calling Convention in LLVMFuncOp

This patch adds support for Calling Convention attribute in LLVM
dialect, including enums, custom syntax and import from LLVM IR.
Additionally fix import of dso_local attribute.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D126161

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Target/LLVMIR/Import/basic.ll
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index f5b97002b5a15..4a095817ab0ce 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -35,6 +35,15 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
   let hasCustomAssemblyFormat = 1;
 }
 
+// Attribute definition for the LLVM Linkage enum.
+def CConvAttr : LLVM_Attr<"CConv"> {
+  let mnemonic = "cconv";
+  let parameters = (ins
+    "CConv":$CConv
+  );
+  let hasCustomAssemblyFormat = 1;
+}
+
 def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
   let mnemonic = "loopopts";
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 67709646a2227..d97a267f4aba3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -39,6 +39,7 @@ namespace LLVM {
 // attribute definition itself.
 // TODO: this shouldn't be needed after we unify the attribute generation, i.e.
 // --gen-attr-* and --gen-attrdef-*.
+using cconv::CConv;
 using linkage::Linkage;
 } // namespace LLVM
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 1b986968f928a..2fe1130c6bb73 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -234,6 +234,14 @@ class LLVM_EnumAttr<string name, string llvmName, string description,
   string llvmClassName = llvmName;
 }
 
+// LLVM_CEnumAttr is functionally identical to LLVM_EnumAttr, but to be used for
+// non-class enums.
+class LLVM_CEnumAttr<string name, string llvmNS, string description,
+      list<LLVM_EnumAttrCase> cases> :
+    I64EnumAttr<name, description, cases> {
+  string llvmClassName = llvmNS;
+}
+
 // For every value in the list, substitutes the value in the place of "$0" in
 // "pattern" and stores the list of strings as "lst".
 class ListIntSubst<string pattern, list<int> values> {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 039294a6e200b..9b77699ee7d8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -67,6 +67,134 @@ def LoopOptionCase : I32EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
+// These values must match llvm::CallingConv ones.
+// See https://llvm.org/doxygen/namespacellvm_1_1CallingConv.html for full list
+// of supported calling conventions.
+def CConvC : LLVM_EnumAttrCase<"C", "ccc", "C", 0>;
+def CConvFast : LLVM_EnumAttrCase<"Fast", "fastcc", "Fast", 8>;
+def CConvCold : LLVM_EnumAttrCase<"Cold", "coldcc", "Cold", 9>;
+def CConvGHC : LLVM_EnumAttrCase<"GHC", "cc_10", "GHC", 10>;
+def CConvHiPE : LLVM_EnumAttrCase<"HiPE", "cc_11", "HiPE", 11>;
+def CConvWebKitJS : LLVM_EnumAttrCase<"WebKit_JS", "webkit_jscc",
+                                      "WebKit_JS", 12>;
+def CConvAnyReg : LLVM_EnumAttrCase<"AnyReg", "anyregcc", "AnyReg", 13>;
+def CConvPreserveMost : LLVM_EnumAttrCase<"PreserveMost", "preserve_mostcc",
+                                          "PreserveMost", 14>;
+def CConvPreserveAll : LLVM_EnumAttrCase<"PreserveAll", "preserve_allcc",
+                                         "PreserveAll", 15>;
+def CConvSwift : LLVM_EnumAttrCase<"Swift", "swiftcc", "Swift", 16>;
+def CConvCXXFastTLS : LLVM_EnumAttrCase<"CXX_FAST_TLS", "cxx_fast_tlscc",
+                                        "CXX_FAST_TLS", 17>;
+def CConvTail : LLVM_EnumAttrCase<"Tail", "tailcc", "Tail", 18>;
+def CConvCFGuard_Check : LLVM_EnumAttrCase<"CFGuard_Check",
+                                            "cfguard_checkcc",
+                                            "CFGuard_Check", 19>;
+def CConvSwiftTail : LLVM_EnumAttrCase<"SwiftTail", "swifttailcc",
+                                       "SwiftTail", 20>;
+def CConvX86_StdCall : LLVM_EnumAttrCase<"X86_StdCall", "x86_stdcallcc",
+                                         "X86_StdCall", 64>;
+def CConvX86_FastCall : LLVM_EnumAttrCase<"X86_FastCall", "x86_fastcallcc",
+                                          "X86_FastCall", 65>;
+def CConvARM_APCS : LLVM_EnumAttrCase<"ARM_APCS", "arm_apcscc", "ARM_APCS", 66>;
+def CConvARM_AAPCS : LLVM_EnumAttrCase<"ARM_AAPCS", "arm_aapcscc", "ARM_AAPCS",
+                                       67>;
+def CConvARM_AAPCS_VFP : LLVM_EnumAttrCase<"ARM_AAPCS_VFP", "arm_aapcs_vfpcc",
+                                           "ARM_AAPCS_VFP", 68>;
+def CConvMSP430_INTR : LLVM_EnumAttrCase<"MSP430_INTR", "msp430_intrcc",
+                                          "MSP430_INTR", 69>;
+def CConvX86_ThisCall : LLVM_EnumAttrCase<"X86_ThisCall", "x86_thiscallcc",
+                                          "X86_ThisCall", 70>;
+def CConvPTX_Kernel : LLVM_EnumAttrCase<"PTX_Kernel", "ptx_kernelcc",
+                                        "PTX_Kernel", 71>;
+def CConvPTX_Device : LLVM_EnumAttrCase<"PTX_Device", "ptx_devicecc",
+                                        "PTX_Device", 72>;
+def CConvSPIR_FUNC : LLVM_EnumAttrCase<"SPIR_FUNC", "spir_funccc",
+                                       "SPIR_FUNC", 75>;
+def CConvSPIR_KERNEL : LLVM_EnumAttrCase<"SPIR_KERNEL", "spir_kernelcc",
+                                         "SPIR_KERNEL", 76>;
+def CConvIntel_OCL_BI : LLVM_EnumAttrCase<"Intel_OCL_BI", "intel_ocl_bicc",
+                                          "Intel_OCL_BI", 77>;
+def CConvX86_64_SysV : LLVM_EnumAttrCase<"X86_64_SysV", "x86_64_sysvcc",
+                                         "X86_64_SysV", 78>;
+def CConvWin64 : LLVM_EnumAttrCase<"Win64", "win64cc", "Win64", 79>;
+def CConvX86_VectorCall : LLVM_EnumAttrCase<"X86_VectorCall",
+                                            "x86_vectorcallcc",
+                                            "X86_VectorCall", 80>;
+def CConvHHVM : LLVM_EnumAttrCase<"HHVM", "hhvmcc", "HHVM", 81>;
+def CConvHHVM_C : LLVM_EnumAttrCase<"HHVM_C", "hhvm_ccc", "HHVM_C", 82>;
+def CConvX86_INTR : LLVM_EnumAttrCase<"X86_INTR", "x86_intrcc", "X86_INTR", 83>;
+def CConvAVR_INTR : LLVM_EnumAttrCase<"AVR_INTR", "avr_intrcc", "AVR_INTR", 84>;
+def CConvAVR_SIGNAL : LLVM_EnumAttrCase<"AVR_SIGNAL", "avr_signalcc",
+                                        "AVR_SIGNAL", 85>;
+def CConvAVR_BUILTIN : LLVM_EnumAttrCase<"AVR_BUILTIN", "avr_builtincc",
+                                         "AVR_BUILTIN", 86>;
+def CConvAMDGPU_VS : LLVM_EnumAttrCase<"AMDGPU_VS", "amdgpu_vscc", "AMDGPU_VS",
+                                       87>;
+def CConvAMDGPU_GS : LLVM_EnumAttrCase<"AMDGPU_GS", "amdgpu_gscc", "AMDGPU_GS",
+                                       88>;
+def CConvAMDGPU_PS : LLVM_EnumAttrCase<"AMDGPU_PS", "amdgpu_pscc", "AMDGPU_PS",
+                                       89>;
+def CConvAMDGPU_CS : LLVM_EnumAttrCase<"AMDGPU_CS", "amdgpu_cscc", "AMDGPU_CS",
+                                       90>;
+def CConvAMDGPU_KERNEL : LLVM_EnumAttrCase<"AMDGPU_KERNEL", "amdgpu_kernelcc",
+                                           "AMDGPU_KERNEL", 91>;
+def CConvX86_RegCall : LLVM_EnumAttrCase<"X86_RegCall", "x86_regcallcc",
+                                         "X86_RegCall", 92>;
+def CConvAMDGPU_HS : LLVM_EnumAttrCase<"AMDGPU_HS", "amdgpu_hscc", "AMDGPU_HS",
+                                       93>;
+def CConvMSP430_BUILTIN : LLVM_EnumAttrCase<"MSP430_BUILTIN",
+                                             "msp430_builtincc",
+                                             "MSP430_BUILTIN", 94>;
+def CConvAMDGPU_LS : LLVM_EnumAttrCase<"AMDGPU_LS", "amdgpu_lscc", "AMDGPU_LS",
+                                       95>;
+def CConvAMDGPU_ES : LLVM_EnumAttrCase<"AMDGPU_ES", "amdgpu_escc", "AMDGPU_ES",
+                                       96>;
+def CConvAArch64_VectorCall : LLVM_EnumAttrCase<"AArch64_VectorCall",
+                                                "aarch64_vectorcallcc",
+                                                "AArch64_VectorCall", 97>;
+def CConvAArch64_SVE_VectorCall : LLVM_EnumAttrCase<"AArch64_SVE_VectorCall",
+                                                    "aarch64_sve_vectorcallcc",
+                                                    "AArch64_SVE_VectorCall",
+                                                    98>;
+def CConvWASM_EmscriptenInvoke : LLVM_EnumAttrCase<"WASM_EmscriptenInvoke",
+                                                   "wasm_emscripten_invokecc",
+                                                   "WASM_EmscriptenInvoke", 99>;
+def CConvAMDGPU_Gfx : LLVM_EnumAttrCase<"AMDGPU_Gfx", "amdgpu_gfxcc",
+                                        "AMDGPU_Gfx", 100>;
+def CConvM68k_INTR : LLVM_EnumAttrCase<"M68k_INTR", "m68k_intrcc", "M68k_INTR",
+                                       101>;
+
+def CConvEnum : LLVM_CEnumAttr<
+    "CConv",
+    "::llvm::CallingConv",
+    "Calling Conventions",
+    [CConvC, CConvFast, CConvCold, CConvGHC, CConvHiPE, CConvWebKitJS,
+     CConvAnyReg, CConvPreserveMost, CConvPreserveAll, CConvSwift,
+     CConvCXXFastTLS, CConvTail, CConvCFGuard_Check, CConvSwiftTail,
+     CConvX86_StdCall, CConvX86_FastCall, CConvARM_APCS,
+     CConvARM_AAPCS, CConvARM_AAPCS_VFP, CConvMSP430_INTR, CConvX86_ThisCall,
+     CConvPTX_Kernel, CConvPTX_Device, CConvSPIR_FUNC, CConvSPIR_KERNEL,
+     CConvIntel_OCL_BI, CConvX86_64_SysV, CConvWin64, CConvX86_VectorCall,
+     CConvHHVM, CConvHHVM_C, CConvX86_INTR, CConvAVR_INTR, CConvAVR_BUILTIN,
+     CConvAMDGPU_VS, CConvAMDGPU_GS, CConvAMDGPU_CS, CConvAMDGPU_KERNEL,
+     CConvX86_RegCall, CConvAMDGPU_HS, CConvMSP430_BUILTIN, CConvAMDGPU_LS,
+     CConvAMDGPU_ES, CConvAArch64_VectorCall, CConvAArch64_SVE_VectorCall,
+     CConvWASM_EmscriptenInvoke, CConvAMDGPU_Gfx, CConvM68k_INTR
+    ]> {
+  let cppNamespace = "::mlir::LLVM::cconv";
+}
+
+def CConv : DialectAttr<
+    LLVM_Dialect,
+    CPred<"$_self.isa<::mlir::LLVM::CConvAttr>()">,
+    "LLVM Calling Convention specification"> {
+  let storageType = "::mlir::LLVM::CConvAttr";
+  let returnType = "::mlir::LLVM::cconv::CConv";
+  let convertFromStorage = "$_self.getCConv()";
+  let constBuilderCall =
+          "::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
+}
+
 class LLVM_Builder<string builder> {
   string llvmBuilder = builder;
 }
@@ -1233,6 +1361,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     TypeAttrOf<LLVM_FunctionType>:$function_type,
     DefaultValuedAttr<Linkage, "Linkage::External">:$linkage,
     UnitAttr:$dso_local,
+    DefaultValuedAttr<CConv, "CConv::C">:$CConv,
     OptionalAttr<FlatSymbolRefAttr>:$personality,
     OptionalAttr<StrAttr>:$garbageCollector,
     OptionalAttr<ArrayAttr>:$passthrough
@@ -1246,6 +1375,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OpBuilder<(ins "StringRef":$name, "Type":$type,
       CArg<"Linkage", "Linkage::External">:$linkage,
       CArg<"bool", "false">:$dsoLocal,
+      CArg<"CConv", "CConv::C">:$cconv,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
       CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
   ];

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 3ffd52604fd78..f162f779922f8 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -139,7 +139,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
     prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
-      wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
+      wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
+      /*cconv*/ LLVM::CConv::C, attributes);
 
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
@@ -206,7 +207,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
-      wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
+      wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
+      /*cconv*/ LLVM::CConv::C, attributes);
 
   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
 
@@ -345,7 +347,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
     }
     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
-        /*dsoLocal*/ false, attributes);
+        /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6cb5a5bdf2e02..85d1a5234b8f2 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -68,7 +68,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
-      LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
+      LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
+      attributes);
 
   {
     // Insert operations that correspond to converted workgroup and private

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 958142b03d245..41d0521b981c2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -37,6 +37,7 @@
 
 using namespace mlir;
 using namespace mlir::LLVM;
+using mlir::LLVM::cconv::getMaxEnumValForCConv;
 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
@@ -1821,6 +1822,7 @@ struct EnumTraits {};
 
 REGISTER_ENUM_TYPE(Linkage);
 REGISTER_ENUM_TYPE(UnnamedAddr);
+REGISTER_ENUM_TYPE(CConv);
 } // namespace
 
 /// Parse an enum from the keyword, or default to the provided default value.
@@ -2124,7 +2126,8 @@ Block *LLVMFuncOp::addEntryBlock() {
 
 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                        StringRef name, Type type, LLVM::Linkage linkage,
-                       bool dsoLocal, ArrayRef<NamedAttribute> attrs,
+                       bool dsoLocal, CConv cconv,
+                       ArrayRef<NamedAttribute> attrs,
                        ArrayRef<DictionaryAttr> argAttrs) {
   result.addRegion();
   result.addAttribute(SymbolTable::getSymbolAttrName(),
@@ -2133,6 +2136,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                       TypeAttr::get(type));
   result.addAttribute(getLinkageAttrName(result.name),
                       LinkageAttr::get(builder.getContext(), linkage));
+  result.addAttribute(getCConvAttrName(result.name),
+                      CConvAttr::get(builder.getContext(), cconv));
   result.attributes.append(attrs.begin(), attrs.end());
   if (dsoLocal)
     result.addAttribute("dso_local", builder.getUnitAttr());
@@ -2185,7 +2190,8 @@ buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
 
 // Parses an LLVM function.
 //
-// operation ::= `llvm.func` linkage? function-signature function-attributes?
+// operation ::= `llvm.func` linkage? cconv? function-signature
+// function-attributes?
 //               function-body
 //
 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2196,6 +2202,12 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
                        parseOptionalLLVMKeyword<Linkage>(
                            parser, result, LLVM::Linkage::External)));
 
+  // Default to C Calling Convention if no keyword is provided.
+  result.addAttribute(
+      getCConvAttrName(result.name),
+      CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
+                                              parser, result, LLVM::CConv::C)));
+
   StringAttr nameAttr;
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
@@ -2239,6 +2251,9 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   p << ' ';
   if (getLinkage() != LLVM::Linkage::External)
     p << stringifyLinkage(getLinkage()) << ' ';
+  if (getCConv() != LLVM::CConv::C)
+    p << stringifyCConv(getCConv()) << ' ';
+
   p.printSymbolName(getName());
 
   LLVMFunctionType fnType = getFunctionType();
@@ -2255,7 +2270,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
-      p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
+      p, *this, argTypes.size(), resTypes.size(),
+      {getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
@@ -2645,7 +2661,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 void LLVMDialect::initialize() {
-  addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>();
+  addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
 
   // clang-format off
   addTypes<LLVMVoidType,
@@ -2940,6 +2956,31 @@ Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
   return LinkageAttr::get(parser.getContext(), linkage);
 }
 
+void CConvAttr::print(AsmPrinter &printer) const {
+  printer << "<";
+  if (static_cast<uint64_t>(getCConv()) <= cconv::getMaxEnumValForCConv())
+    printer << stringifyEnum(getCConv());
+  else
+    printer << "INVALID_cc_" << static_cast<uint64_t>(getCConv());
+  printer << ">";
+}
+
+Attribute CConvAttr::parse(AsmParser &parser, Type type) {
+  StringRef convName;
+
+  if (parser.parseLess() || parser.parseKeyword(&convName) ||
+      parser.parseGreater())
+    return {};
+  auto cconv = cconv::symbolizeCConv(convName);
+  if (!cconv) {
+    parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
+        << convName;
+    return {};
+  }
+  CConv cconvVal = *cconv;
+  return CConvAttr::get(parser.getContext(), cconvVal);
+}
+
 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
 

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index ab69c6a002751..eced5896a167e 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -1139,10 +1139,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
   if (!functionType)
     return failure();
 
+  bool dsoLocal = f->hasLocalLinkage();
+  CConv cconv = convertCConvFromLLVM(f->getCallingConv());
+
   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
-  LLVMFuncOp fop =
-      b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
-                           convertLinkageFromLLVM(f->getLinkage()));
+  LLVMFuncOp fop = b.create<LLVMFuncOp>(
+      UnknownLoc::get(context), f->getName(), functionType,
+      convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv);
 
   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
     fop->setAttr(b.getStringAttr("personality"), personality);

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 810622da275ee..852ff1f6191f6 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -144,6 +144,21 @@ module {
       -> (!llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
     llvm.return %arg0 : !llvm.struct<(i32)>
   }
+
+  // CHECK: llvm.func @cconv1
+  llvm.func ccc @cconv1() {
+    llvm.return
+  }
+
+  // CHECK: llvm.func weak @cconv2
+  llvm.func weak ccc @cconv2() {
+    llvm.return
+  }
+
+  // CHECK: llvm.func weak fastcc @cconv3
+  llvm.func weak fastcc @cconv3() {
+    llvm.return
+  }
 }
 
 // -----
@@ -251,3 +266,18 @@ module {
   // expected-error at +1 {{functions cannot have 'common' linkage}}
   llvm.func common @common_linkage_func()
 }
+
+// -----
+
+module {
+  // expected-error at +1 {{custom op 'llvm.func' expected valid '@'-identifier for symbol name}}
+  llvm.func cc_12 @unknown_calling_convention()
+}
+
+// -----
+
+module {
+  // expected-error at +2 {{unknown calling convention: cc_12}}
+  "llvm.func"() ({
+  }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll
index 6b74e7da00c6c..3691448eb6ee1 100644
--- a/mlir/test/Target/LLVMIR/Import/basic.ll
+++ b/mlir/test/Target/LLVMIR/Import/basic.ll
@@ -122,8 +122,13 @@ define internal void @func_internal() {
 ; CHECK: llvm.func @fe(i32) -> f32
 declare float @fe(i32)
 
+; CHECK: llvm.func internal spir_funccc @spir_func_internal()
+define internal spir_func void @spir_func_internal() {
+  ret void
+}
+
 ; FIXME: function attributes.
-; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 {
+; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 attributes {dso_local} {
 ; CHECK-DAG: %[[c2:[0-9]+]] = llvm.mlir.constant(2 : i32) : i32
 ; CHECK-DAG: %[[c42:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32
 ; CHECK-DAG: %[[c1:[0-9]+]] = llvm.mlir.constant(true) : i1

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index e098e89da39d8..50eb2e7978af5 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -210,6 +210,27 @@ class LLVMEnumAttr : public tblgen::EnumAttr {
     return cases;
   }
 };
+
+// Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
+class LLVMCEnumAttr : public tblgen::EnumAttr {
+public:
+  using tblgen::EnumAttr::EnumAttr;
+
+  // Returns the C++ enum name for the LLVM API.
+  StringRef getLLVMClassName() const {
+    return def->getValueAsString("llvmClassName");
+  }
+
+  // Returns all associated cases viewed as LLVM-specific enum cases.
+  std::vector<LLVMEnumAttrCase> getAllCases() const {
+    std::vector<LLVMEnumAttrCase> cases;
+
+    for (auto &c : tblgen::EnumAttr::getAllCases())
+      cases.emplace_back(c);
+
+    return cases;
+  }
+};
 } // namespace
 
 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
@@ -242,6 +263,37 @@ static void emitOneEnumToConversion(const llvm::Record *record,
   os << "}\n\n";
 }
 
+// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
+// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
+// (Enum) to the corresponding LLVM API C-style enumerant
+static void emitOneCEnumToConversion(const llvm::Record *record,
+                                     raw_ostream &os) {
+  LLVMCEnumAttr enumAttr(record);
+  StringRef llvmClass = enumAttr.getLLVMClassName();
+  StringRef cppClassName = enumAttr.getEnumClassName();
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+
+  // Emit the function converting the enum attribute to its LLVM counterpart.
+  os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
+                "convert{0}ToLLVM({1}::{0} value) {{\n",
+                cppClassName, cppNamespace);
+  os << "  switch (value) {\n";
+
+  for (const auto &enumerant : enumAttr.getAllCases()) {
+    StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
+    StringRef cppEnumerant = enumerant.getSymbol();
+    os << formatv("  case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
+                  cppEnumerant);
+    os << formatv("    return static_cast<int64_t>({0}::{1});\n", llvmClass,
+                  llvmEnumerant);
+  }
+
+  os << "  }\n";
+  os << formatv("  llvm_unreachable(\"unknown {0} type\");\n",
+                enumAttr.getEnumClassName());
+  os << "}\n\n";
+}
+
 // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
 // LLVM dialect enum attribute (Enum).
@@ -272,6 +324,38 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
   os << "}\n\n";
 }
 
+// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
+// containing switch-based logic to convert from the LLVM API C-style enumerant
+// to MLIR LLVM dialect enum attribute (Enum).
+static void emitOneCEnumFromConversion(const llvm::Record *record,
+                                       raw_ostream &os) {
+  LLVMCEnumAttr enumAttr(record);
+  StringRef llvmClass = enumAttr.getLLVMClassName();
+  StringRef cppClassName = enumAttr.getEnumClassName();
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+
+  // Emit the function converting the enum attribute from its LLVM counterpart.
+  os << formatv(
+      "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
+      "value) {{\n",
+      cppNamespace, cppClassName, llvmClass);
+  os << "  switch (value) {\n";
+
+  for (const auto &enumerant : enumAttr.getAllCases()) {
+    StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
+    StringRef cppEnumerant = enumerant.getSymbol();
+    os << formatv("  case static_cast<int64_t>({0}::{1}):\n", llvmClass,
+                  llvmEnumerant);
+    os << formatv("    return {0}::{1}::{2};\n", cppNamespace, cppClassName,
+                  cppEnumerant);
+  }
+
+  os << "  }\n";
+  os << formatv("  llvm_unreachable(\"unknown {0} type\");",
+                enumAttr.getLLVMClassName());
+  os << "}\n\n";
+}
+
 // Emits conversion functions between MLIR enum attribute case and corresponding
 // LLVM API enumerants for all registered LLVM dialect enum attributes.
 template <bool ConvertTo>
@@ -283,6 +367,13 @@ static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
     else
       emitOneEnumFromConversion(def, os);
 
+  for (const auto *def :
+       recordKeeper.getAllDerivedDefinitions("LLVM_CEnumAttr"))
+    if (ConvertTo)
+      emitOneCEnumToConversion(def, os);
+    else
+      emitOneCEnumFromConversion(def, os);
+
   return false;
 }
 


        


More information about the Mlir-commits mailing list