[clang] [CIR] Upstream global initialization for ComplexType (PR #141369)
via cfe-commits
cfe-commits at lists.llvm.org
Sat May 24 13:29:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clangir
Author: Amr Hesham (AmrDeveloper)
<details>
<summary>Changes</summary>
This change adds support for zero and global init for ComplexType
#<!-- -->141365
---
Full diff: https://github.com/llvm/llvm-project/pull/141369.diff
11 Files Affected:
- (modified) clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h (+2)
- (modified) clang/include/clang/CIR/Dialect/IR/CIRAttrs.td (+34)
- (modified) clang/include/clang/CIR/Dialect/IR/CIRTypes.td (+43)
- (modified) clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp (+23-4)
- (modified) clang/lib/CIR/CodeGen/CIRGenTypes.cpp (+7)
- (modified) clang/lib/CIR/Dialect/IR/CIRAttrs.cpp (+20)
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+4-2)
- (modified) clang/lib/CIR/Dialect/IR/CIRTypes.cpp (+26)
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+53-12)
- (added) clang/test/CIR/CodeGen/complex.cpp (+29)
- (added) clang/test/CIR/IR/complex.cir (+16)
``````````diff
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 9de3a66755778..878aba69c0e24 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -89,6 +89,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::IntAttr::get(ty, 0);
if (cir::isAnyFloatingPointType(ty))
return cir::FPAttr::getZero(ty);
+ if (auto complexType = mlir::dyn_cast<cir::ComplexType>(ty))
+ return cir::ZeroAttr::get(complexType);
if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
return cir::ZeroAttr::get(arrTy);
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index 8152535930095..4effae1cf2e29 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -276,4 +276,38 @@ def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
}];
}
+//===----------------------------------------------------------------------===//
+// ConstComplexAttr
+//===----------------------------------------------------------------------===//
+
+def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex", [TypedAttrInterface]> {
+ let summary = "An attribute that contains a constant complex value";
+ let description = [{
+ The `#cir.const_complex` attribute contains a constant value of complex number
+ type. The `real` parameter gives the real part of the complex number and the
+ `imag` parameter gives the imaginary part of the complex number.
+
+ The `real` and `imag` parameter must be either an IntAttr or an FPAttr that
+ contains values of the same CIR type.
+ }];
+
+ let parameters = (ins
+ AttributeSelfTypeParameter<"", "cir::ComplexType">:$type,
+ "mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
+
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type,
+ "mlir::TypedAttr":$real,
+ "mlir::TypedAttr":$imag), [{
+ return $_get(type.getContext(), type, real, imag);
+ }]>,
+ ];
+
+ let genVerifyDecl = 1;
+
+ let assemblyFormat = [{
+ `<` qualified($real) `,` qualified($imag) `>`
+ }];
+}
+
#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index 26f1122a4b261..ec994620893fe 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -161,6 +161,49 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
}];
}
+//===----------------------------------------------------------------------===//
+// ComplexType
+//===----------------------------------------------------------------------===//
+
+def CIR_ComplexType : CIR_Type<"Complex", "complex",
+ [DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
+
+ let summary = "CIR complex type";
+ let description = [{
+ CIR type that represents a C complex number. `cir.complex` models the C type
+ `T _Complex`.
+
+ The type models complex values, per C99 6.2.5p11. It supports the C99
+ complex float types as well as the GCC integer complex extensions.
+
+ The parameter `elementType` gives the type of the real and imaginary part of
+ the complex number. `elementType` must be either a CIR integer type or a CIR
+ floating-point type.
+ }];
+
+ let parameters = (ins CIR_AnyIntOrFloatType:$elementType);
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
+ return $_get(elementType.getContext(), elementType);
+ }]>,
+ ];
+
+ let assemblyFormat = [{
+ `<` $elementType `>`
+ }];
+
+ let extraClassDeclaration = [{
+ bool isFloatingPointComplex() const {
+ return isAnyFloatingPointType(getElementType());
+ }
+
+ bool isIntegerComplex() const {
+ return mlir::isa<cir::IntType>(getElementType());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index 9085ee2dfe506..973349b8c0443 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -577,12 +577,31 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
case APValue::Union:
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate struct or union");
return {};
- case APValue::FixedPoint:
case APValue::ComplexInt:
- case APValue::ComplexFloat:
+ case APValue::ComplexFloat: {
+ mlir::Type desiredType = cgm.convertType(destType);
+ cir::ComplexType complexType =
+ mlir::dyn_cast<cir::ComplexType>(desiredType);
+
+ mlir::Type compelxElemTy = complexType.getElementType();
+ if (isa<cir::IntType>(compelxElemTy)) {
+ llvm::APSInt real = value.getComplexIntReal();
+ llvm::APSInt imag = value.getComplexIntImag();
+ return builder.getAttr<cir::ConstComplexAttr>(
+ complexType, builder.getAttr<cir::IntAttr>(compelxElemTy, real),
+ builder.getAttr<cir::IntAttr>(compelxElemTy, imag));
+ }
+
+ llvm::APFloat real = value.getComplexFloatReal();
+ llvm::APFloat imag = value.getComplexFloatImag();
+ return builder.getAttr<cir::ConstComplexAttr>(
+ complexType, builder.getAttr<cir::FPAttr>(compelxElemTy, real),
+ builder.getAttr<cir::FPAttr>(compelxElemTy, imag));
+ }
+ case APValue::FixedPoint:
case APValue::AddrLabelDiff:
- cgm.errorNYI("ConstExprEmitter::tryEmitPrivate fixed point, complex int, "
- "complex float, addr label diff");
+ cgm.errorNYI(
+ "ConstExprEmitter::tryEmitPrivate fixed point, addr label diff");
return {};
}
llvm_unreachable("Unknown APValue kind");
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index 0665ea0506875..948be813ebe51 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -385,6 +385,13 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
break;
}
+ case Type::Complex: {
+ const ComplexType *ct = cast<ComplexType>(ty);
+ mlir::Type elementTy = convertType(ct->getElementType());
+ resultType = cir::ComplexType::get(elementTy);
+ break;
+ }
+
case Type::LValueReference:
case Type::RValueReference: {
const ReferenceType *refTy = cast<ReferenceType>(ty);
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index c4fb4942aec75..d9426ced5f5ab 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -184,6 +184,26 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// ConstComplexAttr definitions
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ cir::ComplexType type, mlir::TypedAttr real,
+ mlir::TypedAttr imag) {
+ mlir::Type elemType = type.getElementType();
+ if (real.getType() != elemType)
+ return emitError()
+ << "type of the real part does not match the complex type";
+
+ if (imag.getType() != elemType)
+ return emitError()
+ << "type of the imaginary part does not match the complex type";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CIR ConstArrayAttr
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 36dcbc6a4be4a..4a9386b1eed0f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -231,7 +231,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}
if (isa<cir::ZeroAttr>(attrType)) {
- if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType))
+ if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
+ opType))
return success();
return op->emitOpError("zero expects struct or array type");
}
@@ -253,7 +254,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}
- if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
+ if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
+ cir::ConstComplexAttr>(attrType))
return success();
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index b402177a5ec18..14050f36bbfdc 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -552,6 +552,32 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
.getABIAlignment(dataLayout, params);
}
+//===----------------------------------------------------------------------===//
+// ComplexType Definitions
+//===----------------------------------------------------------------------===//
+
+llvm::TypeSize
+cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
+ mlir::DataLayoutEntryListRef params) const {
+ // C17 6.2.5p13:
+ // Each complex type has the same representation and alignment requirements
+ // as an array type containing exactly two elements of the corresponding
+ // real type.
+
+ return dataLayout.getTypeSizeInBits(getElementType()) * 2;
+}
+
+uint64_t
+cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
+ mlir::DataLayoutEntryListRef params) const {
+ // C17 6.2.5p13:
+ // Each complex type has the same representation and alignment requirements
+ // as an array type containing exactly two elements of the corresponding
+ // real type.
+
+ return dataLayout.getTypeABIAlignment(getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 8e82af7e62bc0..d0ae1d64e9afd 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -188,14 +188,15 @@ class CIRAttrToValue {
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
- .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
- cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
- [&](auto attrT) { return visitCirAttr(attrT); })
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
+ cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
+ cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}
mlir::Value visitCirAttr(cir::IntAttr intAttr);
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
+ mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
@@ -226,6 +227,42 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
}
+/// FPAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+}
+
+/// ConstComplexAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
+ auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
+ auto complexElemTy = complexType.getElementType();
+ auto complexElemLLVMTy = converter->convertType(complexElemTy);
+
+ mlir::Attribute components[2];
+ if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
+ components[0] = rewriter.getIntegerAttr(
+ complexElemLLVMTy,
+ mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
+ components[1] = rewriter.getIntegerAttr(
+ complexElemLLVMTy,
+ mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
+ } else {
+ components[0] = rewriter.getFloatAttr(
+ complexElemLLVMTy,
+ mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
+ components[1] = rewriter.getFloatAttr(
+ complexElemLLVMTy,
+ mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
+ }
+
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(complexAttr.getType()),
+ rewriter.getArrayAttr(components));
+}
+
/// ConstPtrAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
mlir::Location loc = parentOp->getLoc();
@@ -241,13 +278,6 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
loc, converter->convertType(ptrAttr.getType()), ptrVal);
}
-/// FPAttr visitor.
-mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
-}
-
// ConstArrayAttr visitor
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
mlir::Type llvmTy = converter->convertType(attr.getType());
@@ -341,9 +371,11 @@ class GlobalInitAttrRewriter {
mlir::Attribute visitCirAttr(cir::IntAttr attr) {
return rewriter.getIntegerAttr(llvmType, attr.getValue());
}
+
mlir::Attribute visitCirAttr(cir::FPAttr attr) {
return rewriter.getFloatAttr(llvmType, attr.getValue());
}
+
mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
return rewriter.getBoolAttr(attr.getValue());
}
@@ -990,7 +1022,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
- cir::ZeroAttr>(init)));
+ cir::ConstComplexAttr, cir::ZeroAttr>(init)));
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
@@ -1043,7 +1075,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::failure();
}
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
- cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
+ cir::ConstPtrAttr, cir::ConstComplexAttr,
+ cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
@@ -1559,6 +1592,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
return mlir::BFloat16Type::get(type.getContext());
});
+ converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
+ // A complex type is lowered to an LLVM struct that contains the real and
+ // imaginary part as data fields.
+ mlir::Type elementTy = converter.convertType(type.getElementType());
+ mlir::Type structFields[2] = {elementTy, elementTy};
+ return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
+ structFields);
+ });
converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> {
auto result = converter.convertType(type.getReturnType());
llvm::SmallVector<mlir::Type> arguments;
diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
new file mode 100644
index 0000000000000..1e0c9fcf08ef0
--- /dev/null
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
+
+int _Complex ci;
+
+float _Complex cf;
+
+int _Complex ci2 = { 1, 2 };
+
+float _Complex cf2 = { 1.0f, 2.0f };
+
+// CIR: cir.global external @ci = #cir.zero : !cir.complex<!s32i>
+// CIR: cir.global external @cf = #cir.zero : !cir.complex<!cir.float>
+// CIR: cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
+// CIR: cir.global external @cf2 = #cir.const_complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float>
+
+// LLVM: {{.*}} = dso_local global { i32, i32 } zeroinitializer, align 4
+// LLVM: {{.*}} = dso_local global { float, float } zeroinitializer, align 4
+// LLVM: {{.*}} = dso_local global { i32, i32 } { i32 1, i32 2 }, align 4
+// LLVM: {{.*}} = dso_local global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4
+
+// OGCG: {{.*}} = global { i32, i32 } zeroinitializer, align 4
+// OGCG: {{.*}} = global { float, float } zeroinitializer, align 4
+// OGCG: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4
+// OGCG: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4
diff --git a/clang/test/CIR/IR/complex.cir b/clang/test/CIR/IR/complex.cir
new file mode 100644
index 0000000000000..a73a8654ca274
--- /dev/null
+++ b/clang/test/CIR/IR/complex.cir
@@ -0,0 +1,16 @@
+// RUN: cir-opt %s | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.global external @ci = #cir.zero : !cir.complex<!s32i>
+// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!s32i>
+
+cir.global external @cf = #cir.zero : !cir.complex<!cir.float>
+// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!cir.float>
+
+cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
+// CHECK: cir.global external {{.*}} = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
+
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/141369
More information about the cfe-commits
mailing list