[clang] [CIR] Upstream initial attribute support (PR #121069)

David Olsen via cfe-commits cfe-commits at lists.llvm.org
Thu Dec 26 15:55:42 PST 2024


https://github.com/dkolsen-pgi updated https://github.com/llvm/llvm-project/pull/121069

>From f81f3d0b52ee343eb26eb00f42de97f8792e9172 Mon Sep 17 00:00:00 2001
From: David Olsen <dolsen at nvidia.com>
Date: Tue, 24 Dec 2024 13:16:32 -0800
Subject: [PATCH 1/2] [CIR] Upstream initial attribute support

Upstream several ClangIR-specific MLIR attributes, in particular
attributes for integer, floating-point, and null pointer constants.
These are the first ClangIR attributes to be upstreamed, so
infrastructure changes are included, such as the table-gen file
`CIRAttrs.td`.

Attributes can be used as the initial values for global variables.  The
existing automated test global-var-simple.cpp includes initial values
for some of the global variables in the test.
---
 .../CIR/Dialect/Builder/CIRBaseBuilder.h      |  11 ++
 clang/include/clang/CIR/Dialect/IR/CIRAttrs.h |  36 ++++
 .../include/clang/CIR/Dialect/IR/CIRAttrs.td  | 142 +++++++++++++++
 .../include/clang/CIR/Dialect/IR/CIRDialect.h |   1 +
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  54 +++++-
 .../include/clang/CIR/Dialect/IR/CIRTypes.td  |  12 +-
 .../clang/CIR/Dialect/IR/CMakeLists.txt       |   3 +
 clang/lib/CIR/CodeGen/CIRGenModule.cpp        |  42 +++++
 clang/lib/CIR/Dialect/IR/CIRAttrs.cpp         | 168 +++++++++++++++++-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 107 ++++++++++-
 clang/lib/CIR/Dialect/IR/CMakeLists.txt       |   1 +
 clang/lib/CIR/Interfaces/CMakeLists.txt       |   1 +
 clang/test/CIR/global-var-simple.cpp          |  24 +--
 13 files changed, 579 insertions(+), 23 deletions(-)
 create mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrs.h
 create mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 0e414921324b7f..1b2cb81683f22c 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -9,7 +9,11 @@
 #ifndef LLVM_CLANG_CIR_DIALECT_BUILDER_CIRBASEBUILDER_H
 #define LLVM_CLANG_CIR_DIALECT_BUILDER_CIRBASEBUILDER_H
 
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
+
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
 
 namespace cir {
 
@@ -26,6 +30,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
   cir::PointerType getVoidPtrTy() {
     return getPointerTo(cir::VoidType::get(getContext()));
   }
+
+  mlir::TypedAttr getConstPtrAttr(mlir::Type t, int64_t v) {
+    auto val =
+        mlir::IntegerAttr::get(mlir::IntegerType::get(t.getContext(), 64), v);
+    return cir::ConstPtrAttr::get(getContext(), mlir::cast<cir::PointerType>(t),
+                                  val);
+  }
 };
 
 } // namespace cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.h
new file mode 100644
index 00000000000000..438fb7d09608db
--- /dev/null
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.h
@@ -0,0 +1,36 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the attributes in the CIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_H
+#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_H
+
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+//===----------------------------------------------------------------------===//
+// CIR Dialect Attrs
+//===----------------------------------------------------------------------===//
+
+namespace clang {
+class FunctionDecl;
+class VarDecl;
+class RecordDecl;
+} // namespace clang
+
+#define GET_ATTRDEF_CLASSES
+#include "clang/CIR/Dialect/IR/CIROpsAttributes.h.inc"
+
+#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_H
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
new file mode 100644
index 00000000000000..bd1665e1ac1a06
--- /dev/null
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -0,0 +1,142 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the CIR dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD
+#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD
+
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+
+include "clang/CIR/Dialect/IR/CIRDialect.td"
+
+//===----------------------------------------------------------------------===//
+// CIR Attrs
+//===----------------------------------------------------------------------===//
+
+class CIR_Attr<string name, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<CIR_Dialect, name, traits> {
+  let mnemonic = attrMnemonic;
+}
+
+class CIRUnitAttr<string name, string attrMnemonic, list<Trait> traits = []>
+    : CIR_Attr<name, attrMnemonic, traits> {
+  let returnType = "bool";
+  let defaultValue = "false";
+  let valueType = NoneType;
+  let isOptional = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
+  let summary = "An attribute containing an integer value";
+  let description = [{
+    An integer attribute is a literal attribute that represents an integral
+    value of the specified integer type.
+  }];
+  let parameters = (ins AttributeSelfTypeParameter<"">:$type,
+                        "llvm::APInt":$value);
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
+                                        "const llvm::APInt &":$value), [{
+      return $_get(type.getContext(), type, value);
+    }]>,
+    AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
+                                        "int64_t":$value), [{
+      IntType intType = mlir::cast<IntType>(type);
+      mlir::APInt apValue(intType.getWidth(), value, intType.isSigned());
+      return $_get(intType.getContext(), intType, apValue);
+    }]>,
+  ];
+  let extraClassDeclaration = [{
+    int64_t getSInt() const { return getValue().getSExtValue(); }
+    uint64_t getUInt() const { return getValue().getZExtValue(); }
+    bool isNullValue() const { return getValue() == 0; }
+    uint64_t getBitWidth() const {
+      return mlir::cast<IntType>(getType()).getWidth();
+    }
+  }];
+  let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// FPAttr
+//===----------------------------------------------------------------------===//
+
+def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
+  let summary = "An attribute containing a floating-point value";
+  let description = [{
+    An fp attribute is a literal attribute that represents a floating-point
+    value of the specified floating-point type. Supporting only CIR FP types.
+  }];
+  let parameters = (ins
+    AttributeSelfTypeParameter<"", "::cir::CIRFPTypeInterface">:$type,
+    APFloatParameter<"">:$value
+  );
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
+                                        "const llvm::APFloat &":$value), [{
+      return $_get(type.getContext(), mlir::cast<CIRFPTypeInterface>(type),
+                   value);
+    }]>,
+    AttrBuilder<(ins "mlir::Type":$type,
+                     "const llvm::APFloat &":$value), [{
+      return $_get($_ctxt, mlir::cast<CIRFPTypeInterface>(type), value);
+    }]>,
+  ];
+  let extraClassDeclaration = [{
+    static FPAttr getZero(mlir::Type type);
+  }];
+  let genVerifyDecl = 1;
+
+  let assemblyFormat = [{
+    `<` custom<FloatLiteral>($value, ref($type)) `>`
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ConstPtrAttr
+//===----------------------------------------------------------------------===//
+
+def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
+  let summary = "Holds a constant pointer value";
+  let parameters = (ins
+    AttributeSelfTypeParameter<"", "::cir::PointerType">:$type,
+    "mlir::IntegerAttr":$value);
+  let description = [{
+    A pointer attribute is a literal attribute that represents an integral
+    value of a pointer type.
+  }];
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
+                                        "mlir::IntegerAttr":$value), [{
+      return $_get(type.getContext(), mlir::cast<cir::PointerType>(type),
+                   value);
+    }]>,
+    AttrBuilder<(ins "mlir::Type":$type,
+                     "mlir::IntegerAttr":$value), [{
+      return $_get($_ctxt, mlir::cast<cir::PointerType>(type), value);
+    }]>,
+  ];
+  let extraClassDeclaration = [{
+    bool isNullValue() const { return getValue().getInt() == 0; }
+  }];
+
+  let assemblyFormat = [{
+    `<` custom<ConstPtr>($value) `>`
+  }];
+}
+
+#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
index 0b71bdad29a3af..683176b139ca49 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
@@ -26,6 +26,7 @@
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
 #include "clang/CIR/Dialect/IR/CIROpsDialect.h.inc"
 
 // TableGen'erated files for MLIR dialects require that a macro be defined when
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 0d6c65ecf41029..b15e0415360ead 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -16,6 +16,7 @@
 
 include "clang/CIR/Dialect/IR/CIRDialect.td"
 include "clang/CIR/Dialect/IR/CIRTypes.td"
+include "clang/CIR/Dialect/IR/CIRAttrs.td"
 
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
@@ -75,6 +76,45 @@ class LLVMLoweringInfo {
 class CIR_Op<string mnemonic, list<Trait> traits = []> :
     Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;
 
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+def ConstantOp : CIR_Op<"const",
+                        [ConstantLike, Pure, AllTypesMatch<["value", "res"]>]> {
+  let summary = "Defines a CIR constant";
+  let description = [{
+    The `cir.const` operation turns a literal into an SSA value. The data is
+    attached to the operation as an attribute.
+
+    ```mlir
+      %0 = cir.const 42 : i32
+      %1 = cir.const 4.2 : f32
+      %2 = cir.const nullptr : !cir.ptr<i32>
+    ```
+  }];
+
+  // The constant operation takes an attribute as the only input.
+  let arguments = (ins TypedAttrInterface:$value);
+
+  // The constant operation returns a single value of CIR_AnyType.
+  let results = (outs CIR_AnyType:$res);
+
+  let assemblyFormat = "attr-dict $value";
+
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isNullPtr() {
+      if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(getValue()))
+        return ptrAttr.isNullValue();
+      return false;
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalOp
 //===----------------------------------------------------------------------===//
@@ -92,9 +132,19 @@ def GlobalOp : CIR_Op<"global"> {
     described by the type of the variable.
   }];
 
-  let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$sym_type);
+  let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$sym_type,
+                       OptionalAttr<AnyAttr>:$initial_value);
+
+  let assemblyFormat = [{
+    $sym_name
+    custom<GlobalOpTypeAndInitialValue>($sym_type, $initial_value)
+    attr-dict
+  }];
 
-  let assemblyFormat = [{ $sym_name `:` $sym_type attr-dict }];
+  let extraClassDeclaration = [{
+    bool isDeclaration() { return !getInitialValue(); }
+    bool hasInitializer() { return !isDeclaration(); }
+  }];
 
   let skipDefaultBuilders = 1;
 
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index ef00b26c1fd98c..a32fb3c801114a 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -220,8 +220,8 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
 
 // Constraints
 
-def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_FP128, CIR_LongDouble,
-    CIR_FP16, CIR_BFloat16]>;
+def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_FP128,
+                             CIR_LongDouble, CIR_FP16, CIR_BFloat16]>;
 def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
 
 //===----------------------------------------------------------------------===//
@@ -350,4 +350,12 @@ def VoidPtr : Type<
       "cir::VoidType::get($_builder.getContext()))"> {
 }
 
+//===----------------------------------------------------------------------===//
+// Global type constraints
+//===----------------------------------------------------------------------===//
+
+def CIR_AnyType : AnyTypeOf<[
+  CIR_VoidType, CIR_IntType, CIR_AnyFloat, CIR_PointerType, CIR_FuncType
+]>;
+
 #endif // MLIR_CIR_DIALECT_CIR_TYPES
diff --git a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
index 28ae30dab8dfb2..1fdbc24ba6b4a3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
+++ b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
@@ -14,3 +14,6 @@ mlir_tablegen(CIROpsDialect.cpp.inc -gen-dialect-defs)
 add_public_tablegen_target(MLIRCIROpsIncGen)
 add_dependencies(mlir-headers MLIRCIROpsIncGen)
 
+mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRCIRAttrsEnumsGen)
diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
index 416d532028d090..2615ae382cb8b5 100644
--- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
@@ -115,6 +115,48 @@ void CIRGenModule::emitGlobalVarDefinition(const clang::VarDecl *vd,
   if (clang::IdentifierInfo *identifier = vd->getIdentifier()) {
     auto varOp = builder.create<cir::GlobalOp>(getLoc(vd->getSourceRange()),
                                                identifier->getName(), type);
+    // TODO(CIR): This code for processing initial values is a placeholder
+    // until class ConstantEmitter is upstreamed and the code for processing
+    // constant expressions is filled out.  Only the most basic handling of
+    // certain constant expressions is implemented for now.
+    const VarDecl *initDecl;
+    const Expr *initExpr = vd->getAnyInitializer(initDecl);
+    if (initExpr) {
+      mlir::Attribute initializer;
+      if (APValue *value = initDecl->evaluateValue()) {
+        switch (value->getKind()) {
+        case APValue::Int: {
+          initializer = builder.getAttr<cir::IntAttr>(type, value->getInt());
+          break;
+        }
+        case APValue::Float: {
+          initializer = builder.getAttr<cir::FPAttr>(type, value->getFloat());
+          break;
+        }
+        case APValue::LValue: {
+          if (value->getLValueBase()) {
+            errorNYI(initExpr->getSourceRange(),
+                     "non-null pointer initialization");
+          } else {
+            if (auto ptrType = mlir::dyn_cast<cir::PointerType>(type)) {
+              initializer = builder.getConstPtrAttr(
+                  ptrType, value->getLValueOffset().getQuantity());
+            } else {
+              llvm_unreachable(
+                  "non-pointer variable initialized with a pointer");
+            }
+          }
+          break;
+        }
+        default:
+          errorNYI(initExpr->getSourceRange(), "unsupported initializer kind");
+          break;
+        }
+      } else {
+        errorNYI(initExpr->getSourceRange(), "non-constant initializer");
+      }
+      varOp.setInitialValueAttr(initializer);
+    }
     theModule.push_back(varOp);
   } else {
     errorNYI(vd->getSourceRange().getBegin(),
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index 7d42da1ab20d76..11dd927f79d789 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -12,6 +12,24 @@
 
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
+                              mlir::Type ty);
+static mlir::ParseResult
+parseFloatLiteral(mlir::AsmParser &parser,
+                  mlir::FailureOr<llvm::APFloat> &value,
+                  cir::CIRFPTypeInterface fpType);
+
+static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
+                                       mlir::IntegerAttr &value);
+
+static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
+
+#define GET_ATTRDEF_CLASSES
+#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
+
 using namespace mlir;
 using namespace cir;
 
@@ -21,12 +39,151 @@ using namespace cir;
 
 Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
                                      Type type) const {
-  // No attributes yet to parse
-  return Attribute{};
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  llvm::StringRef mnemonic;
+  Attribute genAttr;
+  OptionalParseResult parseResult =
+      generatedAttributeParser(parser, &mnemonic, type, genAttr);
+  if (parseResult.has_value())
+    return genAttr;
+  parser.emitError(typeLoc, "unknown attribute in CIR dialect");
+  return Attribute();
 }
 
 void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
-  // No attributes yet to print
+  if (failed(generatedAttributePrinter(attr, os)))
+    llvm_unreachable("unexpected CIR type kind");
+}
+
+//===----------------------------------------------------------------------===//
+// ConstPtrAttr definitions
+//===----------------------------------------------------------------------===//
+
+// TODO(CIR): Consider encoding the null value differently and use conditional
+// assembly format instead of custom parsing/printing.
+static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
+
+  if (parser.parseOptionalKeyword("null").succeeded()) {
+    value = mlir::IntegerAttr::get(
+        mlir::IntegerType::get(parser.getContext(), 64), 0);
+    return success();
+  }
+
+  return parser.parseAttribute(value);
+}
+
+static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
+  if (!value.getInt())
+    p << "null";
+  else
+    p << value;
+}
+
+//===----------------------------------------------------------------------===//
+// IntAttr definitions
+//===----------------------------------------------------------------------===//
+
+Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
+  mlir::APInt apValue;
+
+  if (!mlir::isa<IntType>(odsType))
+    return {};
+  auto type = mlir::cast<IntType>(odsType);
+
+  // Consume the '<' symbol.
+  if (parser.parseLess())
+    return {};
+
+  // Fetch arbitrary precision integer value.
+  if (type.isSigned()) {
+    int64_t value;
+    if (parser.parseInteger(value))
+      parser.emitError(parser.getCurrentLocation(), "expected integer value");
+    apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
+                          /*implicitTrunc=*/true);
+    if (apValue.getSExtValue() != value)
+      parser.emitError(parser.getCurrentLocation(),
+                       "integer value too large for the given type");
+  } else {
+    uint64_t value;
+    if (parser.parseInteger(value))
+      parser.emitError(parser.getCurrentLocation(), "expected integer value");
+    apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
+                          /*implicitTrunc=*/true);
+    if (apValue.getZExtValue() != value)
+      parser.emitError(parser.getCurrentLocation(),
+                       "integer value too large for the given type");
+  }
+
+  // Consume the '>' symbol.
+  if (parser.parseGreater())
+    return {};
+
+  return IntAttr::get(type, apValue);
+}
+
+void IntAttr::print(AsmPrinter &printer) const {
+  auto type = mlir::cast<IntType>(getType());
+  printer << '<';
+  if (type.isSigned())
+    printer << getSInt();
+  else
+    printer << getUInt();
+  printer << '>';
+}
+
+LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, APInt value) {
+  if (!mlir::isa<IntType>(type)) {
+    emitError() << "expected 'simple.int' type";
+    return failure();
+  }
+
+  auto intType = mlir::cast<IntType>(type);
+  if (value.getBitWidth() != intType.getWidth()) {
+    emitError() << "type and value bitwidth mismatch: " << intType.getWidth()
+                << " != " << value.getBitWidth();
+    return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FPAttr definitions
+//===----------------------------------------------------------------------===//
+
+static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
+  p << value;
+}
+
+static ParseResult parseFloatLiteral(AsmParser &parser,
+                                     FailureOr<APFloat> &value,
+                                     CIRFPTypeInterface fpType) {
+
+  APFloat parsedValue(0.0);
+  if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
+    return failure();
+
+  value.emplace(parsedValue);
+  return success();
+}
+
+FPAttr FPAttr::getZero(Type type) {
+  return get(type,
+             APFloat::getZero(
+                 mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics()));
+}
+
+LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                             CIRFPTypeInterface fpType, APFloat value) {
+  if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
+      APFloat::SemanticsToEnum(value.getSemantics())) {
+    emitError() << "floating-point semantics mismatch";
+    return failure();
+  }
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -34,5 +191,8 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
 //===----------------------------------------------------------------------===//
 
 void CIRDialect::registerAttributes() {
-  // No attributes yet to register
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
+      >();
 }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index dbdca1f8401663..f98d8b60f6ff87 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -12,6 +12,8 @@
 
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+
 #include "mlir/Support/LogicalResult.h"
 
 #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
@@ -32,13 +34,73 @@ void cir::CIRDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
+                                        mlir::Attribute attrType) {
+  if (isa<cir::ConstPtrAttr>(attrType)) {
+    if (!mlir::isa<cir::PointerType>(opType))
+      return op->emitOpError(
+          "pointer constant initializing a non-pointer type");
+    return success();
+  }
+
+  if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
+    auto at = cast<TypedAttr>(attrType);
+    if (at.getType() != opType) {
+      return op->emitOpError("result type (")
+             << opType << ") does not match value type (" << at.getType()
+             << ")";
+    }
+    return success();
+  }
+
+  assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
+  return op->emitOpError("global with type ")
+         << cast<TypedAttr>(attrType).getType() << " not yet supported";
+}
+
+LogicalResult cir::ConstantOp::verify() {
+  // ODS already generates checks to make sure the result type is valid. We just
+  // need to additionally check that the value's attribute type is consistent
+  // with the result type.
+  return checkConstantTypes(getOperation(), getType(), getValue());
+}
+
+OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
+  return getValue();
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalOp
 //===----------------------------------------------------------------------===//
 
-// TODO(CIR): The properties of global variables that require verification
-// haven't been implemented yet.
-mlir::LogicalResult cir::GlobalOp::verify() { return success(); }
+static ParseResult parseConstantValue(OpAsmParser &parser,
+                                      mlir::Attribute &valueAttr) {
+  NamedAttrList attr;
+  return parser.parseAttribute(valueAttr, "value", attr);
+}
+
+static void printConstant(OpAsmPrinter &p, Attribute value) {
+  p.printAttribute(value);
+}
+
+mlir::LogicalResult cir::GlobalOp::verify() {
+  // Verify that the initial value, if present, is either a unit attribute or
+  // an attribute CIR supports.
+  if (getInitialValue().has_value()) {
+    if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
+            .failed())
+      return failure();
+  }
+
+  // TODO(CIR): Many other checks for properties that haven't been upstreamed
+  // yet.
+
+  return success();
+}
 
 void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                           llvm::StringRef sym_name, mlir::Type sym_type) {
@@ -48,6 +110,45 @@ void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         mlir::TypeAttr::get(sym_type));
 }
 
+static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
+                                             TypeAttr type,
+                                             Attribute initAttr) {
+  if (!op.isDeclaration()) {
+    p << "= ";
+    // This also prints the type...
+    if (initAttr)
+      printConstant(p, initAttr);
+  } else {
+    p << ": " << type;
+  }
+}
+
+static ParseResult
+parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+                                 Attribute &initialValueAttr) {
+  mlir::Type opTy;
+  if (parser.parseOptionalEqual().failed()) {
+    // Absence of equal means a declaration, so we need to parse the type.
+    //  cir.global @a : !cir.int<s, 32>
+    if (parser.parseColonType(opTy))
+      return failure();
+  } else {
+    // Parse constant with initializer, examples:
+    //  cir.global @y = #cir.fp<1.250000e+00> : !cir.double
+    //  cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
+    if (parseConstantValue(parser, initialValueAttr).failed())
+      return failure();
+
+    assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
+           "Non-typed attrs shouldn't appear here.");
+    auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
+    opTy = typedAttr.getType();
+  }
+
+  typeAttr = TypeAttr::get(opTy);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FuncOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CMakeLists.txt b/clang/lib/CIR/Dialect/IR/CMakeLists.txt
index df60f69df6fc0e..baf8bff1852212 100644
--- a/clang/lib/CIR/Dialect/IR/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_clang_library(MLIRCIR
 
   DEPENDS
   MLIRCIROpsIncGen
+  MLIRCIRAttrsEnumsGen
 
   LINK_LIBS PUBLIC
   MLIRIR
diff --git a/clang/lib/CIR/Interfaces/CMakeLists.txt b/clang/lib/CIR/Interfaces/CMakeLists.txt
index fcd8b6963d06c2..b826bf612cc356 100644
--- a/clang/lib/CIR/Interfaces/CMakeLists.txt
+++ b/clang/lib/CIR/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_clang_library(MLIRCIRInterfaces
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
 
   DEPENDS
+  MLIRCIRAttrsEnumsGen
   MLIRCIRFPTypeInterfaceIncGen
 
   LINK_LIBS
diff --git a/clang/test/CIR/global-var-simple.cpp b/clang/test/CIR/global-var-simple.cpp
index bbd452655a01bf..ffcc3ef71a6c74 100644
--- a/clang/test/CIR/global-var-simple.cpp
+++ b/clang/test/CIR/global-var-simple.cpp
@@ -13,11 +13,11 @@ unsigned char uc;
 short ss;
 // CHECK: cir.global @ss : !cir.int<s, 16>
 
-unsigned short us;
-// CHECK: cir.global @us : !cir.int<u, 16>
+unsigned short us = 100;
+// CHECK: cir.global @us = #cir.int<100> : !cir.int<u, 16>
 
-int si;
-// CHECK: cir.global @si : !cir.int<s, 32>
+int si = 42;
+// CHECK: cir.global @si = #cir.int<42> : !cir.int<s, 32>
 
 unsigned ui;
 // CHECK: cir.global @ui : !cir.int<u, 32>
@@ -31,8 +31,8 @@ unsigned long ul;
 long long sll;
 // CHECK: cir.global @sll : !cir.int<s, 64>
 
-unsigned long long ull;
-// CHECK: cir.global @ull : !cir.int<u, 64>
+unsigned long long ull = 123456;
+// CHECK: cir.global @ull = #cir.int<123456> : !cir.int<u, 64>
 
 __int128 s128;
 // CHECK: cir.global @s128 : !cir.int<s, 128>
@@ -67,8 +67,8 @@ __bf16 bf16;
 float f;
 // CHECK: cir.global @f : !cir.float
 
-double d;
-// CHECK: cir.global @d : !cir.double
+double d = 1.25;
+// CHECK: cir.global @d = #cir.fp<1.250000e+00> : !cir.double
 
 long double ld;
 // CHECK: cir.global @ld : !cir.long_double<!cir.f80>
@@ -79,8 +79,8 @@ __float128 f128;
 void *vp;
 // CHECK: cir.global @vp : !cir.ptr<!cir.void>
 
-int *ip;
-// CHECK: cir.global @ip : !cir.ptr<!cir.int<s, 32>>
+int *ip = 0;
+// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
 
 double *dp;
 // CHECK: cir.global @dp : !cir.ptr<!cir.double>
@@ -91,8 +91,8 @@ char **cpp;
 void (*fp)();
 // CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
 
-int (*fpii)(int);
-// CHECK: cir.global @fpii : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
+int (*fpii)(int) = 0;
+// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
 
 void (*fpvar)(int, ...);
 // CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>

>From 01ae70bd2ee7de814be342cc08d3c69f61d5c66d Mon Sep 17 00:00:00 2001
From: David Olsen <dolsen at nvidia.com>
Date: Thu, 26 Dec 2024 15:54:08 -0800
Subject: [PATCH 2/2] [CIR] Upstream attributes followup

Better parameter names for `CIRBaseBuilderTy::getConstPtrAttr`

Slightly better error handling for `IntAttr::parse`
---
 .../CIR/Dialect/Builder/CIRBaseBuilder.h      | 10 +++---
 clang/lib/CIR/Dialect/IR/CIRAttrs.cpp         | 32 +++++++++++--------
 2 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 1b2cb81683f22c..b4a961de224aa0 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -31,11 +31,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
     return getPointerTo(cir::VoidType::get(getContext()));
   }
 
-  mlir::TypedAttr getConstPtrAttr(mlir::Type t, int64_t v) {
-    auto val =
-        mlir::IntegerAttr::get(mlir::IntegerType::get(t.getContext(), 64), v);
-    return cir::ConstPtrAttr::get(getContext(), mlir::cast<cir::PointerType>(t),
-                                  val);
+  mlir::TypedAttr getConstPtrAttr(mlir::Type type, int64_t value) {
+    auto valueAttr = mlir::IntegerAttr::get(
+        mlir::IntegerType::get(type.getContext(), 64), value);
+    return cir::ConstPtrAttr::get(
+        getContext(), mlir::cast<cir::PointerType>(type), valueAttr);
   }
 };
 
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index 11dd927f79d789..8e8f7d5b7d7cb4 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -96,23 +96,27 @@ Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
 
   // Fetch arbitrary precision integer value.
   if (type.isSigned()) {
-    int64_t value;
-    if (parser.parseInteger(value))
+    int64_t value = 0;
+    if (parser.parseInteger(value)) {
       parser.emitError(parser.getCurrentLocation(), "expected integer value");
-    apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
-                          /*implicitTrunc=*/true);
-    if (apValue.getSExtValue() != value)
-      parser.emitError(parser.getCurrentLocation(),
-                       "integer value too large for the given type");
+    } else {
+      apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
+                            /*implicitTrunc=*/true);
+      if (apValue.getSExtValue() != value)
+        parser.emitError(parser.getCurrentLocation(),
+                         "integer value too large for the given type");
+    }
   } else {
-    uint64_t value;
-    if (parser.parseInteger(value))
+    uint64_t value = 0;
+    if (parser.parseInteger(value)) {
       parser.emitError(parser.getCurrentLocation(), "expected integer value");
-    apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
-                          /*implicitTrunc=*/true);
-    if (apValue.getZExtValue() != value)
-      parser.emitError(parser.getCurrentLocation(),
-                       "integer value too large for the given type");
+    } else {
+      apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
+                            /*implicitTrunc=*/true);
+      if (apValue.getZExtValue() != value)
+        parser.emitError(parser.getCurrentLocation(),
+                         "integer value too large for the given type");
+    }
   }
 
   // Consume the '>' symbol.



More information about the cfe-commits mailing list