[Mlir-commits] [mlir] [MLIR][LLVM] Add vararg support in LLVM::LLVMFuncOp (PR #67274)
Ivan R. Ivanov
llvmlistbot at llvm.org
Mon Sep 25 20:08:02 PDT 2023
https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/67274
>From ee20d089575575c420d6f475cdd1040b1169774b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:46:17 +0900
Subject: [PATCH 1/4] Build LLVM::CallOp with callee type info
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 9 +++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 51 ++++++++++++++++++---
2 files changed, 53 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2b4c8b609cfdd4f..63deb00802446b9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];
- dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+ dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+ OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
@@ -630,6 +631,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+ CArg<"ValueRange", "{}">:$args)>,
+ OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index fd0d2b3fb3c1a08..2def142b1d2d457 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,6 +950,16 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
// CallOp
//===----------------------------------------------------------------------===//
+namespace {
+static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
+ SmallVector<Type> results;
+ Type resultType = calleeType.getReturnType();
+ if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+ results.push_back(resultType);
+ return results;
+}
+} // namespace
+
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
build(builder, state, results, builder.getStringAttr(callee), args);
@@ -962,7 +972,38 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
- build(builder, state, results, callee, args, /*fastmathFlags=*/nullptr,
+ Type resultType;
+ if (results.empty())
+ resultType = LLVMVoidType::get(builder.getContext());
+ else
+ resultType = results.front();
+ std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
+ auto calleeType =
+ LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+ build(builder, state, results, TypeAttr::get(calleeType), callee, args,
+ /*fastmathFlags=*/nullptr,
+ /*branch_weights=*/nullptr,
+ /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, StringRef callee,
+ ValueRange args) {
+ build(builder, state, calleeType, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, StringAttr callee,
+ ValueRange args) {
+ build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+ ValueRange args) {
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -970,11 +1011,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
- SmallVector<Type> results;
- Type resultType = func.getFunctionType().getReturnType();
- if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
- results.push_back(resultType);
- build(builder, state, results, SymbolRefAttr::get(func), args,
+ auto calleeType = func.getFunctionType();
+ build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+ SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
>From 6f95cba5452169d46679a9adfe6296bc1899605a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:54:54 +0900
Subject: [PATCH 2/4] Fix Conversion of CallOp func type
---
.../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 9b438090c84cacc..c27bc1b6f29076b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,9 +211,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
call = builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
} else {
- call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
- callOp.getArgOperands()),
- operandsRef.front(), operandsRef.drop_front());
+ call = builder.CreateCall(
+ llvm::cast<llvm::FunctionType>(
+ moduleTranslation.convertType(callOp.getCalleeType())),
+ operandsRef.front(), operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
>From 5c54d75f18134f4867217751c1be605bef7a7f3b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 25 Sep 2023 10:37:30 +0900
Subject: [PATCH 3/4] Print vararg info
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 36 ++++++++++++++++---
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 2 +-
3 files changed, 33 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 63deb00802446b9..2f65077dddf5aaf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];
- dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+ dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2def142b1d2d457..87138d24f7fc09a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1172,6 +1172,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
+ auto calleeType = *getCalleeType();
+ bool isVarArg = calleeType.isVarArg();
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
@@ -1181,9 +1183,13 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
+ if (isVarArg)
+ p << " vararg " << getCalleeType() << ' ';
+
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
- p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
+ p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
+ {"callee", "callee_type"});
p << " : ";
if (!isDirect)
@@ -1197,7 +1203,7 @@ void CallOp::print(OpAsmPrinter &p) {
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+ ArrayRef<OpAsmParser::UnresolvedOperand> operands, bool isVarArg) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
@@ -1233,6 +1239,18 @@ static ParseResult parseCallTypeAndResolveOperands(
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
+ if (!isVarArg) {
+ Type returnType;
+ if (funcType.getNumResults() == 0)
+ returnType = LLVM::LLVMVoidType::get(result.getContext());
+ else
+ returnType = funcType.getResult(0);
+ result.addAttribute(
+ "callee_type",
+ TypeAttr::get(LLVM::LLVMFunctionType::get(
+ returnType, funcType.getInputs(), /*isVarArg*/ false)));
+ }
+
return success();
}
@@ -1251,10 +1269,12 @@ static ParseResult parseOptionalCallFuncPtr(
return success();
}
-// <operation> ::= `llvm.call` (function-id | ssa-use)`(` ssa-use-list `)`
+// <operation> ::= `llvm.call` (function-id | ssa-use) var-arg-func-type?
+// `(` ssa-use-list `)`
// attribute-dict? `:` (type `,`)? function-type
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
+ TypeAttr calleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
// Parse a function pointer for indirect calls.
@@ -1267,13 +1287,18 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
+ bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+ if (isVarArg)
+ parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
+
// Parse the function arguments.
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the operands.
- return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+ return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ isVarArg);
}
///===---------------------------------------------------------------------===//
@@ -1388,7 +1413,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the trailing type list and resolve the function operands.
- if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ /*isVarArg*/ false))
return failure();
result.addSuccessors({normalDest, unwindDest});
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c27bc1b6f29076b..cdd22f116a4ef9d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -213,7 +213,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
} else {
call = builder.CreateCall(
llvm::cast<llvm::FunctionType>(
- moduleTranslation.convertType(callOp.getCalleeType())),
+ moduleTranslation.convertType(*callOp.getCalleeType())),
operandsRef.front(), operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
>From 66d8e4fdcf9e55934e7ab1858bac8eac44728346 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 12:03:03 +0900
Subject: [PATCH 4/4] Fixes
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 49 ++++++++++++-------
.../LLVMIR/LLVMToLLVMIRTranslation.cpp | 13 +++--
2 files changed, 41 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 87138d24f7fc09a..39bc824d25b25bc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,15 +950,14 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
// CallOp
//===----------------------------------------------------------------------===//
-namespace {
-static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
- SmallVector<Type> results;
+/// Get the MLIR Op-like result types of a LLVM fuction type
+static SmallVector<Type, 1> getCallOpResults(LLVMFunctionType calleeType) {
+ SmallVector<Type, 1> results;
Type resultType = calleeType.getReturnType();
- if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+ if (!isa<LLVM::LLVMVoidType>(resultType))
results.push_back(resultType);
return results;
}
-} // namespace
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
@@ -977,9 +976,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
resultType = LLVMVoidType::get(builder.getContext());
else
resultType = results.front();
- std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
- auto calleeType =
- LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+ auto calleeType = LLVMFunctionType::get(
+ resultType, llvm::to_vector(args.getTypes()), /*isVariadic*/ false);
build(builder, state, results, TypeAttr::get(calleeType), callee, args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
@@ -1172,8 +1170,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
- auto calleeType = *getCalleeType();
- bool isVarArg = calleeType.isVarArg();
+
+ LLVMFunctionType calleeType;
+ bool isVarArg;
+
+ std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType();
+ if (optionalCalleeType.has_value()) {
+ calleeType = *optionalCalleeType;
+ isVarArg = calleeType.isVarArg();
+ } else {
+ isVarArg = false;
+ }
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
@@ -1183,11 +1190,12 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
- if (isVarArg)
- p << " vararg " << getCalleeType() << ' ';
-
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
+
+ if (isVarArg)
+ p << " vararg(" << calleeType << ") ";
+
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{"callee", "callee_type"});
@@ -1287,15 +1295,22 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
- bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
- if (isVarArg)
- parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
-
// Parse the function arguments.
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
+ bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+ if (isVarArg) {
+ if (parser.parseLParen().failed() ||
+ !parser
+ .parseOptionalAttribute(calleeType, "callee_type",
+ result.attributes)
+ .has_value() ||
+ parser.parseRParen().failed())
+ return failure();
+ }
+
// Parse the trailing type list and resolve the operands.
return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
isVarArg);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index cdd22f116a4ef9d..61f9088d7e30f73 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,10 +211,15 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
call = builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
} else {
- call = builder.CreateCall(
- llvm::cast<llvm::FunctionType>(
- moduleTranslation.convertType(*callOp.getCalleeType())),
- operandsRef.front(), operandsRef.drop_front());
+ llvm::FunctionType *calleeType;
+ if (callOp.getCalleeType().has_value())
+ calleeType = llvm::cast<llvm::FunctionType>(
+ moduleTranslation.convertType(*callOp.getCalleeType()));
+ else
+ calleeType = getCalleeFunctionType(callOp.getResultTypes(),
+ callOp.getArgOperands());
+ call = builder.CreateCall(calleeType, operandsRef.front(),
+ operandsRef.drop_front());
}
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
More information about the Mlir-commits
mailing list