[Mlir-commits] [mlir] [MLIR] Add support for calling conventions to LLVM::CallOp and LLVM::InvokeOp (PR #71319)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 5 09:45:35 PST 2023


https://github.com/Sirraide updated https://github.com/llvm/llvm-project/pull/71319

>From 1761abcbe65146df90761bf75574a5ce43f1ba3d Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 5 Nov 2023 16:19:00 +0100
Subject: [PATCH 1/5] [MLIR] Add support for calling conventions to
 LLVM::CallOp/InvokeOp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Despite the fact that the LLVM dialect’s FuncOp already supports
calling conventions, there was yet no support for them in the ops
that actually perform function calls, which led to incorrect LLVM
IR being generated if one actually tried setting a FuncOp’s calling
convention to anything other than ccc.

This commit adds support for calling conventions to LLVM::CallOp
and LLVM::InvokeOp and makes sure that calling conventions are
parsed, printed, and lowered appropriately.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 147 +++++++++++-------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |   4 +-
 mlir/test/Dialect/LLVMIR/inlining.mlir        |   2 +-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  10 ++
 mlir/test/Target/LLVMIR/llvmir.mlir           |  25 +++
 6 files changed, 133 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c8549f146d0297a..16bd478cc0a8685 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -552,7 +552,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
-                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
+                   DefaultValuedAttr<CConv, "CConv::C">:$CConv);
   let results = (outs Variadic<LLVM_Type>);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -633,7 +634,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
-                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
+                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
+                  DefaultValuedAttr<CConv, "CConv::C">:$CConv);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7f5681e7bdc0592..a670c2b03f5d303 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -97,6 +97,52 @@ static Type getI1SameShape(Type type) {
   return i1Type;
 }
 
+// Parses one of the keywords provided in the list `keywords` and returns the
+// position of the parsed keyword in the list. If none of the keywords from the
+// list is parsed, returns -1.
+static int parseOptionalKeywordAlternative(OpAsmParser &parser,
+                                           ArrayRef<StringRef> keywords) {
+  for (const auto &en : llvm::enumerate(keywords)) {
+    if (succeeded(parser.parseOptionalKeyword(en.value())))
+      return en.index();
+  }
+  return -1;
+}
+
+namespace {
+template <typename Ty>
+struct EnumTraits {};
+
+#define REGISTER_ENUM_TYPE(Ty)                                                 \
+  template <>                                                                  \
+  struct EnumTraits<Ty> {                                                      \
+    static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
+    static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
+  }
+
+REGISTER_ENUM_TYPE(Linkage);
+REGISTER_ENUM_TYPE(UnnamedAddr);
+REGISTER_ENUM_TYPE(CConv);
+REGISTER_ENUM_TYPE(Visibility);
+} // namespace
+
+/// Parse an enum from the keyword, or default to the provided default value.
+/// The return type is the enum type by default, unless overridden with the
+/// second template argument.
+template <typename EnumTy, typename RetTy = EnumTy>
+static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
+                                      OperationState &result,
+                                      EnumTy defaultValue) {
+  SmallVector<StringRef, 10> names;
+  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
+    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
+
+  int index = parseOptionalKeywordAlternative(parser, names);
+  if (index == -1)
+    return static_cast<RetTy>(defaultValue);
+  return static_cast<RetTy>(index);
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing, folding and builder for LLVM::CmpOp.
 //===----------------------------------------------------------------------===//
@@ -1012,6 +1058,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
   build(builder, state, results,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
         callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1033,7 +1080,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
-        /*branch_weights=*/nullptr, /*access_groups=*/nullptr,
+        /*branch_weights=*/nullptr, /*CConv=*/nullptr,
+        /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1042,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1052,6 +1101,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*CConv=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1067,6 +1117,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, Value callee,
   return build(builder, state, getCallOpResultTypes(calleeType),
                TypeAttr::get(calleeType), FlatSymbolRefAttr(), operands,
                /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+               /*CConv=*/nullptr,
                /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
                /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1161,6 +1212,13 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       return emitOpError() << "'" << calleeName.getValue()
                            << "' does not reference a valid LLVM function";
 
+    if (fn.getCConv() != getCConv()) {
+      auto fnCConv = stringifyCConv(fn.getCConv());
+      auto callCConv = stringifyCConv(getCConv());
+      return emitOpError() << "calling convention mismatch: " << callCConv
+                           << " != " << fnCConv;
+    }
+
     if (failed(verifyCallOpDebugInfo(*this, fn)))
       return failure();
     fnType = fn.getFunctionType();
@@ -1226,9 +1284,14 @@ void CallOp::print(OpAsmPrinter &p) {
     isVarArg = calleeType.isVarArg();
   }
 
+  p << ' ';
+
+  // Print calling convention.
+  if (getCConv() != LLVM::CConv::C)
+    p << stringifyCConv(getCConv()) << ' ';
+
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
-  p << ' ';
   if (isDirect)
     p.printSymbolName(callee.value());
   else
@@ -1241,7 +1304,7 @@ void CallOp::print(OpAsmPrinter &p) {
     p << " vararg(" << calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {"callee", "callee_type"});
+                          {getCConvAttrName(), "callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1309,7 +1372,7 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv) (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
@@ -1318,6 +1381,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   TypeAttr calleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
+  // Default to C Calling Convention if no keyword is provided.
+  result.addAttribute(
+      getCConvAttrName(result.name),
+      CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
+                                              parser, result, LLVM::CConv::C)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
@@ -1363,7 +1432,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, normal, unwind);
+        unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1372,7 +1441,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange unwindOps) {
   build(builder, state, tys,
         TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, normal, unwind);
+        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1381,7 +1450,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
         TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        normal, unwind);
+        nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1447,6 +1516,10 @@ void InvokeOp::print(OpAsmPrinter &p) {
 
   p << ' ';
 
+  // Print calling convention.
+  if (getCConv() != LLVM::CConv::C)
+    p << stringifyCConv(getCConv()) << ' ';
+
   // Either function name or pointer
   if (isDirect)
     p.printSymbolName(callee.value());
@@ -1462,9 +1535,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
   if (isVarArg)
     p << " vararg(" << calleeType << ")";
 
-  p.printOptionalAttrDict(
-      (*this)->getAttrs(),
-      {InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"});
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          {InvokeOp::getOperandSegmentSizeAttr(), "callee",
+                           "callee_type", InvokeOp::getCConvAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1473,7 +1546,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
                         getResultTypes());
 }
 
-// <operation> ::= `llvm.invoke` (function-id | ssa-use)
+// <operation> ::= `llvm.invoke` (cconv) (function-id | ssa-use)
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
@@ -1487,6 +1560,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
 
+  // Default to C Calling Convention if no keyword is provided.
+  result.addAttribute(
+      getCConvAttrName(result.name),
+      CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
+                                              parser, result, LLVM::CConv::C)));
+
   // Parse a function pointer for indirect calls.
   if (parseOptionalCallFuncPtr(parser, operands))
     return failure();
@@ -1971,52 +2050,6 @@ void GlobalOp::print(OpAsmPrinter &p) {
   }
 }
 
-// Parses one of the keywords provided in the list `keywords` and returns the
-// position of the parsed keyword in the list. If none of the keywords from the
-// list is parsed, returns -1.
-static int parseOptionalKeywordAlternative(OpAsmParser &parser,
-                                           ArrayRef<StringRef> keywords) {
-  for (const auto &en : llvm::enumerate(keywords)) {
-    if (succeeded(parser.parseOptionalKeyword(en.value())))
-      return en.index();
-  }
-  return -1;
-}
-
-namespace {
-template <typename Ty>
-struct EnumTraits {};
-
-#define REGISTER_ENUM_TYPE(Ty)                                                 \
-  template <>                                                                  \
-  struct EnumTraits<Ty> {                                                      \
-    static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
-    static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
-  }
-
-REGISTER_ENUM_TYPE(Linkage);
-REGISTER_ENUM_TYPE(UnnamedAddr);
-REGISTER_ENUM_TYPE(CConv);
-REGISTER_ENUM_TYPE(Visibility);
-} // namespace
-
-/// Parse an enum from the keyword, or default to the provided default value.
-/// The return type is the enum type by default, unless overriden with the
-/// second template argument.
-template <typename EnumTy, typename RetTy = EnumTy>
-static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
-                                      OperationState &result,
-                                      EnumTy defaultValue) {
-  SmallVector<StringRef, 10> names;
-  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
-    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
-
-  int index = parseOptionalKeywordAlternative(parser, names);
-  if (index == -1)
-    return static_cast<RetTy>(defaultValue);
-  return static_cast<RetTy>(index);
-}
-
 static LogicalResult verifyComdat(Operation *op,
                                   std::optional<SymbolRefAttr> attr) {
   if (!attr)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 1c0f51a66bf5e8c..5494a13acb6e1b6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -200,6 +200,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(calleeType, operandsRef.front(),
                                 operandsRef.drop_front());
     }
+    call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
@@ -275,7 +276,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
     auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
-    llvm::Instruction *result;
+    llvm::InvokeInst *result;
     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
       result = builder.CreateInvoke(
           moduleTranslation.lookupFunction(attr.getValue()),
@@ -290,6 +291,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());
     }
+    result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index 1296b8e031c1330..b684be1f9626b1c 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -84,7 +84,7 @@ llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count =
 // CHECK-NEXT: llvm.return %[[CST]]
 llvm.func @caller() -> (i32) {
   // Include all call attributes that don't prevent inlining.
-  %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
+  %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
   llvm.return %0 : i32
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe2f94454561a08..09358493b2c56cd 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -349,6 +349,16 @@ func.func @call_too_many_results(%callee : !llvm.ptr) {
 
 // -----
 
+llvm.func fastcc @callee_cconv_mismatch()
+
+llvm.func @call_cconv_mismatch() {
+    // expected-error at +1 {{calling convention mismatch: ccc != fastcc}}
+    llvm.call @callee_cconv_mismatch() : () -> ()
+    llvm.return
+}
+
+// -----
+
 llvm.func @void_func_result(%arg0: i32) {
   // expected-error at below {{expected no operands}}
   // expected-note at above {{when returning from function}}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 3f27b247edd3e67..2b510a2941c62d0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -329,6 +329,31 @@ llvm.func @func_args(%arg0: i32, %arg1: i32) -> i32 {
   llvm.return %12 : i32
 }
 
+// CHECK: declare fastcc void @cconv_fastcc()
+// CHECK: declare        void @cconv_ccc()
+// CHECK: declare tailcc void @cconv_tailcc()
+// CHECK: declare ghccc  void @cconv_ghccc()
+llvm.func fastcc @cconv_fastcc()
+llvm.func ccc    @cconv_ccc()
+llvm.func tailcc @cconv_tailcc()
+llvm.func cc_10  @cconv_ghccc()
+
+// CHECK: define void @test_ccs() {
+llvm.func @test_ccs() {
+  // CHECK-NEXT: call fastcc void @cconv_fastcc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call tailcc void @cconv_tailcc()
+  // CHECK-NEXT: call ghccc  void @cconv_ghccc()
+  // CHECK-NEXT: ret void
+  llvm.call fastcc @cconv_fastcc() : () -> ()
+  llvm.call ccc    @cconv_ccc()    : () -> ()
+  llvm.call        @cconv_ccc()    : () -> ()
+  llvm.call tailcc @cconv_tailcc() : () -> ()
+  llvm.call cc_10  @cconv_ghccc()  : () -> ()
+  llvm.return
+}
+
 // CHECK: declare void @pre(i64)
 llvm.func @pre(i64)
 

>From cf8a74afd343488ce81ec0b5d375d93dc151a237 Mon Sep 17 00:00:00 2001
From: Sirraide <74590115+Sirraide at users.noreply.github.com>
Date: Sun, 5 Nov 2023 17:20:45 +0100
Subject: [PATCH 2/5] Fix EBNF in comments
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a670c2b03f5d303..7c5bd5e4374c300 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1372,7 +1372,7 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (cconv) (function-id | ssa-use)
+// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
@@ -1546,7 +1546,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
                         getResultTypes());
 }
 
-// <operation> ::= `llvm.invoke` (cconv) (function-id | ssa-use)
+// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?

>From 7dc7f4101cee4e05b8ccebd3b5db8192f5716653 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 5 Nov 2023 17:23:53 +0100
Subject: [PATCH 3/5] [MLIR] Make cconv mismatch no longer a verification
 failure

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp |  7 -------
 mlir/test/Dialect/LLVMIR/invalid.mlir      | 10 ----------
 2 files changed, 17 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7c5bd5e4374c300..abc42ffcfa39512 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1212,13 +1212,6 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       return emitOpError() << "'" << calleeName.getValue()
                            << "' does not reference a valid LLVM function";
 
-    if (fn.getCConv() != getCConv()) {
-      auto fnCConv = stringifyCConv(fn.getCConv());
-      auto callCConv = stringifyCConv(getCConv());
-      return emitOpError() << "calling convention mismatch: " << callCConv
-                           << " != " << fnCConv;
-    }
-
     if (failed(verifyCallOpDebugInfo(*this, fn)))
       return failure();
     fnType = fn.getFunctionType();
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 09358493b2c56cd..fe2f94454561a08 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -349,16 +349,6 @@ func.func @call_too_many_results(%callee : !llvm.ptr) {
 
 // -----
 
-llvm.func fastcc @callee_cconv_mismatch()
-
-llvm.func @call_cconv_mismatch() {
-    // expected-error at +1 {{calling convention mismatch: ccc != fastcc}}
-    llvm.call @callee_cconv_mismatch() : () -> ()
-    llvm.return
-}
-
-// -----
-
 llvm.func @void_func_result(%arg0: i32) {
   // expected-error at below {{expected no operands}}
   // expected-note at above {{when returning from function}}

>From 8bbeb90004e26f6096b879f1882b16692efa5a31 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 5 Nov 2023 18:44:33 +0100
Subject: [PATCH 4/5] [MLIR] Add calling convention tests for LLVM::InvokeOp

---
 .../Dialect/LLVMIR/calling-conventions.mlir   | 72 +++++++++++++++++++
 mlir/test/Target/LLVMIR/llvmir.mlir           | 25 -------
 2 files changed, 72 insertions(+), 25 deletions(-)
 create mode 100644 mlir/test/Dialect/LLVMIR/calling-conventions.mlir

diff --git a/mlir/test/Dialect/LLVMIR/calling-conventions.mlir b/mlir/test/Dialect/LLVMIR/calling-conventions.mlir
new file mode 100644
index 000000000000000..a6f692ba46f4550
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/calling-conventions.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @__gxx_personality_v0(...) -> i32
+
+// CHECK: declare fastcc void @cconv_fastcc()
+// CHECK: declare        void @cconv_ccc()
+// CHECK: declare tailcc void @cconv_tailcc()
+// CHECK: declare ghccc  void @cconv_ghccc()
+llvm.func fastcc @cconv_fastcc()
+llvm.func ccc    @cconv_ccc()
+llvm.func tailcc @cconv_tailcc()
+llvm.func cc_10  @cconv_ghccc()
+
+// CHECK-LABEL: @test_ccs
+llvm.func @test_ccs() {
+  // CHECK-NEXT: call fastcc void @cconv_fastcc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call        void @cconv_ccc()
+  // CHECK-NEXT: call tailcc void @cconv_tailcc()
+  // CHECK-NEXT: call ghccc  void @cconv_ghccc()
+  // CHECK-NEXT: ret void
+  llvm.call fastcc @cconv_fastcc() : () -> ()
+  llvm.call ccc    @cconv_ccc()    : () -> ()
+  llvm.call        @cconv_ccc()    : () -> ()
+  llvm.call tailcc @cconv_tailcc() : () -> ()
+  llvm.call cc_10  @cconv_ghccc()  : () -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: @test_ccs_invoke
+llvm.func @test_ccs_invoke() attributes { personality = @__gxx_personality_v0 } {
+  // CHECK-NEXT: invoke fastcc void @cconv_fastcc()
+  // CHECK-NEXT:   to label %[[normal1:[0-9]+]] unwind label %[[unwind:[0-9]+]]
+  llvm.invoke fastcc @cconv_fastcc() to ^bb1 unwind ^bb6 : () -> ()
+
+^bb1:
+  // CHECK: [[normal1]]:
+  // CHECK-NEXT: invoke void @cconv_ccc()
+  // CHECK-NEXT:   to label %[[normal2:[0-9]+]] unwind label %[[unwind:[0-9]+]]
+  llvm.invoke ccc @cconv_ccc() to ^bb2 unwind ^bb6 : () -> ()
+
+^bb2:
+  // CHECK: [[normal2]]:
+  // CHECK-NEXT: invoke void @cconv_ccc()
+  // CHECK-NEXT:   to label %[[normal3:[0-9]+]] unwind label %[[unwind:[0-9]+]]
+  llvm.invoke @cconv_ccc() to ^bb3 unwind ^bb6 : () -> ()
+
+^bb3:
+  // CHECK: [[normal3]]:
+  // CHECK-NEXT: invoke tailcc void @cconv_tailcc()
+  // CHECK-NEXT:   to label %[[normal4:[0-9]+]] unwind label %[[unwind:[0-9]+]]
+  llvm.invoke tailcc @cconv_tailcc() to ^bb4 unwind ^bb6 : () -> ()
+
+^bb4:
+  // CHECK: [[normal4]]:
+  // CHECK-NEXT: invoke ghccc void @cconv_ghccc()
+  // CHECK-NEXT:   to label %[[normal5:[0-9]+]] unwind label %[[unwind:[0-9]+]]
+  llvm.invoke cc_10 @cconv_ghccc() to ^bb5 unwind ^bb6 : () -> ()
+
+^bb5:
+  // CHECK: [[normal5]]:
+  // CHECK-NEXT: ret void
+  llvm.return
+
+  // CHECK: [[unwind]]:
+  // CHECK-NEXT: landingpad { ptr, i32 }
+  // CHECK-NEXT: cleanup
+  // CHECK-NEXT: ret void
+^bb6:
+  %0 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
+  llvm.return
+}
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 2b510a2941c62d0..3f27b247edd3e67 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -329,31 +329,6 @@ llvm.func @func_args(%arg0: i32, %arg1: i32) -> i32 {
   llvm.return %12 : i32
 }
 
-// CHECK: declare fastcc void @cconv_fastcc()
-// CHECK: declare        void @cconv_ccc()
-// CHECK: declare tailcc void @cconv_tailcc()
-// CHECK: declare ghccc  void @cconv_ghccc()
-llvm.func fastcc @cconv_fastcc()
-llvm.func ccc    @cconv_ccc()
-llvm.func tailcc @cconv_tailcc()
-llvm.func cc_10  @cconv_ghccc()
-
-// CHECK: define void @test_ccs() {
-llvm.func @test_ccs() {
-  // CHECK-NEXT: call fastcc void @cconv_fastcc()
-  // CHECK-NEXT: call        void @cconv_ccc()
-  // CHECK-NEXT: call        void @cconv_ccc()
-  // CHECK-NEXT: call tailcc void @cconv_tailcc()
-  // CHECK-NEXT: call ghccc  void @cconv_ghccc()
-  // CHECK-NEXT: ret void
-  llvm.call fastcc @cconv_fastcc() : () -> ()
-  llvm.call ccc    @cconv_ccc()    : () -> ()
-  llvm.call        @cconv_ccc()    : () -> ()
-  llvm.call tailcc @cconv_tailcc() : () -> ()
-  llvm.call cc_10  @cconv_ghccc()  : () -> ()
-  llvm.return
-}
-
 // CHECK: declare void @pre(i64)
 llvm.func @pre(i64)
 

>From 575bba1a0ff7af88221f16b70fa00b618b2ba0c6 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 5 Nov 2023 18:45:20 +0100
Subject: [PATCH 5/5] [MLIR] Add round trip test for calling conventions

---
 mlir/test/Dialect/LLVMIR/func.mlir | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index b45e6c4ef897b10..63e20b1d8fc31ab 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -184,6 +184,24 @@ module {
     llvm.return
   }
 
+  // CHECK: llvm.func cc_10 @cconv4
+  llvm.func cc_10 @cconv4() {
+    llvm.return
+  }
+
+  // CHECK: llvm.func @test_ccs
+  llvm.func @test_ccs() {
+    // CHECK-NEXT: llvm.call        @cconv1() : () -> ()
+    // CHECK-NEXT: llvm.call        @cconv2() : () -> ()
+    // CHECK-NEXT: llvm.call fastcc @cconv3() : () -> ()
+    // CHECK-NEXT: llvm.call cc_10  @cconv4() : () -> ()
+    llvm.call        @cconv1() : () -> ()
+    llvm.call ccc    @cconv2() : () -> ()
+    llvm.call fastcc @cconv3() : () -> ()
+    llvm.call cc_10  @cconv4() : () -> ()
+    llvm.return
+  }
+
   // CHECK-LABEL: llvm.func @variadic_def
   llvm.func @variadic_def(...) {
     llvm.return



More information about the Mlir-commits mailing list