[Mlir-commits] [mlir] [MLIR] Add support for calling conventions to LLVM::CallOp and LLVM::InvokeOp (PR #71319)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 5 07:28:39 PST 2023


https://github.com/Sirraide created https://github.com/llvm/llvm-project/pull/71319

Despite the fact that the LLVM dialect’s `FuncOp` already supports calling conventions, there was yet no support for them in the ops that actually perform function calls, which led to incorrect LLVM IR being generated if one actually tried setting a `FuncOp`’s calling convention to anything other than `ccc`.

This commit adds support for calling conventions to `LLVM::CallOp` and `LLVM::InvokeOp` and makes sure that calling conventions are parsed, printed, and lowered appropriately.

>From 1761abcbe65146df90761bf75574a5ce43f1ba3d Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 5 Nov 2023 16:19:00 +0100
Subject: [PATCH] [MLIR] Add support for calling conventions to
 LLVM::CallOp/InvokeOp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Despite the fact that the LLVM dialect’s FuncOp already supports
calling conventions, there was yet no support for them in the ops
that actually perform function calls, which led to incorrect LLVM
IR being generated if one actually tried setting a FuncOp’s calling
convention to anything other than ccc.

This commit adds support for calling conventions to LLVM::CallOp
and LLVM::InvokeOp and makes sure that calling conventions are
parsed, printed, and lowered appropriately.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 147 +++++++++++-------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |   4 +-
 mlir/test/Dialect/LLVMIR/inlining.mlir        |   2 +-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  10 ++
 mlir/test/Target/LLVMIR/llvmir.mlir           |  25 +++
 6 files changed, 133 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c8549f146d0297a..16bd478cc0a8685 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -552,7 +552,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
-                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
+                   DefaultValuedAttr<CConv, "CConv::C">:$CConv);
   let results = (outs Variadic<LLVM_Type>);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -633,7 +634,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
-                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
+                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
+                  DefaultValuedAttr<CConv, "CConv::C">:$CConv);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7f5681e7bdc0592..a670c2b03f5d303 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -97,6 +97,52 @@ static Type getI1SameShape(Type type) {
   return i1Type;
 }
 
+// Parses one of the keywords provided in the list `keywords` and returns the
+// position of the parsed keyword in the list. If none of the keywords from the
+// list is parsed, returns -1.
+static int parseOptionalKeywordAlternative(OpAsmParser &parser,
+                                           ArrayRef<StringRef> keywords) {
+  for (const auto &en : llvm::enumerate(keywords)) {
+    if (succeeded(parser.parseOptionalKeyword(en.value())))
+      return en.index();
+  }
+  return -1;
+}
+
+namespace {
+template <typename Ty>
+struct EnumTraits {};
+
+#define REGISTER_ENUM_TYPE(Ty)                                                 \
+  template <>                                                                  \
+  struct EnumTraits<Ty> {                                                      \
+    static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
+    static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
+  }
+
+REGISTER_ENUM_TYPE(Linkage);
+REGISTER_ENUM_TYPE(UnnamedAddr);
+REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(Visibility);
+} // namespace
+
+/// Parse an enum from the keyword, or default to the provided default value.
+/// The return type is the enum type by default, unless overridden with the
+/// second template argument.
+template <typename EnumTy, typename RetTy = EnumTy>
+static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
+                                      OperationState &result,
+                                      EnumTy defaultValue) {
+  SmallVector<StringRef, 10> names;
+  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
+    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
+
+  int index = parseOptionalKeywordAlternative(parser, names);
+  if (index == -1)
+    return static_cast<RetTy>(defaultValue);
+  return static_cast<RetTy>(index);
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing, folding and builder for LLVM::CmpOp.
 //===----------------------------------------------------------------------===//
@@ -1012,6 +1058,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
   build(builder, state, results,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
         callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1033,7 +1080,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
-        /*branch_weights=*/nullptr, /*access_groups=*/nullptr,
+        /*branch_weights=*/nullptr, /*CConv=*/nullptr,
+        /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1042,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1052,6 +1101,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1067,6 +1117,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, Value callee,
   return build(builder, state, getCallOpResultTypes(calleeType),
                TypeAttr::get(calleeType), FlatSymbolRefAttr(), operands,
                /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+               /*CConv=*/nullptr,
                /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
                /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1161,6 +1212,13 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       return emitOpError() << "'" << calleeName.getValue()
                            << "' does not reference a valid LLVM function";
 
+    if (fn.getCConv() != getCConv()) {
+      auto fnCConv = stringifyCConv(fn.getCConv());
+      auto callCConv = stringifyCConv(getCConv());
+      return emitOpError() << "calling convention mismatch: " << callCConv
+                           << " != " << fnCConv;
+    }
+
     if (failed(verifyCallOpDebugInfo(*this, fn)))
       return failure();
     fnType = fn.getFunctionType();
@@ -1226,9 +1284,14 @@ void CallOp::print(OpAsmPrinter &p) {
     isVarArg = calleeType.isVarArg();
   }
 
+  p << ' ';
+
+  // Print calling convention.
+  if (getCConv() != LLVM::CConv::C)
+    p << stringifyCConv(getCConv()) << ' ';
+
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
-  p << ' ';
   if (isDirect)
     p.printSymbolName(callee.value());
   else
@@ -1241,7 +1304,7 @@ void CallOp::print(OpAsmPrinter &p) {
     p << " vararg(" << calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {"callee", "callee_type"});
+                          {getCConvAttrName(), "callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1309,7 +1372,7 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv) (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
@@ -1318,6 +1381,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   TypeAttr calleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
+  // Default to C Calling Convention if no keyword is provided.
+  result.addAttribute(
+      getCConvAttrName(result.name),
+      CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
+                                              parser, result, LLVM::CConv::C)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
@@ -1363,7 +1432,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, normal, unwind);
+        unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1372,7 +1441,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange unwindOps) {
   build(builder, state, tys,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, normal, unwind);
+        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1381,7 +1450,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        normal, unwind);
+        nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1447,6 +1516,10 @@ void InvokeOp::print(OpAsmPrinter &p) {
 
   p << ' ';
 
+  // Print calling convention.
+  if (getCConv() != LLVM::CConv::C)
+    p << stringifyCConv(getCConv()) << ' ';
+
   // Either function name or pointer
   if (isDirect)
     p.printSymbolName(callee.value());
@@ -1462,9 +1535,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
   if (isVarArg)
     p << " vararg(" << calleeType << ")";
 
-  p.printOptionalAttrDict(
-      (*this)->getAttrs(),
-      {InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"});
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          {InvokeOp::getOperandSegmentSizeAttr(), "callee",
+                           "callee_type", InvokeOp::getCConvAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1473,7 +1546,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
                         getResultTypes());
 }
 
-// <operation> ::= `llvm.invoke` (function-id | ssa-use)
+// <operation> ::= `llvm.invoke` (cconv) (function-id | ssa-use)
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
@@ -1487,6 +1560,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
 
+  // Default to C Calling Convention if no keyword is provided.
+  result.addAttribute(
+      getCConvAttrName(result.name),
+      CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
+                                              parser, result, LLVM::CConv::C)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
@@ -1971,52 +2050,6 @@ void GlobalOp::print(OpAsmPrinter &p) {
   }
 }
 
-// Parses one of the keywords provided in the list `keywords` and returns the
-// position of the parsed keyword in the list. If none of the keywords from the
-// list is parsed, returns -1.
-static int parseOptionalKeywordAlternative(OpAsmParser &parser,
-                                           ArrayRef<StringRef> keywords) {
-  for (const auto &en : llvm::enumerate(keywords)) {
-    if (succeeded(parser.parseOptionalKeyword(en.value())))
-      return en.index();
-  }
-  return -1;
-}
-
-namespace {
-template <typename Ty>
-struct EnumTraits {};
-
-#define REGISTER_ENUM_TYPE(Ty)                                                 \
-  template <>                                                                  \
-  struct EnumTraits<Ty> {                                                      \
-    static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
-    static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
-  }
-
-REGISTER_ENUM_TYPE(Linkage);
-REGISTER_ENUM_TYPE(UnnamedAddr);
-REGISTER_ENUM_TYPE(CConv);
-REGISTER_ENUM_TYPE(Visibility);
-} // namespace
-
-/// Parse an enum from the keyword, or default to the provided default value.
-/// The return type is the enum type by default, unless overriden with the
-/// second template argument.
-template <typename EnumTy, typename RetTy = EnumTy>
-static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
-                                      OperationState &result,
-                                      EnumTy defaultValue) {
-  SmallVector<StringRef, 10> names;
-  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
-    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
-
-  int index = parseOptionalKeywordAlternative(parser, names);
-  if (index == -1)
-    return static_cast<RetTy>(defaultValue);
-  return static_cast<RetTy>(index);
-}
-
 static LogicalResult verifyComdat(Operation *op,
                                   std::optional<SymbolRefAttr> attr) {
   if (!attr)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 1c0f51a66bf5e8c..5494a13acb6e1b6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -200,6 +200,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(calleeType, operandsRef.front(),
                                 operandsRef.drop_front());
     }
+    call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
@@ -275,7 +276,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
     auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
-    llvm::Instruction *result;
+    llvm::InvokeInst *result;
     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
       result = builder.CreateInvoke(
           moduleTranslation.lookupFunction(attr.getValue()),
@@ -290,6 +291,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());
     }
+    result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index 1296b8e031c1330..b684be1f9626b1c 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -84,7 +84,7 @@ llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count =
 // CHECK-NEXT: llvm.return %[[CST]]
 llvm.func @caller() -> (i32) {
   // Include all call attributes that don't prevent inlining.
-  %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
+  %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
   llvm.return %0 : i32
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe2f94454561a08..09358493b2c56cd 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -349,6 +349,16 @@ func.func @call_too_many_results(%callee : !llvm.ptr) {
 
 // -----
 
+llvm.func fastcc @callee_cconv_mismatch()
+
+llvm.func @call_cconv_mismatch() {
+    // expected-error at +1 {{calling convention mismatch: ccc != fastcc}}
+    llvm.call @callee_cconv_mismatch() : () -> ()
+    llvm.return
+}
+
+// -----
+
 llvm.func @void_func_result(%arg0: i32) {
   // expected-error at below {{expected no operands}}
   // expected-note at above {{when returning from function}}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 3f27b247edd3e67..2b510a2941c62d0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -329,6 +329,31 @@ llvm.func @func_args(%arg0: i32, %arg1: i32) -> i32 {
   llvm.return %12 : i32
 }
 
+// CHECK: declare fastcc void @cconv_fastcc()
+// CHECK: declare        void @cconv_ccc()
+// CHECK: declare tailcc void @cconv_tailcc()
+// CHECK: declare ghccc  void @cconv_ghccc()
+llvm.func fastcc @cconv_fastcc()
+llvm.func ccc    @cconv_ccc()
+llvm.func tailcc @cconv_tailcc()
+llvm.func cc_10  @cconv_ghccc()
+
+// CHECK: define void @test_ccs() {
+llvm.func @test_ccs() {
+  // CHECK-NEXT: call fastcc void @cconv_fastcc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call tailcc void @cconv_tailcc()
+  // CHECK-NEXT: call ghccc  void @cconv_ghccc()
+  // CHECK-NEXT: ret void
+  llvm.call fastcc @cconv_fastcc() : () -> ()
+  llvm.call ccc    @cconv_ccc()    : () -> ()
+  llvm.call        @cconv_ccc()    : () -> ()
+  llvm.call tailcc @cconv_tailcc() : () -> ()
+  llvm.call cc_10  @cconv_ghccc()  : () -> ()
+  llvm.return
+}
+
 // CHECK: declare void @pre(i64)
 llvm.func @pre(i64)
 



More information about the Mlir-commits mailing list