[Mlir-commits] [mlir] [MLIR][LLVM] Always print variadic callee type (PR #99293)

Tobias Gysi llvmlistbot at llvm.org
Mon Jul 22 22:55:52 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/99293

>From 6b774e063f7c23f990852b0fc140ac61a5c689a5 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 17 Jul 2024 08:02:11 +0000
Subject: [PATCH] [MLIR][LLVM] Always print variadic callee type

This commit updates the LLVM dialect CallOp and InvokeOp to always print
the variadic callee type (previously callee type) if present. An
additional verifier checks that only variadic calls have a non-null
variadic callee type, and the builders are adapted accordingly to
set the variadic callee type for variadic calls only. Finally, the
CallOp and InvokeOp verifiers are strengthened to check that the
variadic callee type matches the call argument and result types.

The motivation of this change is that CallOp and InvokeOp don't have
hidden state that is not pretty printed, but used during the export to
LLVM IR. Previously, it could happen that a call looked correct in MLIR,
but the return type changed after exporting to LLVM IR (since it has
been taken from the hidden callee type attribute). After landing this
change, this is not possible anymore since the variadic callee type is
always printed if present.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  17 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 143 +++++++++++------
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 165 +++++++++++++++++++-
 3 files changed, 260 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 06656c791c594..d2d1fbaf304b2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -560,14 +560,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       DeclareOpInterfaceMethods<BranchWeightOpInterface>,
                       Terminator]> {
   let arguments = (ins
-                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                    OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
                    DefaultValuedAttr<CConv, "CConv::C">:$CConv);
-  let results = (outs Variadic<LLVM_Type>);
+  let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
 
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     start with a function name (`@`-prefixed) and indirect calls start with an
     SSA value (`%`-prefixed). The direct callee, if present, is stored as a
     function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
-    and is stored as the first value in `callee_operands`. If the callee is a variadic
-    function, then the `callee_type` attribute must carry the function type. The
-    trailing type list contains the optional indirect callee type and the MLIR
-    function type, which differs from the LLVM function type that uses a explicit
-    void type to model functions that do not return a value.
+    and is stored as the first value in `callee_operands`. If and only if the
+    callee is a variadic function, the `var_callee_type` attribute must carry
+    the variadic LLVM function type. The trailing type list contains the
+    optional indirect callee type and the MLIR function type, which differs from
+    the LLVM function type that uses an explicit void type to model functions
+    that do not return a value.
 
     Examples:
 
@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..b572b79d089a6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
   return results;
 }
 
+/// Gets the variadic callee type for a LLVMFunctionType.
+static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
+  return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+}
+
 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
                                         ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
-        callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+        getCallOpVarCalleeType(calleeType), callee, args,
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+        getCallOpVarCalleeType(calleeType),
+        /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
   return success();
 }
 
+/// Verify that the parameter and return types of the variadic callee type match
+/// the `callOp` argument and result types.
+template <typename OpTy>
+LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
+  std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
+  if (!varCalleeType)
+    return success();
+
+  // Verify the variadic callee type is a variadic function type.
+  if (!varCalleeType->isVarArg())
+    return callOp.emitOpError(
+        "expected var_callee_type to be a variadic function type");
+
+  // Verify the variadic callee type has at most as many parameters as the call
+  // has argument operands.
+  if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
+    return callOp.emitOpError("expected var_callee_type to have at most ")
+           << callOp.getArgOperands().size() << " parameters";
+
+  // Verify the variadic callee type matches the call argument types.
+  for (auto [paramType, operand] :
+       llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
+    if (paramType != operand.getType())
+      return callOp.emitOpError()
+             << "var_callee_type parameter type mismatch: " << paramType
+             << " != " << operand.getType();
+
+  // Verify the variadic callee type matches the call result type.
+  if (!callOp.getNumResults()) {
+    if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
+      return callOp.emitOpError("expected var_callee_type to return void");
+  } else {
+    if (callOp.getResult().getType() != varCalleeType->getReturnType())
+      return callOp.emitOpError("var_callee_type return type mismatch: ")
+             << varCalleeType->getReturnType()
+             << " != " << callOp.getResult().getType();
+  }
+  return success();
+}
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  if (getNumResults() > 1)
-    return emitOpError("must have 0 or 1 result");
+  if (failed(verifyCallOpVarCalleeType(*this)))
+    return failure();
 
   // Type for the callee, we'll get it differently depending if it is a direct
   // or indirect call.
@@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (!funcType)
     return emitOpError("callee does not have a functional type: ") << fnType;
 
-  if (funcType.isVarArg() && !getCalleeType())
-    return emitOpError() << "missing callee type attribute for vararg call";
+  if (funcType.isVarArg() && !getVarCalleeType())
+    return emitOpError() << "missing var_callee_type attribute for vararg call";
 
   // Verify that the operand and result types match the callee.
 
@@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the call is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {getCConvAttrName(), "callee", "callee_type",
-                           getTailCallKindAttrName()});
+                          {getCalleeAttrName(), getTailCallKindAttrName(),
+                           getVarCalleeTypeAttrName(), getCConvAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(
 
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
-//                             ( `vararg(` var-arg-func-type `)` )?
+//                             ( `vararg(` var-callee-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        CallOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType CallOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
@@ -1334,8 +1378,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                      Block *unwind, ValueRange unwindOps) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
+        normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,8 +1387,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+        nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1352,8 +1396,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
+        nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
 }
 
 LogicalResult InvokeOp::verify() {
-  if (getNumResults() > 1)
-    return emitOpError("must have 0 or 1 result");
+  if (failed(verifyCallOpVarCalleeType(*this)))
+    return failure();
 
   Block *unwindDest = getUnwindDest();
   if (unwindDest->empty())
@@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the invoke is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict((*this)->getAttrs(),
-                          {InvokeOp::getOperandSegmentSizeAttr(), "callee",
-                           "callee_type", InvokeOp::getCConvAttrName()});
+                          {getCalleeAttrName(), getOperandSegmentSizeAttr(),
+                           getCConvAttrName(), getVarCalleeTypeAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
-//                  ( `vararg(` var-arg-func-type `)` )?
+//                  ( `vararg(` var-callee-type `)` )?
 //                  attribute-dict? `:` (type `,`)? function-type
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        InvokeOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..fe288dab973f5 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,10 +1415,163 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
 
 // -----
 
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_var_callee_type(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to be a variadic function type}}
+  llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_num_parameters(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<void (i32, i64, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_num_parameters_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error at below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.call %callee(%arg) vararg(!llvm.func<void (i32, i64, ...)>) : !llvm.ptr, (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch(%arg: i32)  {
+  // expected-error at below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<void (i64, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error at below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.call %callee(%arg) vararg(!llvm.func<void (i64, ...)>) : !llvm.ptr, (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_non_void(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to return void}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...) -> i32
+
+llvm.func @invalid_var_callee_type_return_type_mismatch(%arg: i32)  {
+  // expected-error at below {{var_callee_type return type mismatch: 'i8' != 'i32'}}
+  %0 = llvm.call @variadic(%arg) vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> (i32)
+  llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_var_callee_type(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to be a variadic function type}}
+  llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_num_parameters(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.invoke @variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32, i64, ...)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_num_parameters_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error at below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.invoke %callee(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32, i64, ...)>) : !llvm.ptr, (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch(%arg: i32)  {
+  // expected-error at below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.invoke @variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i64, ...)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error at below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.invoke %callee(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i64, ...)>) : !llvm.ptr, (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_non_void(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to return void}}
+  llvm.invoke @variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...) -> i32
+
+llvm.func @invalid_var_callee_type_return_type_mismatch(%arg: i32)  {
+  // expected-error at below {{var_callee_type return type mismatch: 'i8' != 'i32'}}
+  %0 = llvm.invoke @variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> (i32)
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
 llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
-  // expected-error at +1 {{missing callee type attribute for vararg call}}
+  // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
   "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
   llvm.return
 }
@@ -1428,7 +1581,7 @@ llvm.func @invalid_variadic_call(%arg: i32)  {
 llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
-  // expected-error at +1 {{missing callee type attribute for vararg call}}
+  // expected-error at +1 {{missing var_callee_type attribute for vararg call}}
   "llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
   llvm.return
 }
@@ -1445,14 +1598,14 @@ llvm.func @foo(%arg: !llvm.ptr) {
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 // -----
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{im2col offsets must be 2 less than number of coordinates}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1460,7 +1613,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1469,7 +1622,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 



More information about the Mlir-commits mailing list