[Mlir-commits] [mlir] 26eb428 - [MLIR][LLVM] Add vararg support in LLVM::CallOp and InvokeOp (#67274)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 27 23:26:49 PDT 2023


Author: Ivan R. Ivanov
Date: 2023-09-28T08:26:45+02:00
New Revision: 26eb4285b56edd8c897642078d91f16ff0fd3472

URL: https://github.com/llvm/llvm-project/commit/26eb4285b56edd8c897642078d91f16ff0fd3472
DIFF: https://github.com/llvm/llvm-project/commit/26eb4285b56edd8c897642078d91f16ff0fd3472.diff

LOG: [MLIR][LLVM] Add vararg support in LLVM::CallOp and InvokeOp (#67274)

In order to support indirect vararg calls, we need to have information about the
callee type - this patch adds a `callee_type` attribute that holds that.

The attribute is required for vararg calls, else, it is optional and the callee
type is inferred by the operands and results of the operation if not present.

The syntax for non-vararg calls remains the same, whereas for vararg calls, it
is changed to this:

```
llvm.call %p(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
llvm.call @s(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
```

Added: 
    

Modified: 
    mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
    mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
    mlir/test/Conversion/GPUToROCDL/typed-pointers.mlir
    mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/Import/exception.ll
    mlir/test/Target/LLVMIR/Import/instructions.ll
    mlir/test/Target/LLVMIR/llvmir.mlir
    mlir/test/Target/LLVMIR/openmp-nested.mlir
    mlir/test/mlir-cpu-runner/x86-varargs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index a10588e51d9b4f3..684ce37b2398ce2 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -61,6 +61,7 @@ class PrintOpLowering : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    auto context = rewriter.getContext();
     auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
@@ -92,8 +93,8 @@ class PrintOpLowering : public ConversionPattern {
 
       // Insert a newline after each of the inner dimensions of the shape.
       if (i != e - 1)
-        rewriter.create<func::CallOp>(loc, printfRef,
-                                      rewriter.getIntegerType(32), newLineCst);
+        rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef,
+                                      newLineCst);
       rewriter.create<scf::YieldOp>(loc);
       rewriter.setInsertionPointToStart(loop.getBody());
     }
@@ -102,8 +103,8 @@ class PrintOpLowering : public ConversionPattern {
     auto printOp = cast<toy::PrintOp>(op);
     auto elementLoad =
         rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
-    rewriter.create<func::CallOp>(
-        loc, printfRef, rewriter.getIntegerType(32),
+    rewriter.create<LLVM::CallOp>(
+        loc, getPrintfType(context), printfRef,
         ArrayRef<Value>({formatSpecifierCst, elementLoad}));
 
     // Notify the rewriter that this operation has been removed.
@@ -112,6 +113,16 @@ class PrintOpLowering : public ConversionPattern {
   }
 
 private:
+  /// Create a function declaration for printf, the signature is:
+  ///   * `i32 (i8*, ...)`
+  static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
+    auto llvmI32Ty = IntegerType::get(context, 32);
+    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+                                                  /*isVarArg=*/true);
+    return llvmFnType;
+  }
+
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
   static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
@@ -120,17 +131,11 @@ class PrintOpLowering : public ConversionPattern {
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get(context, "printf");
 
-    // Create a function declaration for printf, the signature is:
-    //   * `i32 (i8*, ...)`
-    auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
-                                                  /*isVarArg=*/true);
-
     // Insert the printf function into the body of the parent module.
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
-    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
+    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf",
+                                      getPrintfType(context));
     return SymbolRefAttr::get(context, "printf");
   }
 

diff  --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index a10588e51d9b4f3..684ce37b2398ce2 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -61,6 +61,7 @@ class PrintOpLowering : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    auto context = rewriter.getContext();
     auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
@@ -92,8 +93,8 @@ class PrintOpLowering : public ConversionPattern {
 
       // Insert a newline after each of the inner dimensions of the shape.
       if (i != e - 1)
-        rewriter.create<func::CallOp>(loc, printfRef,
-                                      rewriter.getIntegerType(32), newLineCst);
+        rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef,
+                                      newLineCst);
       rewriter.create<scf::YieldOp>(loc);
       rewriter.setInsertionPointToStart(loop.getBody());
     }
@@ -102,8 +103,8 @@ class PrintOpLowering : public ConversionPattern {
     auto printOp = cast<toy::PrintOp>(op);
     auto elementLoad =
         rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
-    rewriter.create<func::CallOp>(
-        loc, printfRef, rewriter.getIntegerType(32),
+    rewriter.create<LLVM::CallOp>(
+        loc, getPrintfType(context), printfRef,
         ArrayRef<Value>({formatSpecifierCst, elementLoad}));
 
     // Notify the rewriter that this operation has been removed.
@@ -112,6 +113,16 @@ class PrintOpLowering : public ConversionPattern {
   }
 
 private:
+  /// Create a function declaration for printf, the signature is:
+  ///   * `i32 (i8*, ...)`
+  static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
+    auto llvmI32Ty = IntegerType::get(context, 32);
+    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+                                                  /*isVarArg=*/true);
+    return llvmFnType;
+  }
+
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
   static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
@@ -120,17 +131,11 @@ class PrintOpLowering : public ConversionPattern {
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get(context, "printf");
 
-    // Create a function declaration for printf, the signature is:
-    //   * `i32 (i8*, ...)`
-    auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
-                                                  /*isVarArg=*/true);
-
     // Insert the printf function into the body of the parent module.
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
-    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
+    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf",
+                                      getPrintfType(context));
     return SymbolRefAttr::get(context, "printf");
   }
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 10bc8afeefa1777..cffa704e3145f1c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -542,7 +542,9 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       DeclareOpInterfaceMethods<CallOpInterface>,
                       DeclareOpInterfaceMethods<BranchWeightOpInterface>,
                       Terminator]> {
-  let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+  let arguments = (ins
+                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
@@ -552,21 +554,21 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                               AnySuccessor:$unwindDest);
 
   let builders = [
+    OpBuilder<(ins "LLVMFuncOp":$func,
+      "ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
+      "Block*":$unwind, "ValueRange":$unwindOps)>,
     OpBuilder<(ins "TypeRange":$tys, "FlatSymbolRefAttr":$callee,
       "ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
-      "Block*":$unwind, "ValueRange":$unwindOps),
-    [{
-      $_state.addAttribute("callee", callee);
-      build($_builder, $_state, tys, ops, normal, normalOps, unwind, unwindOps);
-    }]>,
-    OpBuilder<(ins "TypeRange":$tys, "ValueRange":$ops, "Block*":$normal,
-      "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps),
-    [{
-      build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
-            unwindOps, nullptr, normal, unwind);
-    }]>];
+      "Block*":$unwind, "ValueRange":$unwindOps)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+      "ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
+      "Block*":$unwind, "ValueRange":$unwindOps)>];
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Returns the callee function type.
+    LLVMFunctionType getCalleeFunctionType();
+  }];
 }
 
 def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
@@ -581,9 +583,6 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 // CallOp
 //===----------------------------------------------------------------------===//
 
-// FIXME: Add a type attribute that carries the LLVM function type to support
-// indirect calls to variadic functions. The type attribute is necessary to
-// distinguish normal and variadic arguments.
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                     [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
@@ -599,10 +598,11 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     The `call` instruction supports both direct and indirect calls. Direct calls
     start with a function name (`@`-prefixed) and indirect calls start with an
     SSA value (`%`-prefixed). The direct callee, if present, is stored as a
-    function attribute `callee`. The trailing type list contains the optional
-    indirect callee type and the MLIR function type, which 
diff ers from the
-    LLVM function type that uses a explicit void type to model functions that do
-    not return a value.
+    function attribute `callee`. If the callee is a variadic function, then the
+    `callee_type` attribute must carry the function type. The trailing type list
+    contains the optional indirect callee type and the MLIR function type, which
+    
diff ers from the LLVM function type that uses a explicit void type to model
+    functions that do not return a value.
 
     Examples:
 
@@ -615,10 +615,17 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
 
     // Indirect call with an argument and without a result.
     llvm.call %1(%0) : !llvm.ptr, (f32) -> ()
+
+    // Direct variadic call.
+    llvm.call @printf(%0, %1) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32) -> i32
+
+    // Indirect variadic call
+    llvm.call %1(%0) vararg(!llvm.func<void (...)>) : !llvm.ptr, (i32) -> ()
     ```
   }];
 
-  dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                  OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
@@ -628,14 +635,25 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
   let results = (outs Optional<LLVM_Type>:$result);
   let builders = [
     OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "ValueRange":$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
                    CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
                    CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
   let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+    /// Returns the callee function type.
+    LLVMFunctionType getCalleeFunctionType();
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d5f2f47bc147e6c..7f297fe79177560 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1018,6 +1018,27 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
 // CallOp
 //===----------------------------------------------------------------------===//
 
+/// Gets the MLIR Op-like result types of a LLVMFunctionType.
+static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
+  SmallVector<Type, 1> results;
+  Type resultType = calleeType.getReturnType();
+  if (!isa<LLVM::LLVMVoidType>(resultType))
+    results.push_back(resultType);
+  return results;
+}
+
+/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
+static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
+                                        ValueRange args) {
+  Type resultType;
+  if (results.empty())
+    resultType = LLVMVoidType::get(context);
+  else
+    resultType = results.front();
+  return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
+                               /*isVarArg=*/false);
+}
+
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    StringRef callee, ValueRange args) {
   build(builder, state, results, builder.getStringAttr(callee), args);
@@ -1030,7 +1051,43 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
-  build(builder, state, results, callee, args, /*fastmathFlags=*/nullptr,
+  build(builder, state, results,
+        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
+        callee, args,
+        /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
+        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringRef callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringAttr callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+                   ValueRange args) {
+  build(builder, state, getCallOpResultTypes(calleeType),
+        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
+        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, ValueRange args) {
+  build(builder, state, getCallOpResultTypes(calleeType),
+        TypeAttr::get(calleeType),
+        /*callee=*/nullptr, args,
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1038,11 +1095,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
-  SmallVector<Type> results;
-  Type resultType = func.getFunctionType().getReturnType();
-  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
-    results.push_back(resultType);
-  build(builder, state, results, SymbolRefAttr::get(func), args,
+  auto calleeType = func.getFunctionType();
+  build(builder, state, getCallOpResultTypes(calleeType),
+        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1148,14 +1203,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (!funcType)
     return emitOpError("callee does not have a functional type: ") << fnType;
 
-  // Indirect variadic function calls are not supported since the translation to
-  // LLVM IR reconstructs the LLVM function type from the argument and result
-  // types. An additional type attribute that stores the LLVM function type
-  // would be needed to distinguish normal and variadic function arguments.
-  // TODO: Support indirect calls to variadic function pointers.
-  if (isIndirect && funcType.isVarArg())
-    return emitOpError()
-           << "indirect calls to variadic functions are not supported";
+  if (funcType.isVarArg() && !getCalleeType())
+    return emitOpError() << "missing callee type attribute for vararg call";
 
   // Verify that the operand and result types match the callee.
 
@@ -1202,6 +1251,14 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
+  LLVMFunctionType calleeType;
+  bool isVarArg = false;
+
+  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
+    calleeType = *optionalCalleeType;
+    isVarArg = calleeType.isVarArg();
+  }
+
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
   p << ' ';
@@ -1212,7 +1269,12 @@ void CallOp::print(OpAsmPrinter &p) {
 
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
-  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
+
+  if (isVarArg)
+    p << " vararg(" << calleeType << ")";
+
+  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
+                          {"callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1280,10 +1342,13 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (function-id | ssa-use)`(` ssa-use-list `)`
+// <operation> ::= `llvm.call` (function-id | ssa-use)
+//                             `(` ssa-use-list `)`
+//                             ( `vararg(` var-arg-func-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
+  TypeAttr calleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Parse a function pointer for indirect calls.
@@ -1301,14 +1366,57 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
+  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+  if (isVarArg) {
+    if (parser.parseLParen().failed() ||
+        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+            .failed() ||
+        parser.parseRParen().failed())
+      return failure();
+  }
+
   // Parse the trailing type list and resolve the operands.
   return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
 }
 
+LLVMFunctionType CallOp::getCalleeFunctionType() {
+  if (getCalleeType())
+    return *getCalleeType();
+  else
+    return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
+}
+
 ///===---------------------------------------------------------------------===//
 /// LLVM::InvokeOp
 ///===---------------------------------------------------------------------===//
 
+void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
+                     ValueRange ops, Block *normal, ValueRange normalOps,
+                     Block *unwind, ValueRange unwindOps) {
+  auto calleeType = func.getFunctionType();
+  build(builder, state, getCallOpResultTypes(calleeType),
+        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
+        unwindOps, nullptr, normal, unwind);
+}
+
+void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
+                     FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
+                     ValueRange normalOps, Block *unwind,
+                     ValueRange unwindOps) {
+  build(builder, state, tys,
+        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
+        ops, normalOps, unwindOps, nullptr, normal, unwind);
+}
+
+void InvokeOp::build(OpBuilder &builder, OperationState &state,
+                     LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+                     ValueRange ops, Block *normal, ValueRange normalOps,
+                     Block *unwind, ValueRange unwindOps) {
+  build(builder, state, getCallOpResultTypes(calleeType),
+        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
+        normal, unwind);
+}
+
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
   assert(index < getNumSuccessors() && "invalid successor index");
   return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
@@ -1362,6 +1470,14 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
+  LLVMFunctionType calleeType;
+  bool isVarArg = false;
+
+  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
+    calleeType = *optionalCalleeType;
+    isVarArg = calleeType.isVarArg();
+  }
+
   p << ' ';
 
   // Either function name or pointer
@@ -1376,8 +1492,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
+  if (isVarArg)
+    p << " vararg(" << calleeType << ")";
+
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      {InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1390,10 +1510,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
+//                  ( `vararg(` var-arg-func-type `)` )?
 //                  attribute-dict? `:` (type `,`)? function-type
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
+  TypeAttr calleeType;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1412,8 +1534,19 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseKeyword("to") ||
       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
       parser.parseKeyword("unwind") ||
-      parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
-      parser.parseOptionalAttrDict(result.attributes))
+      parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
+    return failure();
+
+  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+  if (isVarArg) {
+    if (parser.parseLParen().failed() ||
+        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+            .failed() ||
+        parser.parseRParen().failed())
+      return failure();
+  }
+
+  if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
@@ -1432,6 +1565,13 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
+LLVMFunctionType InvokeOp::getCalleeFunctionType() {
+  if (getCalleeType())
+    return *getCalleeType();
+  else
+    return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
+}
+
 ///===----------------------------------------------------------------------===//
 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
 ///===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 9b438090c84cacc..1c0f51a66bf5e8c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -183,22 +183,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
 
-  // Helper function to reconstruct the function type for an indirect call given
-  // the result and argument types. The function cannot reconstruct the type of
-  // variadic functions since the call operation does not carry enough
-  // information to distinguish normal and variadic arguments. Supporting
-  // indirect variadic calls requires an additional type attribute on the call
-  // operation that stores the LLVM function type of the callee.
-  // TODO: Support indirect calls to variadic function pointers.
-  auto getCalleeFunctionType = [&](TypeRange resultTypes, ValueRange args) {
-    Type resultType = resultTypes.empty()
-                          ? LLVMVoidType::get(opInst.getContext())
-                          : resultTypes.front();
-    return llvm::cast<llvm::FunctionType>(moduleTranslation.convertType(
-        LLVMFunctionType::get(opInst.getContext(), resultType,
-                              llvm::to_vector(args.getTypes()), false)));
-  };
-
   // Emit function calls.  If the "callee" attribute is present, this is a
   // direct function call and we also need to look up the remapped function
   // itself.  Otherwise, this is an indirect call and the callee is the first
@@ -211,9 +195,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(
           moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
     } else {
-      call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
-                                                      callOp.getArgOperands()),
-                                operandsRef.front(), operandsRef.drop_front());
+      llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
+          moduleTranslation.convertType(callOp.getCalleeFunctionType()));
+      call = builder.CreateCall(calleeType, operandsRef.front(),
+                                operandsRef.drop_front());
     }
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
@@ -297,9 +282,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
     } else {
+      llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
+          moduleTranslation.convertType(invOp.getCalleeFunctionType()));
       result = builder.CreateInvoke(
-          getCalleeFunctionType(invOp.getResultTypes(), invOp.getArgOperands()),
-          operandsRef.front(),
+          calleeType, operandsRef.front(),
           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 672d81bd3fd9129..d070e42ac0c7db7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1347,12 +1347,19 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (failed(convertCallTypeAndOperands(callInst, types, operands)))
       return failure();
 
+    auto funcTy =
+        dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
+    if (!funcTy)
+      return failure();
+
     CallOp callOp;
+
     if (llvm::Function *callee = callInst->getCalledFunction()) {
       callOp = builder.create<CallOp>(
-          loc, types, SymbolRefAttr::get(context, callee->getName()), operands);
+          loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+          operands);
     } else {
-      callOp = builder.create<CallOp>(loc, types, operands);
+      callOp = builder.create<CallOp>(loc, funcTy, operands);
     }
     setFastmathFlagsAttr(inst, callOp);
     if (!callInst->getType()->isVoidTy())
@@ -1411,20 +1418,25 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                  unwindArgs)))
       return failure();
 
+    auto funcTy =
+        dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
+    if (!funcTy)
+      return failure();
+
     // Create the invoke operation. Normal destination block arguments will be
     // added later on to handle the case in which the operation result is
     // included in this list.
     InvokeOp invokeOp;
     if (llvm::Function *callee = invokeInst->getCalledFunction()) {
       invokeOp = builder.create<InvokeOp>(
-          loc, types,
+          loc, funcTy,
           SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
           directNormalDest, ValueRange(),
           lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
     } else {
       invokeOp = builder.create<InvokeOp>(
-          loc, types, operands, directNormalDest, ValueRange(),
-          lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
+          loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
+          ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
     }
     if (!invokeInst->getType()->isVoidTy())
       mapValue(inst, invokeOp.getResults().front());

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
index dc776051c8aa7aa..68dfb852b88b52e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
@@ -8,7 +8,7 @@ gpu.module @test_module {
   gpu.func @test_printf(%arg0: i32) {
     // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<4>
     // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<4>) -> !llvm.ptr<4>, !llvm.array<11 x i8>
-    // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) : (!llvm.ptr<4>, i32) -> i32
+    // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<4>, ...)>) : (!llvm.ptr<4>, i32) -> i32
     gpu.printf "Hello: %d\n" %arg0 : i32
     gpu.return
   }

diff  --git a/mlir/test/Conversion/GPUToROCDL/typed-pointers.mlir b/mlir/test/Conversion/GPUToROCDL/typed-pointers.mlir
index e6f3ed7f2a5e4bf..004a526790fc513 100644
--- a/mlir/test/Conversion/GPUToROCDL/typed-pointers.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/typed-pointers.mlir
@@ -14,7 +14,7 @@ gpu.module @test_module {
   gpu.func @test_printf(%arg0: i32) {
     // OCL: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<array<11 x i8>, 4>
     // OCL-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<array<11 x i8>, 4>) -> !llvm.ptr<i8, 4>
-    // OCL-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) : (!llvm.ptr<i8, 4>, i32) -> i32
+    // OCL-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<i8, 4>, ...)>) : (!llvm.ptr<i8, 4>, i32) -> i32
 
     // HIP: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
     // HIP-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64

diff  --git a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
index bb177eb1500ad68..a87b1952b6dca70 100644
--- a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
@@ -35,14 +35,6 @@ func.func @gep_too_few_dynamic(%base : !llvm.ptr<f32>) {
 
 // -----
 
-func.func @call_variadic(%callee : !llvm.ptr<func<i8 (i8, ...)>>, %arg : i8) {
-  // expected-error at +1 {{indirect calls to variadic functions are not supported}}
-  llvm.call %callee(%arg) : !llvm.ptr<func<i8 (i8, ...)>>, (i8) -> (i8)
-  llvm.return
-}
-
-// -----
-
 func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr<func<void(i8)>>) {
   // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
   "llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, i32) -> ()

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 3db07f45a5babdc..6f119a140ba3c31 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1411,3 +1411,23 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
   // expected-error at +1 {{input type is a vector but output type is an integer}}
   %0 = llvm.zext %arg : vector<1xi32> to i64
 }
+
+// -----
+
+llvm.func @variadic(...)
+
+llvm.func @invalid_variadic_call(%arg: i32)  {
+  // expected-error at +1 {{missing callee type attribute for vararg call}}
+  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(...)
+
+llvm.func @invalid_variadic_call(%arg: i32)  {
+  // expected-error at +1 {{missing callee type attribute for vararg call}}
+  "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
+  llvm.return
+}

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 654048d873234d9..1134027c6b6570e 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -71,6 +71,14 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %20 = llvm.mlir.addressof @foo : !llvm.ptr
   %21 = llvm.call %20(%arg0) : !llvm.ptr, (i32) -> !llvm.struct<(i32, f64, i32)>
 
+// Variadic calls
+// CHECK:  llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
+// CHECK:  %[[VARIADIC_FUNC:.*]] = llvm.mlir.addressof @vararg_func : !llvm.ptr
+// CHECK:  llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
+  llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
+  %variadic_func = llvm.mlir.addressof @vararg_func : !llvm.ptr
+  llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
+
 // Terminator operations and their successors.
 //
 // CHECK: llvm.br ^[[BB1:.*]]
@@ -183,6 +191,8 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64, %ptr2: !llvm.ptr) {
   llvm.return
 }
 
+llvm.func @vararg_foo(i32, ...) -> !llvm.struct<(i32, f64, i32)>
+
 // An larger self-contained function.
 // CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> {
 llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> {
@@ -427,8 +437,21 @@ llvm.func @invokeLandingpad() -> i32 attributes { personality = @__gxx_personali
   %13 = llvm.invoke %12(%5) to ^bb2 unwind ^bb1 : !llvm.ptr, (i32) -> !llvm.struct<(i32, f64, i32)>
 
 // CHECK: ^[[BB5:.*]]:
-// CHECK:   llvm.return %[[V0]] : i32
+// CHECK: %{{.*}} = llvm.invoke @{{.*}} vararg(!llvm.func<struct<(i32, f64, i32)> (i32, ...)>) : (i32, i32) -> !llvm.struct<(i32, f64, i32)>
+
 ^bb5:
+  %14 = llvm.invoke @vararg_foo(%5, %5) to ^bb2 unwind ^bb1 vararg(!llvm.func<struct<(i32, f64, i32)> (i32, ...)>) : (i32, i32) -> !llvm.struct<(i32, f64, i32)>
+
+// CHECK: ^[[BB6:.*]]:
+// CHECK: %[[FUNC:.*]] = llvm.mlir.addressof @vararg_foo : !llvm.ptr
+// CHECK: %{{.*}} = llvm.invoke %[[FUNC]]{{.*}} vararg(!llvm.func<struct<(i32, f64, i32)> (i32, ...)>) : !llvm.ptr, (i32, i32) -> !llvm.struct<(i32, f64, i32)>
+^bb6:
+  %15 = llvm.mlir.addressof @vararg_foo : !llvm.ptr
+  %16 = llvm.invoke %15(%5, %5) to ^bb2 unwind ^bb1 vararg(!llvm.func<!llvm.struct<(i32, f64, i32)> (i32, ...)>) : !llvm.ptr, (i32, i32) -> !llvm.struct<(i32, f64, i32)>
+
+// CHECK: ^[[BB7:.*]]:
+// CHECK:   llvm.return %[[V0]] : i32
+^bb7:
   llvm.return %0 : i32
 }
 

diff  --git a/mlir/test/Target/LLVMIR/Import/exception.ll b/mlir/test/Target/LLVMIR/Import/exception.ll
index 1bdfa6f3646e9bd..de227645cc15bbe 100644
--- a/mlir/test/Target/LLVMIR/Import/exception.ll
+++ b/mlir/test/Target/LLVMIR/Import/exception.ll
@@ -3,6 +3,7 @@
 @_ZTIi = external dso_local constant ptr
 @_ZTIii= external dso_local constant ptr
 declare void @foo(ptr)
+declare void @vararg_foo(ptr, ...)
 declare ptr @bar(ptr)
 declare i32 @__gxx_personality_v0(...)
 
@@ -29,6 +30,14 @@ define i32 @invokeLandingpad() personality ptr @__gxx_personality_v0 {
   %6 = invoke ptr @bar(ptr %1) to label %4 unwind label %2
 
 ; CHECK: ^bb4:
+  ; CHECK: llvm.invoke @vararg_foo(%[[a3]], %{{.*}}) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (ptr, ...)>) : (!llvm.ptr, i32) -> ()
+  invoke void (ptr, ...) @vararg_foo(ptr %1, i32 0) to label %4 unwind label %2
+
+; CHECK: ^bb5:
+  ; CHECK: llvm.invoke %{{.*}}(%[[a3]], %{{.*}}) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (ptr, ...)>) : !llvm.ptr, (!llvm.ptr, i32) -> ()
+  invoke void (ptr, ...) undef(ptr %1, i32 0) to label %4 unwind label %2
+
+; CHECK: ^bb6:
   ; CHECK: llvm.return %{{[0-9]+}} : i32
   ret i32 0
 }

diff  --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index b72e72c8392d910..036efad0d099ae1 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -480,6 +480,17 @@ define void @indirect_call(ptr addrspace(42) %fn) {
 
 ; // -----
 
+; CHECK-LABEL: @indirect_vararg_call
+; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+define void @indirect_vararg_call(ptr addrspace(42) %fn) {
+  ; CHECK:  %[[C0:[0-9]+]] = llvm.mlir.constant(0 : i16) : i16
+  ; CHECK:  llvm.call %[[PTR]](%[[C0]]) vararg(!llvm.func<void (...)>) : !llvm.ptr<42>, (i16) -> ()
+  call addrspace(42) void (...) %fn(i16 0)
+  ret void
+}
+
+; // -----
+
 ; CHECK-LABEL: @gep_static_idx
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @gep_static_idx(ptr %ptr) {
@@ -497,7 +508,7 @@ declare void @varargs(...)
 ; CHECK-LABEL: @varargs_call
 ; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
 define void @varargs_call(i32 %0) {
-  ; CHECK:  llvm.call @varargs(%[[ARG1]]) : (i32) -> ()
+  ; CHECK:  llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
   call void (...) @varargs(i32 %0)
   ret void
 }

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8317fae7dd71099..7da44b6fbe1ab33 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1238,7 +1238,14 @@ llvm.func @varargs(...)
 // CHECK-LABEL: define void @varargs_call
 llvm.func @varargs_call(%arg0 : i32) {
 // CHECK:  call void (...) @varargs(i32 %{{.*}})
-  llvm.call @varargs(%arg0) : (i32) -> ()
+  llvm.call @varargs(%arg0) vararg(!llvm.func<void (...)>) : (i32) -> ()
+  llvm.return
+}
+
+// CHECK-LABEL: define void @indirect_varargs_call(ptr %0, i32 %1)
+llvm.func @indirect_varargs_call(%arg0 : !llvm.ptr, %arg1 : i32) {
+// CHECK:  call void (...) %0(i32 %1)
+  llvm.call %arg0(%arg1) vararg(!llvm.func<void (...)>) : !llvm.ptr, (i32) -> ()
   llvm.return
 }
 
@@ -1526,6 +1533,7 @@ llvm.func @cmpxchg(%ptr : !llvm.ptr<i32>, %cmp : i32, %val: i32) {
 
 llvm.mlir.global external constant @_ZTIi() : !llvm.ptr<i8>
 llvm.func @foo(!llvm.ptr<i8>)
+llvm.func @vararg_foo(!llvm.ptr<i8>, ...)
 llvm.func @bar(!llvm.ptr<i8>) -> !llvm.ptr<i8>
 llvm.func @__gxx_personality_v0(...) -> i32
 
@@ -1563,6 +1571,17 @@ llvm.func @invokeLandingpad() -> i32 attributes { personality = @__gxx_personali
 // CHECK-NEXT:          to label %[[normal]] unwind label %[[unwind]]
 ^bb3:	// pred: ^bb1
   %8 = llvm.invoke @bar(%6) to ^bb2 unwind ^bb1 : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
+
+// CHECK: [[BB4:.*]]:
+// CHECK: invoke void (ptr, ...) @vararg_foo(ptr %[[a1]], i32 0)
+^bb4:
+  llvm.invoke @vararg_foo(%6, %0) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (ptr<i8>, ...)>) : (!llvm.ptr<i8>, i32) -> ()
+
+// CHECK: [[BB5:.*]]:
+// CHECK: invoke void (ptr, ...) undef(ptr %[[a1]], i32 0)
+^bb5:
+  %9 = llvm.mlir.undef : !llvm.ptr
+  llvm.invoke %9(%6, %0) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (ptr<i8>, ...)>) : !llvm.ptr, (!llvm.ptr<i8>, i32) -> ()
 }
 
 // -----

diff  --git a/mlir/test/Target/LLVMIR/openmp-nested.mlir b/mlir/test/Target/LLVMIR/openmp-nested.mlir
index 5e047d5a58d288e..4a96746cbbc9137 100644
--- a/mlir/test/Target/LLVMIR/openmp-nested.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-nested.mlir
@@ -23,7 +23,7 @@ module {
         %20 = llvm.trunc %19 : i64 to i32
         %5 = llvm.mlir.addressof @str0 : !llvm.ptr<array<29 x i8>>
         %6 = llvm.getelementptr %5[%4, %4] : (!llvm.ptr<array<29 x i8>>, i32, i32) -> !llvm.ptr<i8>
-        %21 = llvm.call @printf(%6, %20, %20) : (!llvm.ptr<i8>, i32, i32) -> i32
+        %21 = llvm.call @printf(%6, %20, %20) vararg(!llvm.func<i32 (ptr<i8>, ...)>): (!llvm.ptr<i8>, i32, i32) -> i32
         omp.yield
       }
       omp.terminator

diff  --git a/mlir/test/mlir-cpu-runner/x86-varargs.mlir b/mlir/test/mlir-cpu-runner/x86-varargs.mlir
index 44024113c2b7b68..f3f4322ce87975e 100644
--- a/mlir/test/mlir-cpu-runner/x86-varargs.mlir
+++ b/mlir/test/mlir-cpu-runner/x86-varargs.mlir
@@ -10,7 +10,7 @@ llvm.func @caller() -> i32 {
   %0 = llvm.mlir.constant(3 : i32) : i32
   %1 = llvm.mlir.constant(2 : i32) : i32
   %2 = llvm.mlir.constant(1 : i32) : i32
-  %3 = llvm.call @foo(%2, %1, %0) : (i32, i32, i32) -> i32
+  %3 = llvm.call @foo(%2, %1, %0) vararg(!llvm.func<i32 (i32, ...)>) : (i32, i32, i32) -> i32
   llvm.return %3 : i32
 }
 


        


More information about the Mlir-commits mailing list