[Mlir-commits] [mlir] Add vararg support in LLVM::LLVMFuncOp (PR #67274)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 24 18:59:58 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

<details>
<summary>Changes</summary>

Adds an attribute that specifies the callee type of a LLVMCallOp. This is needed as currently the LLVM Dialect always translates CallOp's to non-vararg calls.

The new syntax I opted for is this:

```
%2 = llvm.call @<!-- -->printf vararg !llvm.func<i32 (ptr<4>, ...)> (%1, %arg0) : (!llvm.ptr<4>, i32) -> i32
```

for vararg calls, and for non-vararg calls, the function type is omitted and the type will be inferred from the operation operand and result types, for example:

```
%2 = llvm.call @<!-- -->foo(%1, %arg0) : (!llvm.ptr<4>, i32) -> i32
```

I think it makes sense to keep the type at the end after the `:` which describes the type as a MLIR operation and the function type separate.

If this syntax looks good I will update the tests to reflect this change.




---
Full diff: https://github.com/llvm/llvm-project/pull/67274.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+8-1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+76-11) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+4-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2b4c8b609cfdd4f..2f65077dddf5aaf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                  OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
@@ -630,6 +631,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
                    CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index fd0d2b3fb3c1a08..87138d24f7fc09a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,6 +950,16 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
 // CallOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
+  SmallVector<Type> results;
+  Type resultType = calleeType.getReturnType();
+  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+    results.push_back(resultType);
+  return results;
+}
+} // namespace
+
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    StringRef callee, ValueRange args) {
   build(builder, state, results, builder.getStringAttr(callee), args);
@@ -962,7 +972,38 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
-  build(builder, state, results, callee, args, /*fastmathFlags=*/nullptr,
+  Type resultType;
+  if (results.empty())
+    resultType = LLVMVoidType::get(builder.getContext());
+  else
+    resultType = results.front();
+  std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
+  auto calleeType =
+      LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+  build(builder, state, results, TypeAttr::get(calleeType), callee, args,
+        /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
+        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringRef callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringAttr callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+                   ValueRange args) {
+  build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+        callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -970,11 +1011,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
-  SmallVector<Type> results;
-  Type resultType = func.getFunctionType().getReturnType();
-  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
-    results.push_back(resultType);
-  build(builder, state, results, SymbolRefAttr::get(func), args,
+  auto calleeType = func.getFunctionType();
+  build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+        SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1133,6 +1172,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
+  auto calleeType = *getCalleeType();
+  bool isVarArg = calleeType.isVarArg();
 
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
@@ -1142,9 +1183,13 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
+  if (isVarArg)
+    p << " vararg " << getCalleeType() << ' ';
+
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
-  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
+  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
+                          {"callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1158,7 +1203,7 @@ void CallOp::print(OpAsmPrinter &p) {
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands, bool isVarArg) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
   if (parser.parseColonTypeList(types))
@@ -1194,6 +1239,18 @@ static ParseResult parseCallTypeAndResolveOperands(
   if (funcType.getNumResults() != 0)
     result.addTypes(funcType.getResults());
 
+  if (!isVarArg) {
+    Type returnType;
+    if (funcType.getNumResults() == 0)
+      returnType = LLVM::LLVMVoidType::get(result.getContext());
+    else
+      returnType = funcType.getResult(0);
+    result.addAttribute(
+        "callee_type",
+        TypeAttr::get(LLVM::LLVMFunctionType::get(
+            returnType, funcType.getInputs(), /*isVarArg*/ false)));
+  }
+
   return success();
 }
 
@@ -1212,10 +1269,12 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (function-id | ssa-use)`(` ssa-use-list `)`
+// <operation> ::= `llvm.call` (function-id | ssa-use) var-arg-func-type?
+//                             `(` ssa-use-list `)`
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
+  TypeAttr calleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Parse a function pointer for indirect calls.
@@ -1228,13 +1287,18 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
       return failure();
 
+  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+  if (isVarArg)
+    parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
+
   // Parse the function arguments.
   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                         isVarArg);
 }
 
 ///===---------------------------------------------------------------------===//
@@ -1349,7 +1413,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      /*isVarArg*/ false))
     return failure();
 
   result.addSuccessors({normalDest, unwindDest});
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 9b438090c84cacc..cdd22f116a4ef9d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,9 +211,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(
           moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
     } else {
-      call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
-                                                      callOp.getArgOperands()),
-                                operandsRef.front(), operandsRef.drop_front());
+      call = builder.CreateCall(
+          llvm::cast<llvm::FunctionType>(
+              moduleTranslation.convertType(*callOp.getCalleeType())),
+          operandsRef.front(), operandsRef.drop_front());
     }
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);

``````````

</details>


https://github.com/llvm/llvm-project/pull/67274


More information about the Mlir-commits mailing list