[clang] [CIR] Function type return type improvements (PR #128787)
via cfe-commits
cfe-commits at lists.llvm.org
Tue Feb 25 15:35:50 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clangir
Author: David Olsen (dkolsen-pgi)
<details>
<summary>Changes</summary>
When a C or C++ function has a return type of `void`, the function type is now represented in MLIR as having no return type rather than having a return type of `!cir.void`. This avoids breaking MLIR invariants that require the number of return types and the number of return values to match.
Change the assembly format for `cir::FuncType` from having a leading return type to having a trailing return type. In other words, change
```
!cir.func<!returnType (!argTypes)>
```
to
```
!cir.func<(!argTypes) -> !returnType)>
```
Unless the function returns `void`, in which case change
```
!cir.func<!cir.void (!argTypes)>
```
to
```
!cir.func<(!argTypes)>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/128787.diff
8 Files Affected:
- (modified) clang/include/clang/CIR/Dialect/IR/CIRTypes.td (+28-12)
- (modified) clang/lib/CIR/CodeGen/CIRGenTypes.cpp (+1-1)
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+4)
- (modified) clang/lib/CIR/Dialect/IR/CIRTypes.cpp (+93-32)
- (modified) clang/test/CIR/IR/func.cir (+4-4)
- (modified) clang/test/CIR/IR/global.cir (+6-6)
- (modified) clang/test/CIR/func-simple.cpp (+2-2)
- (modified) clang/test/CIR/global-var-simple.cpp (+3-3)
``````````diff
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index fc8edbcf3e166..c2d45ebeefe63 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -287,32 +287,43 @@ def CIR_BoolType :
def CIR_FuncType : CIR_Type<"Func", "func"> {
let summary = "CIR function type";
let description = [{
- The `!cir.func` is a function type. It consists of a single return type, a
- list of parameter types and can optionally be variadic.
+ The `!cir.func` is a function type. It consists of an optional return type,
+ a list of parameter types and can optionally be variadic.
Example:
```mlir
- !cir.func<!bool ()>
- !cir.func<!s32i (!s8i, !s8i)>
- !cir.func<!s32i (!s32i, ...)>
+ !cir.func<()>
+ !cir.func<() -> bool>
+ !cir.func<(!s8i, !s8i)>
+ !cir.func<(!s8i, !s8i) -> !s32i>
+ !cir.func<(!s32i, ...) -> !s32i>
```
}];
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
- "mlir::Type":$returnType, "bool":$varArg);
+ "mlir::Type":$optionalReturnType, "bool":$varArg);
+ // Use a custom parser to handle the argument types and optional return
let assemblyFormat = [{
- `<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
+ `<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
}];
let builders = [
+ // Create a FuncType, converting the return type from C-style to
+ // MLIR-style. If the given return type is `cir::VoidType`, ignore it
+ // and create the FuncType with no return type, which is how MLIR
+ // represents function types.
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
- return $_get(returnType.getContext(), inputs, returnType, isVarArg);
+ return $_get(returnType.getContext(), inputs,
+ mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType,
+ isVarArg);
}]>
];
+ let genVerifyDecl = 1;
+
let extraClassDeclaration = [{
/// Returns whether the function is variadic.
bool isVarArg() const { return getVarArg(); }
@@ -323,12 +334,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }
- /// Returns the result type of the function as an ArrayRef, enabling better
- /// integration with generic MLIR utilities.
+ /// Get the C-style return type of the function, which is !cir.void if the
+ /// function returns nothing and the actual return type otherwise.
+ mlir::Type getReturnType() const;
+
+ /// Get the MLIR-style return type of the function, which is an empty
+ /// ArrayRef if the function returns nothing and a single-element ArrayRef
+ /// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
- /// Returns whether the function is returns void.
- bool isVoid() const;
+ /// Does the function type return nothing?
+ bool hasVoidReturn() const;
/// Returns a clone of this function type with the given argument
/// and result types.
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index 16aec10fda81e..dcfaaedc2ef57 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) {
mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
assert(qft.isCanonical());
const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
- // First, check whether we can build the full fucntion type. If the function
+ // First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
if (!isFuncTypeConvertible(ft)) {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index bfc74d4373f34..1a0740dea1fa8 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -424,6 +424,10 @@ LogicalResult cir::FuncOp::verifyType() {
if (!isa<cir::FuncType>(type))
return emitOpError("requires '" + getFunctionTypeAttrName().str() +
"' attribute of function type");
+ if (auto rt = type.getReturnTypes();
+ !rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
+ return emitOpError("The return type for a function returning void should "
+ "be empty instead of an explicit !cir.void");
return success();
}
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index d1b143efb955e..67fa6c267cf0f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -20,11 +20,12 @@
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//
-static mlir::ParseResult
-parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
- bool &isVarArg);
-static void printFuncTypeArgs(mlir::AsmPrinter &p,
- mlir::ArrayRef<mlir::Type> params, bool isVarArg);
+static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
+ mlir::Type &optionalReturnTypes,
+ llvm::SmallVector<mlir::Type> ¶ms,
+ bool &isVarArg);
+static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
+ mlir::ArrayRef<mlir::Type> params, bool isVarArg);
//===----------------------------------------------------------------------===//
// Get autogenerated stuff
@@ -282,40 +283,55 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}
-mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
- llvm::SmallVector<mlir::Type> ¶ms,
- bool &isVarArg) {
+// A special parser is needed for function returning void to handle the missing
+// type.
+static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
+ mlir::Type &optionalReturnType) {
+ if (succeeded(p.parseOptionalArrow())) {
+ // `->` found. It must be followed by the return type.
+ return p.parseType(optionalReturnType);
+ }
+ // Function has `void` return in C++, no return in MLIR.
+ optionalReturnType = {};
+ return success();
+}
+
+// A special pretty-printer for function returning or not a result.
+static void printFuncTypeReturn(mlir::AsmPrinter &p,
+ mlir::Type optionalReturnType) {
+ if (optionalReturnType)
+ p << " -> " << optionalReturnType;
+}
+
+static mlir::ParseResult
+parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
+ bool &isVarArg) {
isVarArg = false;
- // `(` `)`
- if (succeeded(p.parseOptionalRParen()))
+ if (failed(p.parseLParen()))
+ return failure();
+ if (succeeded(p.parseOptionalRParen())) {
+ // `()` empty argument list
return mlir::success();
-
- // `(` `...` `)`
- if (succeeded(p.parseOptionalEllipsis())) {
- isVarArg = true;
- return p.parseRParen();
}
-
- // type (`,` type)* (`,` `...`)?
- mlir::Type type;
- if (p.parseType(type))
- return mlir::failure();
- params.push_back(type);
- while (succeeded(p.parseOptionalComma())) {
+ do {
if (succeeded(p.parseOptionalEllipsis())) {
+ // `...`, which must be the last thing in the list.
isVarArg = true;
- return p.parseRParen();
+ break;
+ } else {
+ mlir::Type argType;
+ if (failed(p.parseType(argType)))
+ return failure();
+ params.push_back(argType);
}
- if (p.parseType(type))
- return mlir::failure();
- params.push_back(type);
- }
-
+ } while (succeeded(p.parseOptionalComma()));
return p.parseRParen();
}
-void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
- bool isVarArg) {
+static void printFuncTypeArgs(mlir::AsmPrinter &p,
+ mlir::ArrayRef<mlir::Type> params,
+ bool isVarArg) {
+ p << '(';
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
@@ -326,11 +342,56 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}
+// Use a custom parser to handle the optional return and argument types without
+// an optional anchor.
+static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
+ mlir::Type &optionalReturnType,
+ llvm::SmallVector<mlir::Type> ¶ms,
+ bool &isVarArg) {
+ if (failed(parseFuncTypeArgs(p, params, isVarArg)))
+ return failure();
+ return parseFuncTypeReturn(p, optionalReturnType);
+}
+
+static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType,
+ mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
+ printFuncTypeArgs(p, params, isVarArg);
+ printFuncTypeReturn(p, optionalReturnType);
+}
+
+/// Get the C-style return type of the function, which is !cir.void if the
+/// function returns nothing and the actual return type otherwise.
+mlir::Type FuncType::getReturnType() const {
+ if (hasVoidReturn())
+ return cir::VoidType::get(getContext());
+ return getOptionalReturnType();
+}
+
+/// Get the MLIR-style return type of the function, which is an empty
+/// ArrayRef if the function returns nothing and a single-element ArrayRef
+/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
- return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
+ if (hasVoidReturn())
+ return {};
+ // Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
+ // pointer to its elements and doesn't do lifetime extension. That would
+ // result in returning a pointer to a temporary that has gone out of scope.
+ return getImpl()->optionalReturnType;
}
-bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
+// Does the fuction type return nothing?
+bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }
+
+mlir::LogicalResult
+FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
+ bool isVarArg) {
+ if (returnType && mlir::isa<cir::VoidType>(returnType)) {
+ emitError() << "!cir.func cannot have an explicit 'void' return type";
+ return mlir::failure();
+ }
+ return mlir::success();
+}
//===----------------------------------------------------------------------===//
// BoolType
diff --git a/clang/test/CIR/IR/func.cir b/clang/test/CIR/IR/func.cir
index a32c3e697ed25..4077bd33e0438 100644
--- a/clang/test/CIR/IR/func.cir
+++ b/clang/test/CIR/IR/func.cir
@@ -2,18 +2,18 @@
module {
// void empty() { }
-cir.func @empty() -> !cir.void {
+cir.func @empty() {
cir.return
}
-// CHECK: cir.func @empty() -> !cir.void {
+// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }
// void voidret() { return; }
-cir.func @voidret() -> !cir.void {
+cir.func @voidret() {
cir.return
}
-// CHECK: cir.func @voidret() -> !cir.void {
+// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }
diff --git a/clang/test/CIR/IR/global.cir b/clang/test/CIR/IR/global.cir
index 6c68ab0a501ff..9d187686d996c 100644
--- a/clang/test/CIR/IR/global.cir
+++ b/clang/test/CIR/IR/global.cir
@@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
cir.global @dp : !cir.ptr<!cir.double>
cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
- cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
- cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
- cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
+ cir.global @fp : !cir.ptr<!cir.func<()>>
+ cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
+ cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
}
// CHECK: cir.global @c : !cir.int<s, 8>
@@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
// CHECK: cir.global @dp : !cir.ptr<!cir.double>
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
-// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
-// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
-// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
+// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
+// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
+// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
diff --git a/clang/test/CIR/func-simple.cpp b/clang/test/CIR/func-simple.cpp
index 22c120d3404d3..3947055e300a0 100644
--- a/clang/test/CIR/func-simple.cpp
+++ b/clang/test/CIR/func-simple.cpp
@@ -2,12 +2,12 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s
void empty() { }
-// CHECK: cir.func @empty() -> !cir.void {
+// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }
void voidret() { return; }
-// CHECK: cir.func @voidret() -> !cir.void {
+// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }
diff --git a/clang/test/CIR/global-var-simple.cpp b/clang/test/CIR/global-var-simple.cpp
index dfe8371668e2c..f8e233cd5fe33 100644
--- a/clang/test/CIR/global-var-simple.cpp
+++ b/clang/test/CIR/global-var-simple.cpp
@@ -92,10 +92,10 @@ char **cpp;
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
void (*fp)();
-// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
+// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
int (*fpii)(int) = 0;
-// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
+// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
void (*fpvar)(int, ...);
-// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
+// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/128787
More information about the cfe-commits
mailing list