[Mlir-commits] [mlir] [MLIR][LLVM] Always print variadic callee type (PR #99293)
Tobias Gysi
llvmlistbot at llvm.org
Wed Jul 17 05:42:57 PDT 2024
https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/99293
>From 5ff39fb112ebda763d147ba91d8f430c8854f678 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 17 Jul 2024 08:02:11 +0000
Subject: [PATCH] [MLIR][LLVM] Always print variadic callee type
This commit updates the LLVM dialect CallOp and InvokeOp to always print
the callee type if present. An additional verifier checks that only
variadic calls have a non-null callee type, and the builders are adapted
accordingly to only set the callee type for variadic calls. To reflect
this change, the calleeType is renamed to varCalleeType.
The motivation of this change is that CallOp and InvokeOp don't have
hidden state that is not pretty printed, but used during the export to
LLVM IR. Previously, it could happen that a call looked correct in MLIR,
but the return type changed after exporting to LLVM IR (since it has
been taken from the hidden calleeType). After landing this change, this
is not possible anymore since the variadic callee type is always printed
if present.
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 15 +--
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 101 +++++++++++---------
mlir/test/Dialect/LLVMIR/invalid.mlir | 31 +++++-
3 files changed, 92 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f0dec69a5032a..7adbe8a4e5b31 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -560,7 +560,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
let arguments = (ins
- OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+ OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
start with a function name (`@`-prefixed) and indirect calls start with an
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
- and is stored as the first value in `callee_operands`. If the callee is a variadic
- function, then the `callee_type` attribute must carry the function type. The
- trailing type list contains the optional indirect callee type and the MLIR
- function type, which differs from the LLVM function type that uses a explicit
- void type to model functions that do not return a value.
+ and is stored as the first value in `callee_operands`. If and only if the
+ callee is a variadic function, then the `var_callee_type` attribute must
+ carry the variadic LLVM function type. The trailing type list contains the
+ optional indirect callee type and the MLIR function type, which differs from
+ the LLVM function type that uses an explicit void type to model functions
+ that do not return a value.
Examples:
@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];
- dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+ dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..40e1187e80216 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
return results;
}
+/// Gets the variadic callee type for a LLVMFunctionType.
+static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
+ return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+}
+
/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
- TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
- callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+ /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+ /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
- TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+ getCallOpVarCalleeType(calleeType), callee, args,
+ /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
- TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+ getCallOpVarCalleeType(calleeType),
+ /*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
- TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+ getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1087,11 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
+ // Verify the variadic callee type is a variadic function type.
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ if (!varCalleeType->isVarArg())
+ return emitOpError("expected variadic callee type attribute");
+
// Type for the callee, we'll get it differently depending if it is a direct
// or indirect call.
Type fnType;
@@ -1120,7 +1132,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (!funcType)
return emitOpError("callee does not have a functional type: ") << fnType;
- if (funcType.isVarArg() && !getCalleeType())
+ if (funcType.isVarArg() && !getVarCalleeType())
return emitOpError() << "missing callee type attribute for vararg call";
// Verify that the operand and result types match the callee.
@@ -1168,14 +1180,6 @@ void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
- LLVMFunctionType calleeType;
- bool isVarArg = false;
-
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
- calleeType = *optionalCalleeType;
- isVarArg = calleeType.isVarArg();
- }
-
p << ' ';
// Print calling convention.
@@ -1195,11 +1199,13 @@ void CallOp::print(OpAsmPrinter &p) {
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
- if (isVarArg)
- p << " vararg(" << calleeType << ")";
+ // Print the variadic callee type if the call is variadic.
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ p << " vararg(" << *varCalleeType << ")";
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
- {getCConvAttrName(), "callee", "callee_type",
+ {getCConvAttrName(), "callee",
+ getVarCalleeTypeAttrName(),
getTailCallKindAttrName()});
p << " : ";
@@ -1270,11 +1276,11 @@ static ParseResult parseOptionalCallFuncPtr(
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
-// ( `vararg(` var-arg-func-type `)` )?
+// ( `vararg(` var-callee-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
- TypeAttr calleeType;
+ TypeAttr varCalleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
// Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1311,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
+ StringAttr varCalleeTypeAttrName =
+ CallOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
- parser.parseAttribute(calleeType, "callee_type", result.attributes)
+ parser
+ .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+ result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
@@ -1320,8 +1330,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
}
LLVMFunctionType CallOp::getCalleeFunctionType() {
- if (getCalleeType())
- return *getCalleeType();
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}
@@ -1334,8 +1344,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
Block *unwind, ValueRange unwindOps) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
- TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
- unwindOps, nullptr, nullptr, normal, unwind);
+ getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
+ normalOps, unwindOps, nullptr, nullptr, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,8 +1353,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
ValueRange normalOps, Block *unwind,
ValueRange unwindOps) {
build(builder, state, tys,
- TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
- ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+ /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+ nullptr, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1352,8 +1362,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
ValueRange ops, Block *normal, ValueRange normalOps,
Block *unwind, ValueRange unwindOps) {
build(builder, state, getCallOpResultTypes(calleeType),
- TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
- nullptr, normal, unwind);
+ getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
+ nullptr, nullptr, normal, unwind);
}
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1403,11 @@ LogicalResult InvokeOp::verify() {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
+ // Verify the variadic callee type is a variadic function type.
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ if (!varCalleeType->isVarArg())
+ return emitOpError("expected variadic callee type attribute");
+
Block *unwindDest = getUnwindDest();
if (unwindDest->empty())
return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1424,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
- LLVMFunctionType calleeType;
- bool isVarArg = false;
-
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
- calleeType = *optionalCalleeType;
- isVarArg = calleeType.isVarArg();
- }
-
p << ' ';
// Print calling convention.
@@ -1435,12 +1442,14 @@ void InvokeOp::print(OpAsmPrinter &p) {
p << " unwind ";
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
- if (isVarArg)
- p << " vararg(" << calleeType << ")";
+ // Print the variadic callee type if the invoke is variadic.
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ p << " vararg(" << *varCalleeType << ")";
p.printOptionalAttrDict((*this)->getAttrs(),
{InvokeOp::getOperandSegmentSizeAttr(), "callee",
- "callee_type", InvokeOp::getCConvAttrName()});
+ InvokeOp::getVarCalleeTypeAttrName(),
+ InvokeOp::getCConvAttrName()});
p << " : ";
if (!isDirect)
@@ -1453,12 +1462,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `(` ssa-use-list `)`
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
-// ( `vararg(` var-arg-func-type `)` )?
+// ( `vararg(` var-callee-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
- TypeAttr calleeType;
+ TypeAttr varCalleeType;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
@@ -1488,8 +1497,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
+ StringAttr varCalleeTypeAttrName =
+ InvokeOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
- parser.parseAttribute(calleeType, "callee_type", result.attributes)
+ parser
+ .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+ result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
@@ -1515,8 +1528,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
}
LLVMFunctionType InvokeOp::getCalleeFunctionType() {
- if (getCalleeType())
- return *getCalleeType();
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+ return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..e932187a8e614 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>) {
// -----
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32) {
+ // expected-error at below {{expected variadic callee type attribute}}
+ llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32) {
+ // expected-error at below {{expected variadic callee type attribute}}
+ llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+ llvm.return
+^bb2:
+ llvm.return
+}
+
+// -----
+
llvm.func @variadic(...)
llvm.func @invalid_variadic_call(%arg: i32) {
@@ -1445,14 +1468,14 @@ llvm.func @foo(%arg: !llvm.ptr) {
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
// expected-error at +1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
return
}
// -----
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
// expected-error at +1 {{im2col offsets must be 2 less than number of coordinates}}
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
return
}
@@ -1460,7 +1483,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
// expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
return
}
@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
// expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
return
}
More information about the Mlir-commits
mailing list