[Mlir-commits] [mlir] [MLIR][LLVM] Add vararg support in LLVM::LLVMFuncOp (PR #67274)
Ivan R. Ivanov
llvmlistbot at llvm.org
Mon Sep 25 21:59:20 PDT 2023
https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/67274
>From ee20d089575575c420d6f475cdd1040b1169774b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:46:17 +0900
Subject: [PATCH 01/10] Build LLVM::CallOp with callee type info
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 9 +++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 51 ++++++++++++++++++---
2 files changed, 53 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2b4c8b609cfdd4f..63deb00802446b9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];
- dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+ dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+ OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
@@ -630,6 +631,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index fd0d2b3fb3c1a08..2def142b1d2d457 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,6 +950,16 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
// CallOp
//===----------------------------------------------------------------------===//
+namespace {
+static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
+ SmallVector<Type> results;
+ Type resultType = calleeType.getReturnType();
+ if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+ results.push_back(resultType);
+ return results;
+}
+} // namespace
+
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
build(builder, state, results, builder.getStringAttr(callee), args);
@@ -962,7 +972,38 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
- build(builder, state, results, callee, args, /*fastmathFlags=*/nullptr,
+ Type resultType;
+ if (results.empty())
+ resultType = LLVMVoidType::get(builder.getContext());
+ else
+ resultType = results.front();
+ std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
+ auto calleeType =
+ LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+ build(builder, state, results, TypeAttr::get(calleeType), callee, args,
+ /*fastmathFlags=*/nullptr,
+ /*branch_weights=*/nullptr,
+ /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, StringRef callee,
+ ValueRange args) {
+ build(builder, state, calleeType, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, StringAttr callee,
+ ValueRange args) {
+ build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+ ValueRange args) {
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -970,11 +1011,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
- SmallVector<Type> results;
- Type resultType = func.getFunctionType().getReturnType();
- if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
- results.push_back(resultType);
- build(builder, state, results, SymbolRefAttr::get(func), args,
+ auto calleeType = func.getFunctionType();
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
>From 6f95cba5452169d46679a9adfe6296bc1899605a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:54:54 +0900
Subject: [PATCH 02/10] Fix Conversion of CallOp func type
---
.../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 9b438090c84cacc..c27bc1b6f29076b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,9 +211,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
call = builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
} else {
- call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
- callOp.getArgOperands()),
- operandsRef.front(), operandsRef.drop_front());
+ call = builder.CreateCall(
+ llvm::cast<llvm::FunctionType>(
+ moduleTranslation.convertType(callOp.getCalleeType())),
+ operandsRef.front(), operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
>From 5c54d75f18134f4867217751c1be605bef7a7f3b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 25 Sep 2023 10:37:30 +0900
Subject: [PATCH 03/10] Print vararg info
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 36 ++++++++++++++++---
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 2 +-
3 files changed, 33 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 63deb00802446b9..2f65077dddf5aaf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];
- dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+ dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2def142b1d2d457..87138d24f7fc09a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1172,6 +1172,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
+ auto calleeType = *getCalleeType();
+ bool isVarArg = calleeType.isVarArg();
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
@@ -1181,9 +1183,13 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
+ if (isVarArg)
+ p << " vararg " << getCalleeType() << ' ';
+
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
- p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
+ p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
+ {"callee", "callee_type"});
p << " : ";
if (!isDirect)
@@ -1197,7 +1203,7 @@ void CallOp::print(OpAsmPrinter &p) {
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+ ArrayRef<OpAsmParser::UnresolvedOperand> operands, bool isVarArg) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
@@ -1233,6 +1239,18 @@ static ParseResult parseCallTypeAndResolveOperands(
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
+ if (!isVarArg) {
+ Type returnType;
+ if (funcType.getNumResults() == 0)
+ returnType = LLVM::LLVMVoidType::get(result.getContext());
+ else
+ returnType = funcType.getResult(0);
+ result.addAttribute(
+ "callee_type",
+ TypeAttr::get(LLVM::LLVMFunctionType::get(
+ returnType, funcType.getInputs(), /*isVarArg*/ false)));
+ }
+
return success();
}
@@ -1251,10 +1269,12 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}
-// <operation> ::= `llvm.call` (function-id | ssa-use)`(` ssa-use-list `)`
+// <operation> ::= `llvm.call` (function-id | ssa-use) var-arg-func-type?
+// `(` ssa-use-list `)`
// attribute-dict? `:` (type `,`)? function-type
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
+ TypeAttr calleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
// Parse a function pointer for indirect calls.
@@ -1267,13 +1287,18 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
+ bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+ if (isVarArg)
+ parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
+
// Parse the function arguments.
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the operands.
- return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+ return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ isVarArg);
}
///===---------------------------------------------------------------------===//
@@ -1388,7 +1413,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the trailing type list and resolve the function operands.
- if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ /*isVarArg*/ false))
return failure();
result.addSuccessors({normalDest, unwindDest});
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c27bc1b6f29076b..cdd22f116a4ef9d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -213,7 +213,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
} else {
call = builder.CreateCall(
llvm::cast<llvm::FunctionType>(
- moduleTranslation.convertType(callOp.getCalleeType())),
+ moduleTranslation.convertType(*callOp.getCalleeType())),
operandsRef.front(), operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
>From 66d8e4fdcf9e55934e7ab1858bac8eac44728346 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 12:03:03 +0900
Subject: [PATCH 04/10] Fixes
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 49 ++++++++++++-------
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 13 +++--
2 files changed, 41 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 87138d24f7fc09a..39bc824d25b25bc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,15 +950,14 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
// CallOp
//===----------------------------------------------------------------------===//
-namespace {
-static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
- SmallVector<Type> results;
+/// Get the MLIR Op-like result types of a LLVM fuction type
+static SmallVector<Type, 1> getCallOpResults(LLVMFunctionType calleeType) {
+ SmallVector<Type, 1> results;
Type resultType = calleeType.getReturnType();
- if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+ if (!isa<LLVM::LLVMVoidType>(resultType))
results.push_back(resultType);
return results;
}
-} // namespace
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
@@ -977,9 +976,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
resultType = LLVMVoidType::get(builder.getContext());
else
resultType = results.front();
- std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
- auto calleeType =
- LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+ auto calleeType = LLVMFunctionType::get(
+ resultType, llvm::to_vector(args.getTypes()), /*isVariadic*/ false);
build(builder, state, results, TypeAttr::get(calleeType), callee, args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
@@ -1172,8 +1170,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
- auto calleeType = *getCalleeType();
- bool isVarArg = calleeType.isVarArg();
+
+ LLVMFunctionType calleeType;
+ bool isVarArg;
+
+ std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType();
+ if (optionalCalleeType.has_value()) {
+ calleeType = *optionalCalleeType;
+ isVarArg = calleeType.isVarArg();
+ } else {
+ isVarArg = false;
+ }
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
@@ -1183,11 +1190,12 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
- if (isVarArg)
- p << " vararg " << getCalleeType() << ' ';
-
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
+
+ if (isVarArg)
+ p << " vararg(" << calleeType << ") ";
+
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{"callee", "callee_type"});
@@ -1287,15 +1295,22 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
- bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
- if (isVarArg)
- parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
-
// Parse the function arguments.
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
+ bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+ if (isVarArg) {
+ if (parser.parseLParen().failed() ||
+ !parser
+ .parseOptionalAttribute(calleeType, "callee_type",
+ result.attributes)
+ .has_value() ||
+ parser.parseRParen().failed())
+ return failure();
+ }
+
// Parse the trailing type list and resolve the operands.
return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
isVarArg);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index cdd22f116a4ef9d..61f9088d7e30f73 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,10 +211,15 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
call = builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
} else {
- call = builder.CreateCall(
- llvm::cast<llvm::FunctionType>(
- moduleTranslation.convertType(*callOp.getCalleeType())),
- operandsRef.front(), operandsRef.drop_front());
+ llvm::FunctionType *calleeType;
+ if (callOp.getCalleeType().has_value())
+ calleeType = llvm::cast<llvm::FunctionType>(
+ moduleTranslation.convertType(*callOp.getCalleeType()));
+ else
+ calleeType = getCalleeFunctionType(callOp.getResultTypes(),
+ callOp.getArgOperands());
+ call = builder.CreateCall(calleeType, operandsRef.front(),
+ operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
>From 07bea0eeb76b40d9de365044bc0d45c13e237304 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:11:09 +0900
Subject: [PATCH 05/10] Also add callee type to invoke op
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 19 ++--
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 95 +++++++++++++++----
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 10 +-
3 files changed, 90 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2f65077dddf5aaf..03da0287d6c7937 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -539,7 +539,9 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
- let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+ let arguments = (ins
+ OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+ OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
@@ -549,19 +551,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
AnySuccessor:$unwindDest);
let builders = [
+ OpBuilder<(ins "LLVMFuncOp":$func,
+ "ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
+ "Block*":$unwind, "ValueRange":$unwindOps)>,
OpBuilder<(ins "TypeRange":$tys, "FlatSymbolRefAttr":$callee,
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
- "Block*":$unwind, "ValueRange":$unwindOps),
- [{
- $_state.addAttribute("callee", callee);
- build($_builder, $_state, tys, ops, normal, normalOps, unwind, unwindOps);
- }]>,
+ "Block*":$unwind, "ValueRange":$unwindOps)>,
OpBuilder<(ins "TypeRange":$tys, "ValueRange":$ops, "Block*":$normal,
- "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps),
- [{
- build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
- unwindOps, nullptr, normal, unwind);
- }]>];
+ "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps)>];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 39bc824d25b25bc..5c3f894f005f892 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,7 +950,7 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
// CallOp
//===----------------------------------------------------------------------===//
-/// Get the MLIR Op-like result types of a LLVM fuction type
+/// Get the MLIR Op-like result types of a LLVMFunctionType
static SmallVector<Type, 1> getCallOpResults(LLVMFunctionType calleeType) {
SmallVector<Type, 1> results;
Type resultType = calleeType.getReturnType();
@@ -959,6 +959,18 @@ static SmallVector<Type, 1> getCallOpResults(LLVMFunctionType calleeType) {
return results;
}
+/// Construct a LLVMFunctionType from MLIR results and args
+static LLVMFunctionType getLLVMFuncType(OpBuilder &builder, TypeRange results,
+ ValueRange args) {
+ Type resultType;
+ if (results.empty())
+ resultType = LLVMVoidType::get(builder.getContext());
+ else
+ resultType = results.front();
+ return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
+ /*isVariadic*/ false);
+}
+
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
build(builder, state, results, builder.getStringAttr(callee), args);
@@ -971,14 +983,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
- Type resultType;
- if (results.empty())
- resultType = LLVMVoidType::get(builder.getContext());
- else
- resultType = results.front();
- auto calleeType = LLVMFunctionType::get(
- resultType, llvm::to_vector(args.getTypes()), /*isVariadic*/ false);
- build(builder, state, results, TypeAttr::get(calleeType), callee, args,
+ build(builder, state, results,
+ TypeAttr::get(getLLVMFuncType(builder, results, args)), callee, args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1117,14 +1123,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (!funcType)
return emitOpError("callee does not have a functional type: ") << fnType;
- // Indirect variadic function calls are not supported since the translation to
- // LLVM IR reconstructs the LLVM function type from the argument and result
- // types. An additional type attribute that stores the LLVM function type
- // would be needed to distinguish normal and variadic function arguments.
- // TODO: Support indirect calls to variadic function pointers.
- if (isIndirect && funcType.isVarArg())
- return emitOpError()
- << "indirect calls to variadic functions are not supported";
+ if (funcType.isVarArg() && !getCalleeType().has_value())
+ return emitOpError() << "Missing callee type attribute for vararg call";
// Verify that the operand and result types match the callee.
@@ -1320,6 +1320,30 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//
+void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
+ ValueRange ops, Block *normal, ValueRange normalOps,
+ Block *unwind, ValueRange unwindOps) {
+ auto calleeType = func.getFunctionType();
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ SymbolRefAttr::get(func), ops, normalOps, unwindOps, nullptr, normal,
+ unwind);
+}
+
+void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
+ FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
+ ValueRange normalOps, Block *unwind,
+ ValueRange unwindOps) {
+ build(builder, state, tys, TypeAttr::get(getLLVMFuncType(builder, tys, ops)),
+ callee, ops, normalOps, unwindOps, nullptr, normal, unwind);
+}
+
+void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
+ ValueRange ops, Block *normal, ValueRange normalOps,
+ Block *unwind, ValueRange unwindOps) {
+ build(builder, state, tys, TypeAttr::get(getLLVMFuncType(builder, tys, ops)),
+ /*callee=*/nullptr, ops, normalOps, unwindOps, nullptr, normal, unwind);
+}
+
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
@@ -1373,6 +1397,17 @@ void InvokeOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
+ LLVMFunctionType calleeType;
+ bool isVarArg;
+
+ std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType();
+ if (optionalCalleeType.has_value()) {
+ calleeType = *optionalCalleeType;
+ isVarArg = calleeType.isVarArg();
+ } else {
+ isVarArg = false;
+ }
+
p << ' ';
// Either function name or pointer
@@ -1387,8 +1422,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
p << " unwind ";
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
- p.printOptionalAttrDict((*this)->getAttrs(),
- {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
+ if (isVarArg)
+ p << " vararg(" << calleeType << ") ";
+
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ {InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"});
p << " : ";
if (!isDirect)
@@ -1405,6 +1444,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
+ TypeAttr calleeType;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
@@ -1423,8 +1463,21 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseKeyword("to") ||
parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
parser.parseKeyword("unwind") ||
- parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
- parser.parseOptionalAttrDict(result.attributes))
+ parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
+ return failure();
+
+ bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+ if (isVarArg) {
+ if (parser.parseLParen().failed() ||
+ !parser
+ .parseOptionalAttribute(calleeType, "callee_type",
+ result.attributes)
+ .has_value() ||
+ parser.parseRParen().failed())
+ return failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the function operands.
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 61f9088d7e30f73..8c84e80494e324a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -303,9 +303,15 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
} else {
+ llvm::FunctionType *calleeType;
+ if (invOp.getCalleeType().has_value())
+ calleeType = llvm::cast<llvm::FunctionType>(
+ moduleTranslation.convertType(*invOp.getCalleeType()));
+ else
+ calleeType = getCalleeFunctionType(invOp.getResultTypes(),
+ invOp.getArgOperands());
result = builder.CreateInvoke(
- getCalleeFunctionType(invOp.getResultTypes(), invOp.getArgOperands()),
- operandsRef.front(),
+ calleeType, operandsRef.front(),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef.drop_front());
>From 65f67f90aa8415c84e7155e99a2b83893e216d89 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:24:13 +0900
Subject: [PATCH 06/10] Add builder with func type to InvokeOp
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 5 +++--
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 7 ++++---
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 03da0287d6c7937..07de6986289ce5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -557,8 +557,9 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
OpBuilder<(ins "TypeRange":$tys, "FlatSymbolRefAttr":$callee,
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
"Block*":$unwind, "ValueRange":$unwindOps)>,
- OpBuilder<(ins "TypeRange":$tys, "ValueRange":$ops, "Block*":$normal,
- "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps)>];
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+ "ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
+ "Block*":$unwind, "ValueRange":$unwindOps)>];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5c3f894f005f892..96791d0cd99a618 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1337,11 +1337,12 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
callee, ops, normalOps, unwindOps, nullptr, normal, unwind);
}
-void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
+void InvokeOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange ops, Block *normal, ValueRange normalOps,
Block *unwind, ValueRange unwindOps) {
- build(builder, state, tys, TypeAttr::get(getLLVMFuncType(builder, tys, ops)),
- /*callee=*/nullptr, ops, normalOps, unwindOps, nullptr, normal, unwind);
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ callee, ops, normalOps, unwindOps, nullptr, normal, unwind);
}
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
>From 2a6f2c71fbe9d7622a07c9aa9372e894084180ae Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:24:44 +0900
Subject: [PATCH 07/10] Fix importing of LLVM Call and Invoke Instruction
---
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 18 +++++++++++++-----
1 file changed, 13 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 35b2fcd3d3abe4e..d51f8d6754113a3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1349,12 +1349,17 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertCallTypeAndOperands(callInst, types, operands)))
return failure();
+ auto funcTy =
+ dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
+
CallOp callOp;
+
if (llvm::Function *callee = callInst->getCalledFunction()) {
callOp = builder.create<CallOp>(
- loc, types, SymbolRefAttr::get(context, callee->getName()), operands);
+ loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+ operands);
} else {
- callOp = builder.create<CallOp>(loc, types, operands);
+ callOp = builder.create<CallOp>(loc, funcTy, operands);
}
setFastmathFlagsAttr(inst, callOp);
if (!callInst->getType()->isVoidTy())
@@ -1413,20 +1418,23 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();
+ auto funcTy =
+ dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
+
// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
// included in this list.
InvokeOp invokeOp;
if (llvm::Function *callee = invokeInst->getCalledFunction()) {
invokeOp = builder.create<InvokeOp>(
- loc, types,
+ loc, funcTy,
SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
directNormalDest, ValueRange(),
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
} else {
invokeOp = builder.create<InvokeOp>(
- loc, types, operands, directNormalDest, ValueRange(),
- lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
+ loc, funcTy, /*callee*/ nullptr, operands, directNormalDest,
+ ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
}
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
>From 3dfff15293920dd0293c4a8251e34401ee0f393a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:47:37 +0900
Subject: [PATCH 08/10] Add a builder for indirect calls
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 +
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 ++++++++++
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 4 ++++
3 files changed, 15 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 07de6986289ce5f..9f5b1f7f2738c78 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -624,6 +624,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
let results = (outs Optional<LLVM_Type>:$result);
let builders = [
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "ValueRange":$args)>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 96791d0cd99a618..79d3d25e4ebd1d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1013,6 +1013,16 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, ValueRange args) {
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ /*callee=*/nullptr, args,
+ /*fastmathFlags=*/nullptr,
+ /*branch_weights=*/nullptr,
+ /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
auto calleeType = func.getFunctionType();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d51f8d6754113a3..ccabfb8d9a05527 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1351,6 +1351,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
auto funcTy =
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
+ if (!funcTy)
+ return failure();
CallOp callOp;
@@ -1420,6 +1422,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
auto funcTy =
dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
+ if (!funcTy)
+ return failure();
// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
>From 06f929b39bd46cd9f45fb0bef5c8a2c569693e2d Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:58:50 +0900
Subject: [PATCH 09/10] Remove now supported case from failing test
---
mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir | 8 --------
1 file changed, 8 deletions(-)
diff --git a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
index bb177eb1500ad68..a87b1952b6dca70 100644
--- a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir
@@ -35,14 +35,6 @@ func.func @gep_too_few_dynamic(%base : !llvm.ptr<f32>) {
// -----
-func.func @call_variadic(%callee : !llvm.ptr<func<i8 (i8, ...)>>, %arg : i8) {
- // expected-error at +1 {{indirect calls to variadic functions are not supported}}
- llvm.call %callee(%arg) : !llvm.ptr<func<i8 (i8, ...)>>, (i8) -> (i8)
- llvm.return
-}
-
-// -----
-
func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr<func<void(i8)>>) {
// expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
"llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, i32) -> ()
>From 10629e81fce616a018e82e52ba10366e72bdba50 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 13:59:05 +0900
Subject: [PATCH 10/10] Add test for importing vararg indirect call
---
mlir/test/Target/LLVMIR/Import/instructions.ll | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 3f5ade4f1573579..79712a2fadb9ea3 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -480,6 +480,17 @@ define void @indirect_call(ptr addrspace(42) %fn) {
; // -----
+; CHECK-LABEL: @indirect_vararg_call
+; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
+define void @indirect_vararg_call(ptr addrspace(42) %fn) {
+ ; CHECK: %[[C0:[0-9]+]] = llvm.mlir.constant(0 : i16) : i16
+ ; CHECK: llvm.call %[[PTR]](%[[C0]]) vararg(!llvm.func<void (...)>) : !llvm.ptr<42>, (i16) -> ()
+ call addrspace(42) void (...) %fn(i16 0)
+ ret void
+}
+
+; // -----
+
; CHECK-LABEL: @gep_static_idx
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
define void @gep_static_idx(ptr %ptr) {
@@ -497,7 +508,7 @@ declare void @varargs(...)
; CHECK-LABEL: @varargs_call
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
define void @varargs_call(i32 %0) {
- ; CHECK: llvm.call @varargs(%[[ARG1]]) : (i32) -> ()
+ ; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
call void (...) @varargs(i32 %0)
ret void
}
More information about the Mlir-commits
mailing list