[clang] [CIR] Upstream global initialization for VectorType (PR #137511)
Amr Hesham via cfe-commits
cfe-commits at lists.llvm.org
Tue Apr 29 14:01:54 PDT 2025
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/137511
>From 98eaaf0a64dc811481aab37da0939fa0d374a4f6 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Sat, 26 Apr 2025 18:43:00 +0200
Subject: [PATCH 1/3] [CIR] Upstream global initialization for VectorType
---
.../include/clang/CIR/Dialect/IR/CIRAttrs.td | 33 ++++++-
clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp | 23 ++++-
clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 88 +++++++++++++++++++
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 2 +-
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 40 +++++++--
clang/test/CIR/CodeGen/vector-ext.cpp | 11 ++-
clang/test/CIR/CodeGen/vector.cpp | 9 ++
7 files changed, 196 insertions(+), 10 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index fb3f7b1632436..624a82762ab18 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}]>
];
- // Printing and parsing available in CIRDialect.cpp
+ // Printing and parsing available in CIRAttrs.cpp
let hasCustomAssemblyFormat = 1;
// Enable verifier.
@@ -215,6 +215,37 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];
}
+//===----------------------------------------------------------------------===//
+// ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
+ [TypedAttrInterface]> {
+ let summary = "A constant vector from ArrayAttr";
+ let description = [{
+ A CIR vector attribute is an array of literals of the specified attribute
+ types.
+ }];
+
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type,
+ "mlir::ArrayAttr":$elts);
+
+ // Define a custom builder for the type; that removes the need to pass in an
+ // MLIRContext instance, as it can be inferred from the `type`.
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "cir::VectorType":$type,
+ "mlir::ArrayAttr":$elts), [{
+ return $_get(type.getContext(), type, elts);
+ }]>
+ ];
+
+ // Printing and parsing available in CIRAttrs.cpp
+ let hasCustomAssemblyFormat = 1;
+
+ // Enable verifier.
+ let genVerifyDecl = 1;
+}
+
//===----------------------------------------------------------------------===//
// ConstPtrAttr
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index b9a74e90a5960..6e5c7b8fb51f8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
elements, typedFiller);
}
case APValue::Vector: {
- cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector");
- return {};
+ const QualType elementType =
+ destType->castAs<VectorType>()->getElementType();
+ const unsigned numElements = value.getVectorLength();
+
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(numElements);
+
+ for (unsigned i = 0; i < numElements; ++i) {
+ const mlir::Attribute element =
+ tryEmitPrivateForMemory(value.getVectorElt(i), elementType);
+ if (!element)
+ return {};
+ elements.push_back(element);
+ }
+
+ const auto desiredVecTy =
+ mlir::cast<cir::VectorType>(cgm.convertType(destType));
+
+ return cir::ConstVectorAttr::get(
+ desiredVecTy,
+ mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements));
}
case APValue::MemberPointer: {
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer");
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index a940651f1e9eb..fff849b141562 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -299,6 +299,94 @@ void ConstArrayAttr::print(AsmPrinter &printer) const {
printer << ">";
}
+//===----------------------------------------------------------------------===//
+// CIR ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ConstVectorAttr::verify(
+ function_ref<::mlir::InFlightDiagnostic()> emitError, Type type,
+ ArrayAttr elts) {
+
+ if (!mlir::isa<cir::VectorType>(type)) {
+ return emitError() << "type of cir::ConstVectorAttr is not a "
+ "cir::VectorType: "
+ << type;
+ }
+
+ const auto vecType = mlir::cast<cir::VectorType>(type);
+
+ if (vecType.getSize() != elts.size()) {
+ return emitError()
+ << "number of constant elements should match vector size";
+ }
+
+ // Check if the types of the elements match
+ LogicalResult elementTypeCheck = success();
+ elts.walkImmediateSubElements(
+ [&](Attribute element) {
+ if (elementTypeCheck.failed()) {
+ // An earlier element didn't match
+ return;
+ }
+ auto typedElement = mlir::dyn_cast<TypedAttr>(element);
+ if (!typedElement ||
+ typedElement.getType() != vecType.getElementType()) {
+ elementTypeCheck = failure();
+ emitError() << "constant type should match vector element type";
+ }
+ },
+ [&](Type) {});
+
+ return elementTypeCheck;
+}
+
+Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) {
+ FailureOr<Type> resultType;
+ FailureOr<ArrayAttr> resultValue;
+
+ const SMLoc loc = parser.getCurrentLocation();
+
+ // Parse literal '<'
+ if (parser.parseLess()) {
+ return {};
+ }
+
+ // Parse variable 'value'
+ resultValue = FieldParser<ArrayAttr>::parse(parser);
+ if (failed(resultValue)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "failed to parse ConstVectorAttr parameter 'value' as "
+ "an attribute");
+ return {};
+ }
+
+ if (parser.parseOptionalColon().failed()) {
+ resultType = type;
+ } else {
+ resultType = ::mlir::FieldParser<Type>::parse(parser);
+ if (failed(resultType)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "failed to parse ConstVectorAttr parameter 'type' as "
+ "an MLIR type");
+ return {};
+ }
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater()) {
+ return {};
+ }
+
+ return parser.getChecked<ConstVectorAttr>(
+ loc, parser.getContext(), resultType.value(), resultValue.value());
+}
+
+void cir::ConstVectorAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ printer.printStrippedAttrOrType(getElts());
+ printer << ">";
+}
+
//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 25993063ee7fd..c2cfa6fd81e9b 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -244,7 +244,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}
- if (mlir::isa<cir::ConstArrayAttr>(attrType))
+ if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
return success();
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 2c87255045df8..1cd82169a4cf8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -188,8 +188,9 @@ class CIRAttrToValue {
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
- .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr,
- cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
+ cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}
@@ -197,6 +198,7 @@ class CIRAttrToValue {
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
+ mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);
private:
@@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
return result;
}
+/// ConstVectorAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
+ const mlir::Type llvmTy = converter->convertType(attr.getType());
+ const mlir::Location loc = parentOp->getLoc();
+
+ SmallVector<mlir::Attribute> mlirValues;
+ for (const mlir::Attribute elementAttr : attr.getElts()) {
+ mlir::Attribute mlirAttr;
+ if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
+ mlirAttr = rewriter.getIntegerAttr(
+ converter->convertType(intAttr.getType()), intAttr.getValue());
+ } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
+ mlirAttr = rewriter.getFloatAttr(
+ converter->convertType(floatAttr.getType()), floatAttr.getValue());
+ } else {
+ llvm_unreachable(
+ "vector constant with an element that is neither an int nor a float");
+ }
+ mlirValues.push_back(mlirAttr);
+ }
+
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, llvmTy,
+ mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
+ mlirValues));
+}
+
/// ZeroAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
mlir::Location loc = parentOp->getLoc();
@@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
- assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
+ assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
+ cir::ZeroAttr>(init)));
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
@@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
- } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
- init.value())) {
+ } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
+ cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 13726edf3d259..7759a32fc1378 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -31,7 +31,7 @@ vi2 vec_c;
// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
-vd2 d;
+vd2 vec_d;
// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>
@@ -39,6 +39,15 @@ vd2 d;
// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
+vi4 vec_e = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
+// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
+// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
void foo() {
vi4 a;
vi3 b;
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 8f9e98fb6b3c0..4c1850141a21c 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -30,6 +30,15 @@ vll2 c;
// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer
+vi4 d = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
+// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
+// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
void vec_int_test() {
vi4 a;
vd2 b;
>From 581a8f8db0745e8ac92896d1f677ecfec9fb9da5 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 28 Apr 2025 19:05:11 +0200
Subject: [PATCH 2/3] Add test for parsing const vector
---
clang/test/CIR/IR/vector.cir | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 74ddf7691e7d4..bc70a8b55fa5c 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -13,6 +13,12 @@ cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i>
cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
// CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i>
+cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2>
+: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
+// CIR: cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
>From cd4f06123701160fcc117c7a3f0c99304191004d Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 29 Apr 2025 22:59:54 +0200
Subject: [PATCH 3/3] Use tablegen based asm fmt
---
.../include/clang/CIR/Dialect/IR/CIRAttrs.td | 5 +-
clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 47 -------------------
2 files changed, 3 insertions(+), 49 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index 624a82762ab18..8152535930095 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -239,8 +239,9 @@ def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
}]>
];
- // Printing and parsing available in CIRAttrs.cpp
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ `<` $elts `>`
+ }];
// Enable verifier.
let genVerifyDecl = 1;
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index fff849b141562..6f41cd4388ac2 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -340,53 +340,6 @@ LogicalResult cir::ConstVectorAttr::verify(
return elementTypeCheck;
}
-Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) {
- FailureOr<Type> resultType;
- FailureOr<ArrayAttr> resultValue;
-
- const SMLoc loc = parser.getCurrentLocation();
-
- // Parse literal '<'
- if (parser.parseLess()) {
- return {};
- }
-
- // Parse variable 'value'
- resultValue = FieldParser<ArrayAttr>::parse(parser);
- if (failed(resultValue)) {
- parser.emitError(parser.getCurrentLocation(),
- "failed to parse ConstVectorAttr parameter 'value' as "
- "an attribute");
- return {};
- }
-
- if (parser.parseOptionalColon().failed()) {
- resultType = type;
- } else {
- resultType = ::mlir::FieldParser<Type>::parse(parser);
- if (failed(resultType)) {
- parser.emitError(parser.getCurrentLocation(),
- "failed to parse ConstVectorAttr parameter 'type' as "
- "an MLIR type");
- return {};
- }
- }
-
- // Parse literal '>'
- if (parser.parseGreater()) {
- return {};
- }
-
- return parser.getChecked<ConstVectorAttr>(
- loc, parser.getContext(), resultType.value(), resultValue.value());
-}
-
-void cir::ConstVectorAttr::print(AsmPrinter &printer) const {
- printer << "<";
- printer.printStrippedAttrOrType(getElts());
- printer << ">";
-}
-
//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
More information about the cfe-commits
mailing list