[Mlir-commits] [mlir] c7cae0e - [mlir][Attributes][NFC] Move all builtin Attribute classes to BuiltinAttributes.h
River Riddle
llvmlistbot at llvm.org
Thu Dec 3 18:02:55 PST 2020
Author: River Riddle
Date: 2020-12-03T18:02:11-08:00
New Revision: c7cae0e4fa4e1ed4bdca186096a408578225fc2b
URL: https://github.com/llvm/llvm-project/commit/c7cae0e4fa4e1ed4bdca186096a408578225fc2b
DIFF: https://github.com/llvm/llvm-project/commit/c7cae0e4fa4e1ed4bdca186096a408578225fc2b.diff
LOG: [mlir][Attributes][NFC] Move all builtin Attribute classes to BuiltinAttributes.h
This mirrors the file structure of Types.
Differential Revision: https://reviews.llvm.org/D92499
Added:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinAttributes.cpp
Modified:
flang/include/flang/Optimizer/Dialect/FIRType.h
mlir/docs/CAPI.md
mlir/docs/LangRef.md
mlir/include/mlir/Bindings/Python/Attributes.td
mlir/include/mlir/Dialect/CommonFolders.h
mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h
mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
mlir/include/mlir/Transforms/LoopUtils.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/lib/IR/AffineMap.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/CMakeLists.txt
mlir/test/CAPI/ir.c
mlir/unittests/IR/AttributeTest.cpp
Removed:
mlir/include/mlir-c/StandardAttributes.h
mlir/lib/CAPI/IR/StandardAttributes.cpp
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 6d2aec25fa8f..6484ca5156bd 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -9,8 +9,8 @@
#ifndef OPTIMIZER_DIALECT_FIRTYPE_H
#define OPTIMIZER_DIALECT_FIRTYPE_H
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
namespace llvm {
diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index e71dee091774..a292c1fc1833 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -185,12 +185,12 @@ for (iter = mlirXGetFirst<Y>(x); !mlirYIsNull(iter);
### Extensions for Dialect Attributes and Types
-Dialect attributes and types can follow the example of standard attributes and
+Dialect attributes and types can follow the example of builtin attributes and
types, provided that implementations live in separate directories, i.e.
`include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs
provide implementation-private headers in `include/mlir/CAPI/IR` that allow one
to convert between opaque C structures for core IR components and their C++
counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
-the inverse conversion. Once the C++ object is available, the API
-implementation should rely on `isa` to implement `mlirXIsAY` and is expected to
-use `cast` inside other API calls.
+the inverse conversion. Once the C++ object is available, the API implementation
+should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast`
+inside other API calls.
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index d32b6148030a..28962a8893ee 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -1337,7 +1337,7 @@ dialect and not the function argument.
Attribute values are represented by the following forms:
```
-attribute-value ::= attribute-alias | dialect-attribute | standard-attribute
+attribute-value ::= attribute-alias | dialect-attribute | builtin-attribute
```
### Attribute Value Aliases
@@ -1404,14 +1404,14 @@ characters that are not allowed in the lighter syntax, as well as unbalanced
See [here](Tutorials/DefiningAttributesAndTypes.md) on how to define dialect
attribute values.
-### Standard Attribute Values
+### Builtin Attribute Values
-Standard attributes are a core set of
+Builtin attributes are a core set of
[dialect attributes](#dialect-attribute-values) that are defined in a builtin
dialect and thus available to all users of MLIR.
```
-standard-attribute ::= affine-map-attribute
+builtin-attribute ::= affine-map-attribute
| array-attribute
| bool-attribute
| dictionary-attribute
diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
similarity index 97%
rename from mlir/include/mlir-c/StandardAttributes.h
rename to mlir/include/mlir-c/BuiltinAttributes.h
index 39879d8e31ed..7c280f67e791 100644
--- a/mlir/include/mlir-c/StandardAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -1,4 +1,4 @@
-//===-- mlir-c/StandardAttributes.h - C API for Std Attributes-----*- C -*-===//
+//===-- mlir-c/BuiltinAttributes.h - C API for Builtin Attributes -*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
@@ -7,12 +7,12 @@
//
//===----------------------------------------------------------------------===//
//
-// This header declares the C interface to MLIR Standard attributes.
+// This header declares the C interface to MLIR Builtin attributes.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_C_STANDARDATTRIBUTES_H
-#define MLIR_C_STANDARDATTRIBUTES_H
+#ifndef MLIR_C_BUILTINATTRIBUTES_H
+#define MLIR_C_BUILTINATTRIBUTES_H
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
@@ -45,9 +45,8 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAArray(MlirAttribute attr);
/** Creates an array element containing the given list of elements in the given
* context. */
-MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(MlirContext ctx,
- intptr_t numElements,
- MlirAttribute const *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(
+ MlirContext ctx, intptr_t numElements, MlirAttribute const *elements);
/// Returns the number of elements stored in the given array attribute.
MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
@@ -437,4 +436,4 @@ mlirSparseElementsAttrGetValues(MlirAttribute attr);
}
#endif
-#endif // MLIR_C_STANDARDATTRIBUTES_H
+#endif // MLIR_C_BUILTINATTRIBUTES_H
diff --git a/mlir/include/mlir/Bindings/Python/Attributes.td b/mlir/include/mlir/Bindings/Python/Attributes.td
index 0ed155035a99..4fe746d2a737 100644
--- a/mlir/include/mlir/Bindings/Python/Attributes.td
+++ b/mlir/include/mlir/Bindings/Python/Attributes.td
@@ -15,13 +15,13 @@
#define PYTHON_BINDINGS_ATTRIBUTES
// A mapping between the attribute storage type and the corresponding Python
-// type. There is not necessarily a 1-1 match for non-standard attributes.
+// type. There is not necessarily a 1-1 match for non-builtin attributes.
class PythonAttr<string c, string p> {
string cppStorageType = c;
string pythonType = p;
}
-// Mappings between supported standard attribtues and Python types.
+// Mappings between supported builtin attribtues and Python types.
def : PythonAttr<"::mlir::Attribute", "_ir.Attribute">;
def : PythonAttr<"::mlir::BoolAttr", "_ir.BoolAttr">;
def : PythonAttr<"::mlir::IntegerAttr", "_ir.IntegerAttr">;
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 574aab190302..fb5991fa72af 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -15,7 +15,7 @@
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
#define MLIR_DIALECT_COMMONFOLDERS_H
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h
index 8bce2fd0ad2b..ad8982546cab 100644
--- a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h
+++ b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h
@@ -14,7 +14,7 @@
#ifndef MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H
#define MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
index 5cd6b033d058..b948df39b3fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
@@ -14,7 +14,7 @@
#define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
// Pull in SPIR-V attribute definitions for target and ABI.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 0a55b68b350e..794417e99652 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -18,7 +18,7 @@
#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 5464a1f1c80e..04e2dccc6641 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -10,47 +10,16 @@
#define MLIR_IR_ATTRIBUTES_H
#include "mlir/IR/AttributeSupport.h"
-#include "llvm/ADT/APFloat.h"
-#include "llvm/ADT/Sequence.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
-#include <complex>
namespace mlir {
-class AffineMap;
-class Dialect;
-class FunctionType;
class Identifier;
-class IntegerSet;
-class Location;
-class MLIRContext;
-class ShapedType;
-class Type;
-namespace detail {
-
-struct AffineMapAttributeStorage;
-struct ArrayAttributeStorage;
-struct DictionaryAttributeStorage;
-struct IntegerAttributeStorage;
-struct IntegerSetAttributeStorage;
-struct FloatAttributeStorage;
-struct OpaqueAttributeStorage;
-struct StringAttributeStorage;
-struct SymbolRefAttributeStorage;
-struct TypeAttributeStorage;
-
-/// Elements Attributes.
-struct DenseIntOrFPElementsAttributeStorage;
-struct DenseStringElementsAttributeStorage;
-struct OpaqueElementsAttributeStorage;
-struct SparseElementsAttributeStorage;
-} // namespace detail
-
-/// Attributes are known-constant values of operations and functions.
+/// Attributes are known-constant values of operations.
///
/// Instances of the Attribute class are references to immortal key-value pairs
-/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a
-/// thin wrapper around an underlying storage pointer. Attributes are usually
+/// with immutable, uniqued keys owned by MLIRContext. As such, an Attribute is
+/// a thin wrapper around an underlying storage pointer. Attributes are usually
/// passed by value.
class Attribute {
public:
@@ -126,1469 +95,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
return os;
}
-//===----------------------------------------------------------------------===//
-// AttributeTraitBase
-//===----------------------------------------------------------------------===//
-
-namespace AttributeTrait {
-/// This class represents the base of an attribute trait.
-template <typename ConcreteType, template <typename> class TraitType>
-using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
-} // namespace AttributeTrait
-
-//===----------------------------------------------------------------------===//
-// AttributeInterface
-//===----------------------------------------------------------------------===//
-
-/// This class represents the base of an attribute interface. See the definition
-/// of `detail::Interface` for requirements on the `Traits` type.
-template <typename ConcreteType, typename Traits>
-class AttributeInterface
- : public detail::Interface<ConcreteType, Attribute, Traits, Attribute,
- AttributeTrait::TraitBase> {
-public:
- using Base = AttributeInterface<ConcreteType, Traits>;
- using InterfaceBase = detail::Interface<ConcreteType, Attribute, Traits,
- Attribute, AttributeTrait::TraitBase>;
- using InterfaceBase::InterfaceBase;
-
-private:
- /// Returns the impl interface instance for the given type.
- static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) {
- return attr.getAbstractAttribute().getInterface<ConcreteType>();
- }
-
- /// Allow access to 'getInterfaceFor'.
- friend InterfaceBase;
-};
-
-//===----------------------------------------------------------------------===//
-// AffineMapAttr
-//===----------------------------------------------------------------------===//
-
-class AffineMapAttr
- : public Attribute::AttrBase<AffineMapAttr, Attribute,
- detail::AffineMapAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = AffineMap;
-
- static AffineMapAttr get(AffineMap value);
-
- AffineMap getValue() const;
-};
-
-//===----------------------------------------------------------------------===//
-// ArrayAttr
-//===----------------------------------------------------------------------===//
-
-/// Array attributes are lists of other attributes. They are not necessarily
-/// type homogenous given that attributes don't, in general, carry types.
-class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
- detail::ArrayAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = ArrayRef<Attribute>;
-
- static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
-
- ArrayRef<Attribute> getValue() const;
- Attribute operator[](unsigned idx) const;
-
- /// Support range iteration.
- using iterator = llvm::ArrayRef<Attribute>::iterator;
- iterator begin() const { return getValue().begin(); }
- iterator end() const { return getValue().end(); }
- size_t size() const { return getValue().size(); }
- bool empty() const { return size() == 0; }
-
-private:
- /// Class for underlying value iterator support.
- template <typename AttrTy>
- class attr_value_iterator final
- : public llvm::mapped_iterator<ArrayAttr::iterator,
- AttrTy (*)(Attribute)> {
- public:
- explicit attr_value_iterator(ArrayAttr::iterator it)
- : llvm::mapped_iterator<ArrayAttr::iterator, AttrTy (*)(Attribute)>(
- it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
- AttrTy operator*() const { return (*this->I).template cast<AttrTy>(); }
- };
-
-public:
- template <typename AttrTy>
- iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
- return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
- attr_value_iterator<AttrTy>(end()));
- }
- template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType>
- auto getAsValueRange() {
- return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
- return static_cast<UnderlyingTy>(attr.getValue());
- });
- }
-};
-
-//===----------------------------------------------------------------------===//
-// DictionaryAttr
-//===----------------------------------------------------------------------===//
-
-/// NamedAttribute is used for dictionary attributes, it holds an identifier for
-/// the name and a value for the attribute. The attribute pointer should always
-/// be non-null.
-using NamedAttribute = std::pair<Identifier, Attribute>;
-
-bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs);
-bool operator<(const NamedAttribute &lhs, StringRef rhs);
-
-/// Dictionary attribute is an attribute that represents a sorted collection of
-/// named attribute values. The elements are sorted by name, and each name must
-/// be unique within the collection.
-class DictionaryAttr
- : public Attribute::AttrBase<DictionaryAttr, Attribute,
- detail::DictionaryAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = ArrayRef<NamedAttribute>;
-
- /// Construct a dictionary attribute with the provided list of named
- /// attributes. This method assumes that the provided list is unordered. If
- /// the caller can guarantee that the attributes are ordered by name,
- /// getWithSorted should be used instead.
- static DictionaryAttr get(ArrayRef<NamedAttribute> value,
- MLIRContext *context);
-
- /// Construct a dictionary with an array of values that is known to already be
- /// sorted by name and uniqued.
- static DictionaryAttr getWithSorted(ArrayRef<NamedAttribute> value,
- MLIRContext *context);
-
- ArrayRef<NamedAttribute> getValue() const;
-
- /// Return the specified attribute if present, null otherwise.
- Attribute get(StringRef name) const;
- Attribute get(Identifier name) const;
-
- /// Return the specified named attribute if present, None otherwise.
- Optional<NamedAttribute> getNamed(StringRef name) const;
- Optional<NamedAttribute> getNamed(Identifier name) const;
-
- /// Support range iteration.
- using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
- iterator begin() const;
- iterator end() const;
- bool empty() const { return size() == 0; }
- size_t size() const;
-
- /// Sorts the NamedAttributes in the array ordered by name as expected by
- /// getWithSorted and returns whether the values were sorted.
- /// Requires: uniquely named attributes.
- static bool sort(ArrayRef<NamedAttribute> values,
- SmallVectorImpl<NamedAttribute> &storage);
-
- /// Sorts the NamedAttributes in the array ordered by name as expected by
- /// getWithSorted in place on an array and returns whether the values needed
- /// to be sorted.
- /// Requires: uniquely named attributes.
- static bool sortInPlace(SmallVectorImpl<NamedAttribute> &array);
-
- /// Returns an entry with a duplicate name in `array`, if it exists, else
- /// returns llvm::None. If `isSorted` is true, the array is assumed to be
- /// sorted else it will be sorted in place before finding the duplicate entry.
- static Optional<NamedAttribute>
- findDuplicate(SmallVectorImpl<NamedAttribute> &array, bool isSorted);
-
-private:
- /// Return empty dictionary.
- static DictionaryAttr getEmpty(MLIRContext *context);
-};
-
-//===----------------------------------------------------------------------===//
-// FloatAttr
-//===----------------------------------------------------------------------===//
-
-class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
- detail::FloatAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = APFloat;
-
- /// Return a float attribute for the specified value in the specified type.
- /// These methods should only be used for simple constant values, e.g 1.0/2.0,
- /// that are known-valid both as host double and the 'type' format.
- static FloatAttr get(Type type, double value);
- static FloatAttr getChecked(Type type, double value, Location loc);
-
- /// Return a float attribute for the specified value in the specified type.
- static FloatAttr get(Type type, const APFloat &value);
- static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
-
- APFloat getValue() const;
-
- /// This function is used to convert the value to a double, even if it loses
- /// precision.
- double getValueAsDouble() const;
- static double getValueAsDouble(APFloat val);
-
- /// Verify the construction invariants for a double value.
- static LogicalResult verifyConstructionInvariants(Location loc, Type type,
- double value);
- static LogicalResult verifyConstructionInvariants(Location loc, Type type,
- const APFloat &value);
-};
-
-//===----------------------------------------------------------------------===//
-// IntegerAttr
-//===----------------------------------------------------------------------===//
-
-class IntegerAttr
- : public Attribute::AttrBase<IntegerAttr, Attribute,
- detail::IntegerAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = APInt;
-
- static IntegerAttr get(Type type, int64_t value);
- static IntegerAttr get(Type type, const APInt &value);
-
- APInt getValue() const;
- /// Return the integer value as a 64-bit int. The attribute must be a signless
- /// integer.
- // TODO: Change callers to use getValue instead.
- int64_t getInt() const;
- /// Return the integer value as a signed 64-bit int. The attribute must be
- /// a signed integer.
- int64_t getSInt() const;
- /// Return the integer value as a unsigned 64-bit int. The attribute must be
- /// an unsigned integer.
- uint64_t getUInt() const;
-
- static LogicalResult verifyConstructionInvariants(Location loc, Type type,
- int64_t value);
- static LogicalResult verifyConstructionInvariants(Location loc, Type type,
- const APInt &value);
-};
-
-//===----------------------------------------------------------------------===//
-// BoolAttr
-
-/// Special case of IntegerAttr to represent boolean integers, i.e., signless i1
-/// integers.
-class BoolAttr : public Attribute {
-public:
- using Attribute::Attribute;
- using ValueType = bool;
-
- static BoolAttr get(bool value, MLIRContext *context);
-
- /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
- /// avoid bringing in all of IntegerAttrs methods.
- operator IntegerAttr() const { return IntegerAttr(impl); }
-
- /// Return the boolean value of this attribute.
- bool getValue() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Attribute attr);
-};
-
-//===----------------------------------------------------------------------===//
-// IntegerSetAttr
-//===----------------------------------------------------------------------===//
-
-class IntegerSetAttr
- : public Attribute::AttrBase<IntegerSetAttr, Attribute,
- detail::IntegerSetAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = IntegerSet;
-
- static IntegerSetAttr get(IntegerSet value);
-
- IntegerSet getValue() const;
-};
-
-//===----------------------------------------------------------------------===//
-// OpaqueAttr
-//===----------------------------------------------------------------------===//
-
-/// Opaque attributes represent attributes of non-registered dialects. These are
-/// attribute represented in their raw string form, and can only usefully be
-/// tested for attribute equality.
-class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
- detail::OpaqueAttributeStorage> {
-public:
- using Base::Base;
-
- /// Get or create a new OpaqueAttr with the provided dialect and string data.
- static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
- MLIRContext *context);
-
- /// Get or create a new OpaqueAttr with the provided dialect and string data.
- /// If the given identifier is not a valid namespace for a dialect, then a
- /// null attribute is returned.
- static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
- Type type, Location location);
-
- /// Returns the dialect namespace of the opaque attribute.
- Identifier getDialectNamespace() const;
-
- /// Returns the raw attribute data of the opaque attribute.
- StringRef getAttrData() const;
-
- /// Verify the construction of an opaque attribute.
- static LogicalResult verifyConstructionInvariants(Location loc,
- Identifier dialect,
- StringRef attrData,
- Type type);
-};
-
-//===----------------------------------------------------------------------===//
-// StringAttr
-//===----------------------------------------------------------------------===//
-
-class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
- detail::StringAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = StringRef;
-
- /// Get an instance of a StringAttr with the given string.
- static StringAttr get(StringRef bytes, MLIRContext *context);
-
- /// Get an instance of a StringAttr with the given string and Type.
- static StringAttr get(StringRef bytes, Type type);
-
- StringRef getValue() const;
-};
-
-//===----------------------------------------------------------------------===//
-// SymbolRefAttr
-//===----------------------------------------------------------------------===//
-
-class FlatSymbolRefAttr;
-
-/// A symbol reference attribute represents a symbolic reference to another
-/// operation.
-class SymbolRefAttr
- : public Attribute::AttrBase<SymbolRefAttr, Attribute,
- detail::SymbolRefAttributeStorage> {
-public:
- using Base::Base;
-
- /// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
-
- /// Construct a symbol reference for the given value name, and a set of nested
- /// references that are further resolve to a nested symbol.
- static SymbolRefAttr get(StringRef value,
- ArrayRef<FlatSymbolRefAttr> references,
- MLIRContext *ctx);
-
- /// Returns the name of the top level symbol reference, i.e. the root of the
- /// reference path.
- StringRef getRootReference() const;
-
- /// Returns the name of the fully resolved symbol, i.e. the leaf of the
- /// reference path.
- StringRef getLeafReference() const;
-
- /// Returns the set of nested references representing the path to the symbol
- /// nested under the root reference.
- ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
-};
-
-/// A symbol reference with a reference path containing a single element. This
-/// is used to refer to an operation within the current symbol table.
-class FlatSymbolRefAttr : public SymbolRefAttr {
-public:
- using SymbolRefAttr::SymbolRefAttr;
- using ValueType = StringRef;
-
- /// Construct a symbol reference for the given value name.
- static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
- return SymbolRefAttr::get(value, ctx);
- }
-
- /// Returns the name of the held symbol reference.
- StringRef getValue() const { return getRootReference(); }
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Attribute attr) {
- SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
- return refAttr && refAttr.getNestedReferences().empty();
- }
-
-private:
- using SymbolRefAttr::get;
- using SymbolRefAttr::getNestedReferences;
-};
-
-//===----------------------------------------------------------------------===//
-// Type
-//===----------------------------------------------------------------------===//
-
-class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
- detail::TypeAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = Type;
-
- static TypeAttr get(Type value);
-
- Type getValue() const;
-};
-
-//===----------------------------------------------------------------------===//
-// UnitAttr
-//===----------------------------------------------------------------------===//
-
-/// Unit attributes are attributes that hold no specific value and are given
-/// meaning by their existence.
-class UnitAttr
- : public Attribute::AttrBase<UnitAttr, Attribute, AttributeStorage> {
-public:
- using Base::Base;
-
- static UnitAttr get(MLIRContext *context);
-};
-
-//===----------------------------------------------------------------------===//
-// Elements Attributes
-//===----------------------------------------------------------------------===//
-
-namespace detail {
-template <typename T> class ElementsAttrIterator;
-template <typename T> class ElementsAttrRange;
-} // namespace detail
-
-/// A base attribute that represents a reference to a static shaped tensor or
-/// vector constant.
-class ElementsAttr : public Attribute {
-public:
- using Attribute::Attribute;
- template <typename T> using iterator = detail::ElementsAttrIterator<T>;
- template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
-
- /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
- /// with static shape.
- ShapedType getType() const;
-
- /// Return the value at the given index. The index is expected to refer to a
- /// valid element.
- Attribute getValue(ArrayRef<uint64_t> index) const;
-
- /// Return the value of type 'T' at the given index, where 'T' corresponds to
- /// an Attribute type.
- template <typename T> T getValue(ArrayRef<uint64_t> index) const {
- return getValue(index).template cast<T>();
- }
-
- /// Return the elements of this attribute as a value of type 'T'. Note:
- /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
- /// iteration.
- template <typename T> iterator_range<T> getValues() const;
-
- /// Return if the given 'index' refers to a valid element in this attribute.
- bool isValidIndex(ArrayRef<uint64_t> index) const;
-
- /// Returns the number of elements held by this attribute.
- int64_t getNumElements() const;
-
- /// Returns the number of elements held by this attribute.
- int64_t size() const { return getNumElements(); }
-
- /// Generates a new ElementsAttr by mapping each int value to a new
- /// underlying APInt. The new values can represent either an integer or float.
- /// This ElementsAttr should contain integers.
- ElementsAttr mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const;
-
- /// Generates a new ElementsAttr by mapping each float value to a new
- /// underlying APInt. The new values can represent either an integer or float.
- /// This ElementsAttr should contain floats.
- ElementsAttr mapValues(Type newElementType,
- function_ref<APInt(const APFloat &)> mapping) const;
-
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr);
-
-protected:
- /// Returns the 1 dimensional flattened row-major index from the given
- /// multi-dimensional index.
- uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
-};
-
-namespace detail {
-/// DenseElementsAttr data is aligned to uint64_t, so this traits class is
-/// necessary to interop with PointerIntPair.
-class DenseElementDataPointerTypeTraits {
-public:
- static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
- static inline const char *getFromVoidPointer(const void *ptr) {
- return static_cast<const char *>(ptr);
- }
-
- // Note: We could steal more bits if the need arises.
- static constexpr int NumLowBitsAvailable = 1;
-};
-
-/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
-using DenseIterPtrAndSplat =
- llvm::PointerIntPair<const char *, 1, bool,
- DenseElementDataPointerTypeTraits>;
-
-/// Impl iterator for indexed DenseElementsAttr iterators that records a data
-/// pointer and data index that is adjusted for the case of a splat attribute.
-template <typename ConcreteT, typename T, typename PointerT = T *,
- typename ReferenceT = T &>
-class DenseElementIndexedIteratorImpl
- : public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
- PointerT, ReferenceT> {
-protected:
- DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
- size_t dataIndex)
- : llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
- PointerT, ReferenceT>({data, isSplat},
- dataIndex) {}
-
- /// Return the current index for this iterator, adjusted for the case of a
- /// splat.
- ptr
diff _t getDataIndex() const {
- bool isSplat = this->base.getInt();
- return isSplat ? 0 : this->index;
- }
-
- /// Return the data base pointer.
- const char *getData() const { return this->base.getPointer(); }
-};
-
-/// Type trait detector that checks if a given type T is a complex type.
-template <typename T> struct is_complex_t : public std::false_type {};
-template <typename T>
-struct is_complex_t<std::complex<T>> : public std::true_type {};
-} // namespace detail
-
-/// An attribute that represents a reference to a dense vector or tensor object.
-///
-class DenseElementsAttr : public ElementsAttr {
-public:
- using ElementsAttr::ElementsAttr;
-
- /// Type trait used to check if the given type T is a potentially valid C++
- /// floating point type that can be used to access the underlying element
- /// types of a DenseElementsAttr.
- // TODO: Use std::disjunction when C++17 is supported.
- template <typename T> struct is_valid_cpp_fp_type {
- /// The type is a valid floating point type if it is a builtin floating
- /// point type, or is a potentially user defined floating point type. The
- /// latter allows for supporting users that have custom types defined for
- /// bfloat16/half/etc.
- static constexpr bool value = llvm::is_one_of<T, float, double>::value ||
- (std::numeric_limits<T>::is_specialized &&
- !std::numeric_limits<T>::is_integer);
- };
-
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr);
-
- /// Constructs a dense elements attribute from an array of element values.
- /// Each element attribute value is expected to be an element of 'type'.
- /// 'type' must be a vector or tensor with static shape. If the element of
- /// `type` is non-integer/index/float it is assumed to be a string type.
- static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
-
- /// Constructs a dense integer elements attribute from an array of integer
- /// or floating-point values. Each value is expected to be the same bitwidth
- /// of the element type of 'type'. 'type' must be a vector or tensor with
- /// static shape.
- template <typename T, typename = typename std::enable_if<
- std::numeric_limits<T>::is_integer ||
- is_valid_cpp_fp_type<T>::value>::type>
- static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
- const char *data = reinterpret_cast<const char *>(values.data());
- return getRawIntOrFloat(
- type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
- std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed);
- }
-
- /// Constructs a dense integer elements attribute from a single element.
- template <typename T, typename = typename std::enable_if<
- std::numeric_limits<T>::is_integer ||
- is_valid_cpp_fp_type<T>::value ||
- detail::is_complex_t<T>::value>::type>
- static DenseElementsAttr get(const ShapedType &type, T value) {
- return get(type, llvm::makeArrayRef(value));
- }
-
- /// Constructs a dense complex elements attribute from an array of complex
- /// values. Each value is expected to be the same bitwidth of the element type
- /// of 'type'. 'type' must be a vector or tensor with static shape.
- template <typename T, typename ElementT = typename T::value_type,
- typename = typename std::enable_if<
- detail::is_complex_t<T>::value &&
- (std::numeric_limits<ElementT>::is_integer ||
- is_valid_cpp_fp_type<ElementT>::value)>::type>
- static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
- const char *data = reinterpret_cast<const char *>(values.data());
- return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
- sizeof(T), std::numeric_limits<ElementT>::is_integer,
- std::numeric_limits<ElementT>::is_signed);
- }
-
- /// Overload of the above 'get' method that is specialized for boolean values.
- static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
-
- /// Overload of the above 'get' method that is specialized for StringRef
- /// values.
- static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
-
- /// Constructs a dense integer elements attribute from an array of APInt
- /// values. Each APInt value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
-
- /// Constructs a dense complex elements attribute from an array of APInt
- /// values. Each APInt value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APInt>> values);
-
- /// Constructs a dense float elements attribute from an array of APFloat
- /// values. Each APFloat value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
-
- /// Constructs a dense complex elements attribute from an array of APFloat
- /// values. Each APFloat value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APFloat>> values);
-
- /// Construct a dense elements attribute for an initializer_list of values.
- /// Each value is expected to be the same bitwidth of the element type of
- /// 'type'. 'type' must be a vector or tensor with static shape.
- template <typename T>
- static DenseElementsAttr get(const ShapedType &type,
- const std::initializer_list<T> &list) {
- return get(type, ArrayRef<T>(list));
- }
-
- /// Construct a dense elements attribute from a raw buffer representing the
- /// data for this attribute. Users should generally not use this methods as
- /// the expected buffer format may not be a form the user expects.
- static DenseElementsAttr getFromRawBuffer(ShapedType type,
- ArrayRef<char> rawBuffer,
- bool isSplatBuffer);
-
- /// Returns true if the given buffer is a valid raw buffer for the given type.
- /// `detectedSplat` is set if the buffer is valid and represents a splat
- /// buffer.
- static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
- bool &detectedSplat);
-
- //===--------------------------------------------------------------------===//
- // Iterators
- //===--------------------------------------------------------------------===//
-
- /// A utility iterator that allows walking over the internal Attribute values
- /// of a DenseElementsAttr.
- class AttributeElementIterator
- : public llvm::indexed_accessor_iterator<AttributeElementIterator,
- const void *, Attribute,
- Attribute, Attribute> {
- public:
- /// Accesses the Attribute value at this iterator position.
- Attribute operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- AttributeElementIterator(DenseElementsAttr attr, size_t index);
- };
-
- /// Iterator for walking raw element values of the specified type 'T', which
- /// may be any c++ data type matching the stored representation: int32_t,
- /// float, etc.
- template <typename T>
- class ElementIterator
- : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
- const T> {
- public:
- /// Accesses the raw value at this iterator position.
- const T &operator*() const {
- return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
- }
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- ElementIterator(const char *data, bool isSplat, size_t dataIndex)
- : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
- data, isSplat, dataIndex) {}
- };
-
- /// A utility iterator that allows walking over the internal bool values.
- class BoolElementIterator
- : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
- bool, bool, bool> {
- public:
- /// Accesses the bool value at this iterator position.
- bool operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
- };
-
- /// A utility iterator that allows walking over the internal raw APInt values.
- class IntElementIterator
- : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
- APInt, APInt, APInt> {
- public:
- /// Accesses the raw APInt value at this iterator position.
- APInt operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
-
- /// The bitwidth of the element type.
- size_t bitWidth;
- };
-
- /// A utility iterator that allows walking over the internal raw complex APInt
- /// values.
- class ComplexIntElementIterator
- : public detail::DenseElementIndexedIteratorImpl<
- ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
- std::complex<APInt>> {
- public:
- /// Accesses the raw std::complex<APInt> value at this iterator position.
- std::complex<APInt> operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
-
- /// The bitwidth of the element type.
- size_t bitWidth;
- };
-
- /// Iterator for walking over APFloat values.
- class FloatElementIterator final
- : public llvm::mapped_iterator<IntElementIterator,
- std::function<APFloat(const APInt &)>> {
- friend DenseElementsAttr;
-
- /// Initializes the float element iterator to the specified iterator.
- FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
-
- public:
- using reference = APFloat;
- };
-
- /// Iterator for walking over complex APFloat values.
- class ComplexFloatElementIterator final
- : public llvm::mapped_iterator<
- ComplexIntElementIterator,
- std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
- friend DenseElementsAttr;
-
- /// Initializes the float element iterator to the specified iterator.
- ComplexFloatElementIterator(const llvm::fltSemantics &smt,
- ComplexIntElementIterator it);
-
- public:
- using reference = std::complex<APFloat>;
- };
-
- //===--------------------------------------------------------------------===//
- // Value Querying
- //===--------------------------------------------------------------------===//
-
- /// Returns true if this attribute corresponds to a splat, i.e. if all element
- /// values are the same.
- bool isSplat() const;
-
- /// Return the splat value for this attribute. This asserts that the attribute
- /// corresponds to a splat.
- Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
- template <typename T>
- typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
- std::is_same<Attribute, T>::value,
- T>::type
- getSplatValue() const {
- assert(isSplat() && "expected the attribute to be a splat");
- return *getValues<T>().begin();
- }
- /// Return the splat value for derived attribute element types.
- template <typename T>
- typename std::enable_if<std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value,
- T>::type
- getSplatValue() const {
- return getSplatValue().template cast<T>();
- }
-
- /// Return the value at the given index. The 'index' is expected to refer to a
- /// valid element.
- Attribute getValue(ArrayRef<uint64_t> index) const {
- return getValue<Attribute>(index);
- }
- template <typename T> T getValue(ArrayRef<uint64_t> index) const {
- // Skip to the element corresponding to the flattened index.
- return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
- }
-
- /// Return the held element values as a range of integer or floating-point
- /// values.
- template <typename T, typename = typename std::enable_if<
- (!std::is_same<T, bool>::value &&
- std::numeric_limits<T>::is_integer) ||
- is_valid_cpp_fp_type<T>::value>::type>
- llvm::iterator_range<ElementIterator<T>> getValues() const {
- assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
- std::numeric_limits<T>::is_signed));
- const char *rawData = getRawData().data();
- bool splat = isSplat();
- return {ElementIterator<T>(rawData, splat, 0),
- ElementIterator<T>(rawData, splat, getNumElements())};
- }
-
- /// Return the held element values as a range of std::complex.
- template <typename T, typename ElementT = typename T::value_type,
- typename = typename std::enable_if<
- detail::is_complex_t<T>::value &&
- (std::numeric_limits<ElementT>::is_integer ||
- is_valid_cpp_fp_type<ElementT>::value)>::type>
- llvm::iterator_range<ElementIterator<T>> getValues() const {
- assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
- std::numeric_limits<ElementT>::is_signed));
- const char *rawData = getRawData().data();
- bool splat = isSplat();
- return {ElementIterator<T>(rawData, splat, 0),
- ElementIterator<T>(rawData, splat, getNumElements())};
- }
-
- /// Return the held element values as a range of StringRef.
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, StringRef>::value>::type>
- llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
- auto stringRefs = getRawStringData();
- const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
- bool splat = isSplat();
- return {ElementIterator<StringRef>(ptr, splat, 0),
- ElementIterator<StringRef>(ptr, splat, getNumElements())};
- }
-
- /// Return the held element values as a range of Attributes.
- llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, Attribute>::value>::type>
- llvm::iterator_range<AttributeElementIterator> getValues() const {
- return getAttributeValues();
- }
- AttributeElementIterator attr_value_begin() const;
- AttributeElementIterator attr_value_end() const;
-
- /// Return the held element values a range of T, where T is a derived
- /// attribute type.
- template <typename T>
- using DerivedAttributeElementIterator =
- llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
- template <typename T, typename = typename std::enable_if<
- std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value>::type>
- llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
- auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return llvm::map_range(getAttributeValues(),
- static_cast<T (*)(Attribute)>(castFn));
- }
-
- /// Return the held element values as a range of bool. The element type of
- /// this attribute must be of integer type of bitwidth 1.
- llvm::iterator_range<BoolElementIterator> getBoolValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, bool>::value>::type>
- llvm::iterator_range<BoolElementIterator> getValues() const {
- return getBoolValues();
- }
-
- /// Return the held element values as a range of APInts. The element type of
- /// this attribute must be of integer type.
- llvm::iterator_range<IntElementIterator> getIntValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, APInt>::value>::type>
- llvm::iterator_range<IntElementIterator> getValues() const {
- return getIntValues();
- }
- IntElementIterator int_value_begin() const;
- IntElementIterator int_value_end() const;
-
- /// Return the held element values as a range of complex APInts. The element
- /// type of this attribute must be a complex of integer type.
- llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, std::complex<APInt>>::value>::type>
- llvm::iterator_range<ComplexIntElementIterator> getValues() const {
- return getComplexIntValues();
- }
-
- /// Return the held element values as a range of APFloat. The element type of
- /// this attribute must be of float type.
- llvm::iterator_range<FloatElementIterator> getFloatValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, APFloat>::value>::type>
- llvm::iterator_range<FloatElementIterator> getValues() const {
- return getFloatValues();
- }
- FloatElementIterator float_value_begin() const;
- FloatElementIterator float_value_end() const;
-
- /// Return the held element values as a range of complex APFloat. The element
- /// type of this attribute must be a complex of float type.
- llvm::iterator_range<ComplexFloatElementIterator>
- getComplexFloatValues() const;
- template <typename T, typename = typename std::enable_if<std::is_same<
- T, std::complex<APFloat>>::value>::type>
- llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
- return getComplexFloatValues();
- }
-
- /// Return the raw storage data held by this attribute. Users should generally
- /// not use this directly, as the internal storage format is not always in the
- /// form the user might expect.
- ArrayRef<char> getRawData() const;
-
- /// Return the raw StringRef data held by this attribute.
- ArrayRef<StringRef> getRawStringData() const;
-
- //===--------------------------------------------------------------------===//
- // Mutation Utilities
- //===--------------------------------------------------------------------===//
-
- /// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but has been reshaped to 'newType'. The new type must have the
- /// same total number of elements as well as element type.
- DenseElementsAttr reshape(ShapedType newType);
-
- /// Generates a new DenseElementsAttr by mapping each int value to a new
- /// underlying APInt. The new values can represent either an integer or float.
- /// This underlying type must be an DenseIntElementsAttr.
- DenseElementsAttr mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const;
-
- /// Generates a new DenseElementsAttr by mapping each float value to a new
- /// underlying APInt. the new values can represent either an integer or float.
- /// This underlying type must be an DenseFPElementsAttr.
- DenseElementsAttr
- mapValues(Type newElementType,
- function_ref<APInt(const APFloat &)> mapping) const;
-
-protected:
- /// Get iterators to the raw APInt values for each element in this attribute.
- IntElementIterator raw_int_begin() const {
- return IntElementIterator(*this, 0);
- }
- IntElementIterator raw_int_end() const {
- return IntElementIterator(*this, getNumElements());
- }
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// complex type. This method is used to verify type invariants that the
- /// templatized 'get' method cannot.
- static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// integer or floating-point type. This method is used to verify type
- /// invariants that the templatized 'get' method cannot.
- static DenseElementsAttr getRawIntOrFloat(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-
- /// Check the information for a C++ data type, check if this type is valid for
- /// the current attribute. This method is used to verify specific type
- /// invariants that the templatized 'getValues' method cannot.
- bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
-
- /// Check the information for a C++ data type, check if this type is valid for
- /// the current attribute. This method is used to verify specific type
- /// invariants that the templatized 'getValues' method cannot.
- bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
-};
-
-/// An attribute class for representing dense arrays of strings. The structure
-/// storing and querying a list of densely packed strings.
-class DenseStringElementsAttr
- : public Attribute::AttrBase<DenseStringElementsAttr, DenseElementsAttr,
- detail::DenseStringElementsAttributeStorage> {
-
-public:
- using Base::Base;
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// integer or floating-point type. This method is used to verify type
- /// invariants that the templatized 'get' method cannot.
- static DenseStringElementsAttr get(ShapedType type, ArrayRef<StringRef> data);
-
-protected:
- friend DenseElementsAttr;
-};
-
-/// An attribute class for specializing behavior of Int and Floating-point
-/// densely packed string arrays.
-class DenseIntOrFPElementsAttr
- : public Attribute::AttrBase<DenseIntOrFPElementsAttr, DenseElementsAttr,
- detail::DenseIntOrFPElementsAttributeStorage> {
-
-public:
- using Base::Base;
-
- /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
- /// the elements of `inRawData` has `type`. If `inRawData` is little endian
- /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is
- /// BE, converted to LE.
- static void
- convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,
- MutableArrayRef<char> outRawData,
- ShapedType type);
-
- /// Convert endianess of input for big-endian(BE) machines. The number of
- /// elements of `inRawData` is `numElements`, and each element has
- /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is
- /// converted to big endian (BE) and saved in `outRawData`. Conversely, if
- /// `inRawData` is BE, converted to LE.
- static void convertEndianOfCharForBEmachine(const char *inRawData,
- char *outRawData,
- size_t elementBitWidth,
- size_t numElements);
-
-protected:
- friend DenseElementsAttr;
-
- /// Constructs a dense elements attribute from an array of raw APFloat values.
- /// Each APFloat value is expected to have the same bitwidth as the element
- /// type of 'type'. 'type' must be a vector or tensor with static shape.
- static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
- ArrayRef<APFloat> values, bool isSplat);
-
- /// Constructs a dense elements attribute from an array of raw APInt values.
- /// Each APInt value is expected to have the same bitwidth as the element type
- /// of 'type'. 'type' must be a vector or tensor with static shape.
- static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
- ArrayRef<APInt> values, bool isSplat);
-
- /// Get or create a new dense elements attribute instance with the given raw
- /// data buffer. 'type' must be a vector or tensor with static shape.
- static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
- bool isSplat);
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// complex type. This method is used to verify type invariants that the
- /// templatized 'get' method cannot.
- static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// integer or floating-point type. This method is used to verify type
- /// invariants that the templatized 'get' method cannot.
- static DenseElementsAttr getRawIntOrFloat(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-};
-
-/// An attribute that represents a reference to a dense float vector or tensor
-/// object. Each element is stored as a double.
-class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
-public:
- using iterator = DenseElementsAttr::FloatElementIterator;
-
- using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
-
- /// Get an instance of a DenseFPElementsAttr with the given arguments. This
- /// simply wraps the DenseElementsAttr::get calls.
- template <typename Arg>
- static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
- .template cast<DenseFPElementsAttr>();
- }
- template <typename T>
- static DenseFPElementsAttr get(const ShapedType &type,
- const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseFPElementsAttr>();
- }
-
- /// Generates a new DenseElementsAttr by mapping each value attribute, and
- /// constructing the DenseElementsAttr given the new element type.
- DenseElementsAttr
- mapValues(Type newElementType,
- function_ref<APInt(const APFloat &)> mapping) const;
-
- /// Iterator access to the float element values.
- iterator begin() const { return float_value_begin(); }
- iterator end() const { return float_value_end(); }
-
- /// Method for supporting type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr);
-};
-
-/// An attribute that represents a reference to a dense integer vector or tensor
-/// object.
-class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
-public:
- /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
- /// iterator directly.
- using iterator = DenseElementsAttr::IntElementIterator;
-
- using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
-
- /// Get an instance of a DenseIntElementsAttr with the given arguments. This
- /// simply wraps the DenseElementsAttr::get calls.
- template <typename Arg>
- static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
- .template cast<DenseIntElementsAttr>();
- }
- template <typename T>
- static DenseIntElementsAttr get(const ShapedType &type,
- const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseIntElementsAttr>();
- }
-
- /// Generates a new DenseElementsAttr by mapping each value attribute, and
- /// constructing the DenseElementsAttr given the new element type.
- DenseElementsAttr mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const;
-
- /// Iterator access to the integer element values.
- iterator begin() const { return raw_int_begin(); }
- iterator end() const { return raw_int_end(); }
-
- /// Method for supporting type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr);
-};
-
-/// An opaque attribute that represents a reference to a vector or tensor
-/// constant with opaque content. This representation is for tensor constants
-/// which the compiler may not need to interpret. This attribute is always
-/// associated with a particular dialect, which provides a method to convert
-/// tensor representation to a non-opaque format.
-class OpaqueElementsAttr
- : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
- detail::OpaqueElementsAttributeStorage> {
-public:
- using Base::Base;
- using ValueType = StringRef;
-
- static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
- StringRef bytes);
-
- StringRef getValue() const;
-
- /// Return the value at the given index. The 'index' is expected to refer to a
- /// valid element.
- Attribute getValue(ArrayRef<uint64_t> index) const;
-
- /// Decodes the attribute value using dialect-specific decoding hook.
- /// Returns false if decoding is successful. If not, returns true and leaves
- /// 'result' argument unspecified.
- bool decode(ElementsAttr &result);
-
- /// Returns dialect associated with this opaque constant.
- Dialect *getDialect() const;
-};
-
-/// An attribute that represents a reference to a sparse vector or tensor
-/// object.
-///
-/// This class uses COO (coordinate list) encoding to represent the sparse
-/// elements in an element attribute. Specifically, the sparse vector/tensor
-/// stores the indices and values as two separate dense elements attributes of
-/// tensor type (even if the sparse attribute is of vector type, in order to
-/// support empty lists). The dense elements attribute indices is a 2-D tensor
-/// of 64-bit integer elements with shape [N, ndims], which specifies the
-/// indices of the elements in the sparse tensor that contains nonzero values.
-/// The dense elements attribute values is a 1-D tensor with shape [N], and it
-/// supplies the corresponding values for the indices.
-///
-/// For example,
-/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
-/// [[1, 0, 0, 0],
-/// [0, 0, 5, 0],
-/// [0, 0, 0, 0]].
-class SparseElementsAttr
- : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
- detail::SparseElementsAttributeStorage> {
-public:
- using Base::Base;
-
- template <typename T>
- using iterator =
- llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptr
diff _t>,
- std::function<T(ptr
diff _t)>>;
-
- /// 'type' must be a vector or tensor with static shape.
- static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
- DenseElementsAttr values);
-
- DenseIntElementsAttr getIndices() const;
-
- DenseElementsAttr getValues() const;
-
- /// Return the values of this attribute in the form of the given type 'T'. 'T'
- /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
- template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
- auto zeroValue = getZeroValue<T>();
- auto valueIt = getValues().getValues<T>().begin();
- const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
- // TODO: Move-capture flatSparseIndices when c++14 is available.
- std::function<T(ptr
diff _t)> mapFn = [=](ptr
diff _t index) {
- // Try to map the current index to one of the sparse indices.
- for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
- if (flatSparseIndices[i] == index)
- return *std::next(valueIt, i);
- // Otherwise, return the zero value.
- return zeroValue;
- };
- return llvm::map_range(llvm::seq<ptr
diff _t>(0, getNumElements()), mapFn);
- }
-
- /// Return the value of the element at the given index. The 'index' is
- /// expected to refer to a valid element.
- Attribute getValue(ArrayRef<uint64_t> index) const;
-
-private:
- /// Get a zero APFloat for the given sparse attribute.
- APFloat getZeroAPFloat() const;
-
- /// Get a zero APInt for the given sparse attribute.
- APInt getZeroAPInt() const;
-
- /// Get a zero attribute for the given sparse attribute.
- Attribute getZeroAttr() const;
-
- /// Utility methods to generate a zero value of some type 'T'. This is used by
- /// the 'iterator' class.
- /// Get a zero for a given attribute type.
- template <typename T>
- typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
- getZeroValue() const {
- return getZeroAttr().template cast<T>();
- }
- /// Get a zero for an APInt.
- template <typename T>
- typename std::enable_if<std::is_same<APInt, T>::value, T>::type
- getZeroValue() const {
- return getZeroAPInt();
- }
- template <typename T>
- typename std::enable_if<std::is_same<std::complex<APInt>, T>::value, T>::type
- getZeroValue() const {
- APInt intZero = getZeroAPInt();
- return {intZero, intZero};
- }
- /// Get a zero for an APFloat.
- template <typename T>
- typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
- getZeroValue() const {
- return getZeroAPFloat();
- }
- template <typename T>
- typename std::enable_if<std::is_same<std::complex<APFloat>, T>::value,
- T>::type
- getZeroValue() const {
- APFloat floatZero = getZeroAPFloat();
- return {floatZero, floatZero};
- }
-
- /// Get a zero for an C++ integer, float, StringRef, or complex type.
- template <typename T>
- typename std::enable_if<
- std::numeric_limits<T>::is_integer ||
- DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
- std::is_same<T, StringRef>::value ||
- (detail::is_complex_t<T>::value &&
- !llvm::is_one_of<T, std::complex<APInt>,
- std::complex<APFloat>>::value),
- T>::type
- getZeroValue() const {
- return T();
- }
-
- /// Flatten, and return, all of the sparse indices in this attribute in
- /// row-major order.
- std::vector<ptr
diff _t> getFlattenedSparseIndices() const;
-};
-
-/// An attribute that represents a reference to a splat vector or tensor
-/// constant, meaning all of the elements have the same value.
-class SplatElementsAttr : public DenseElementsAttr {
-public:
- using DenseElementsAttr::DenseElementsAttr;
-
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr) {
- auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
- return denseAttr && denseAttr.isSplat();
- }
-};
-
-namespace detail {
-/// This class represents a general iterator over the values of an ElementsAttr.
-/// It supports all subclasses aside from OpaqueElementsAttr.
-template <typename T>
-class ElementsAttrIterator
- : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
- std::random_access_iterator_tag, T,
- std::ptr
diff _t, T, T> {
- // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
- // inside of a conversion operator.
- using DenseIteratorT = typename std::enable_if<
- true,
- decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
- using SparseIteratorT = SparseElementsAttr::iterator<T>;
-
- /// A union containing the specific iterators for each derived attribute kind.
- union Iterator {
- Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {}
- Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {}
- Iterator() {}
- ~Iterator() {}
-
- operator const DenseIteratorT &() const { return denseIt; }
- operator const SparseIteratorT &() const { return sparseIt; }
- operator DenseIteratorT &() { return denseIt; }
- operator SparseIteratorT &() { return sparseIt; }
-
- /// An instance of a dense elements iterator.
- DenseIteratorT denseIt;
- /// An instance of a sparse elements iterator.
- SparseIteratorT sparseIt;
- };
-
- /// Utility method to process a functor on each of the internal iterator
- /// types.
- template <typename RetT, template <typename> class ProcessFn,
- typename... Args>
- RetT process(Args &... args) const {
- if (attr.isa<DenseElementsAttr>())
- return ProcessFn<DenseIteratorT>()(args...);
- if (attr.isa<SparseElementsAttr>())
- return ProcessFn<SparseIteratorT>()(args...);
- llvm_unreachable("unexpected attribute kind");
- }
-
- /// Utility functors used to generically implement the iterators methods.
- template <typename ItT> struct PlusAssign {
- void operator()(ItT &it, ptr
diff _t offset) { it += offset; }
- };
- template <typename ItT> struct Minus {
- ptr
diff _t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
- };
- template <typename ItT> struct MinusAssign {
- void operator()(ItT &it, ptr
diff _t offset) { it -= offset; }
- };
- template <typename ItT> struct Dereference {
- T operator()(ItT &it) { return *it; }
- };
- template <typename ItT> struct ConstructIter {
- void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
- };
- template <typename ItT> struct DestructIter {
- void operator()(ItT &it) { it.~ItT(); }
- };
-
-public:
- ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
- process<void, ConstructIter>(it, rhs.it);
- }
- ~ElementsAttrIterator() { process<void, DestructIter>(it); }
-
- /// Methods necessary to support random access iteration.
- ptr
diff _t operator-(const ElementsAttrIterator<T> &rhs) const {
- assert(attr == rhs.attr && "incompatible iterators");
- return process<ptr
diff _t, Minus>(it, rhs.it);
- }
- bool operator==(const ElementsAttrIterator<T> &rhs) const {
- return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
- }
- bool operator<(const ElementsAttrIterator<T> &rhs) const {
- assert(attr == rhs.attr && "incompatible iterators");
- return process<bool, std::less>(it, rhs.it);
- }
- ElementsAttrIterator<T> &operator+=(ptr
diff _t offset) {
- process<void, PlusAssign>(it, offset);
- return *this;
- }
- ElementsAttrIterator<T> &operator-=(ptr
diff _t offset) {
- process<void, MinusAssign>(it, offset);
- return *this;
- }
-
- /// Dereference the iterator at the current index.
- T operator*() { return process<T, Dereference>(it); }
-
-private:
- template <typename IteratorT>
- ElementsAttrIterator(Attribute attr, IteratorT &&it)
- : attr(attr), it(std::forward<IteratorT>(it)) {}
-
- /// Allow accessing the constructor.
- friend ElementsAttr;
-
- /// The parent elements attribute.
- Attribute attr;
-
- /// A union containing the specific iterators for each derived kind.
- Iterator it;
-};
-
-template <typename T>
-class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
- using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
-};
-} // namespace detail
-
-/// Return the elements of this attribute as a value of type 'T'.
-template <typename T>
-auto ElementsAttr::getValues() const -> iterator_range<T> {
- if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
- auto values = denseAttr.getValues<T>();
- return {iterator<T>(*this, values.begin()),
- iterator<T>(*this, values.end())};
- }
- if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
- auto values = sparseAttr.getValues<T>();
- return {iterator<T>(*this, values.begin()),
- iterator<T>(*this, values.end())};
- }
- llvm_unreachable("unexpected attribute kind");
-}
-
-//===----------------------------------------------------------------------===//
-// Attributes Utils
-//===----------------------------------------------------------------------===//
-
template <typename U> bool Attribute::isa() const {
assert(impl && "isa<> used on a null attribute.");
return U::classof(*this);
@@ -1610,80 +116,58 @@ template <typename U> U Attribute::cast() const {
return U(impl);
}
-// Make Attribute hashable.
inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.impl);
}
//===----------------------------------------------------------------------===//
-// MutableDictionaryAttr
+// NamedAttribute
//===----------------------------------------------------------------------===//
-/// A MutableDictionaryAttr is a mutable wrapper around a DictionaryAttr. It
-/// provides additional interfaces for adding, removing, replacing attributes
-/// within a DictionaryAttr.
-///
-/// We assume there will be relatively few attributes on a given operation
-/// (maybe a dozen or so, but not hundreds or thousands) so we use linear
-/// searches for everything.
-class MutableDictionaryAttr {
-public:
- MutableDictionaryAttr(DictionaryAttr attrs = nullptr)
- : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
- MutableDictionaryAttr(ArrayRef<NamedAttribute> attributes);
-
- bool operator!=(const MutableDictionaryAttr &other) const {
- return !(*this == other);
- }
- bool operator==(const MutableDictionaryAttr &other) const {
- return attrs == other.attrs;
- }
-
- /// Return the underlying dictionary attribute.
- DictionaryAttr getDictionary(MLIRContext *context) const;
-
- /// Return the underlying dictionary attribute or null if there are no
- /// attributes within this dictionary.
- DictionaryAttr getDictionaryOrNull() const { return attrs; }
-
- /// Return all of the attributes on this operation.
- ArrayRef<NamedAttribute> getAttrs() const;
-
- /// Replace the held attributes with ones provided in 'newAttrs'.
- void setAttrs(ArrayRef<NamedAttribute> attributes);
-
- /// Return the specified attribute if present, null otherwise.
- Attribute get(StringRef name) const;
- Attribute get(Identifier name) const;
+/// NamedAttribute is combination of a name, represented by an Identifier, and a
+/// value, represented by an Attribute. The attribute pointer should always be
+/// non-null.
+using NamedAttribute = std::pair<Identifier, Attribute>;
- /// Return the specified named attribute if present, None otherwise.
- Optional<NamedAttribute> getNamed(StringRef name) const;
- Optional<NamedAttribute> getNamed(Identifier name) const;
+bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs);
+bool operator<(const NamedAttribute &lhs, StringRef rhs);
- /// If the an attribute exists with the specified name, change it to the new
- /// value. Otherwise, add a new attribute with the specified name/value.
- void set(Identifier name, Attribute value);
+//===----------------------------------------------------------------------===//
+// AttributeTraitBase
+//===----------------------------------------------------------------------===//
- enum class RemoveResult { Removed, NotFound };
+namespace AttributeTrait {
+/// This class represents the base of an attribute trait.
+template <typename ConcreteType, template <typename> class TraitType>
+using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
+} // namespace AttributeTrait
- /// Remove the attribute with the specified name if it exists. The return
- /// value indicates whether the attribute was present or not.
- RemoveResult remove(Identifier name);
+//===----------------------------------------------------------------------===//
+// AttributeInterface
+//===----------------------------------------------------------------------===//
- bool empty() const { return attrs == nullptr; }
+/// This class represents the base of an attribute interface. See the definition
+/// of `detail::Interface` for requirements on the `Traits` type.
+template <typename ConcreteType, typename Traits>
+class AttributeInterface
+ : public detail::Interface<ConcreteType, Attribute, Traits, Attribute,
+ AttributeTrait::TraitBase> {
+public:
+ using Base = AttributeInterface<ConcreteType, Traits>;
+ using InterfaceBase = detail::Interface<ConcreteType, Attribute, Traits,
+ Attribute, AttributeTrait::TraitBase>;
+ using InterfaceBase::InterfaceBase;
private:
- friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg);
+ /// Returns the impl interface instance for the given type.
+ static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) {
+ return attr.getAbstractAttribute().getInterface<ConcreteType>();
+ }
- DictionaryAttr attrs;
+ /// Allow access to 'getInterfaceFor'.
+ friend InterfaceBase;
};
-inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) {
- if (!arg.attrs)
- return ::llvm::hash_value((void *)nullptr);
- return hash_value(arg.attrs);
-}
-
} // end namespace mlir.
namespace llvm {
@@ -1718,15 +202,6 @@ template <> struct PointerLikeTypeTraits<mlir::Attribute> {
mlir::AttributeStorage *>::NumLowBitsAvailable;
};
-template <>
-struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
- : public PointerLikeTypeTraits<mlir::Attribute> {
- static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
- return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
- .cast<mlir::SymbolRefAttr>();
- }
-};
-
} // namespace llvm
#endif
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
new file mode 100644
index 000000000000..300741a7923b
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -0,0 +1,1559 @@
+//===- BuiltinAttributes.h - MLIR Builtin Attribute Classes -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BUILTINATTRIBUTES_H
+#define MLIR_IR_BUILTINATTRIBUTES_H
+
+#include "mlir/IR/Attributes.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/Sequence.h"
+#include <complex>
+
+namespace mlir {
+class AffineMap;
+class FunctionType;
+class IntegerSet;
+class Location;
+class ShapedType;
+
+namespace detail {
+
+struct AffineMapAttributeStorage;
+struct ArrayAttributeStorage;
+struct DictionaryAttributeStorage;
+struct IntegerAttributeStorage;
+struct IntegerSetAttributeStorage;
+struct FloatAttributeStorage;
+struct OpaqueAttributeStorage;
+struct StringAttributeStorage;
+struct SymbolRefAttributeStorage;
+struct TypeAttributeStorage;
+
+/// Elements Attributes.
+struct DenseIntOrFPElementsAttributeStorage;
+struct DenseStringElementsAttributeStorage;
+struct OpaqueElementsAttributeStorage;
+struct SparseElementsAttributeStorage;
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
+class AffineMapAttr
+ : public Attribute::AttrBase<AffineMapAttr, Attribute,
+ detail::AffineMapAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = AffineMap;
+
+ static AffineMapAttr get(AffineMap value);
+
+ AffineMap getValue() const;
+};
+
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
+/// Array attributes are lists of other attributes. They are not necessarily
+/// type homogenous given that attributes don't, in general, carry types.
+class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
+ detail::ArrayAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = ArrayRef<Attribute>;
+
+ static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
+
+ ArrayRef<Attribute> getValue() const;
+ Attribute operator[](unsigned idx) const;
+
+ /// Support range iteration.
+ using iterator = llvm::ArrayRef<Attribute>::iterator;
+ iterator begin() const { return getValue().begin(); }
+ iterator end() const { return getValue().end(); }
+ size_t size() const { return getValue().size(); }
+ bool empty() const { return size() == 0; }
+
+private:
+ /// Class for underlying value iterator support.
+ template <typename AttrTy>
+ class attr_value_iterator final
+ : public llvm::mapped_iterator<ArrayAttr::iterator,
+ AttrTy (*)(Attribute)> {
+ public:
+ explicit attr_value_iterator(ArrayAttr::iterator it)
+ : llvm::mapped_iterator<ArrayAttr::iterator, AttrTy (*)(Attribute)>(
+ it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
+ AttrTy operator*() const { return (*this->I).template cast<AttrTy>(); }
+ };
+
+public:
+ template <typename AttrTy>
+ iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
+ return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
+ attr_value_iterator<AttrTy>(end()));
+ }
+ template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType>
+ auto getAsValueRange() {
+ return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
+ return static_cast<UnderlyingTy>(attr.getValue());
+ });
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
+/// Dictionary attribute is an attribute that represents a sorted collection of
+/// named attribute values. The elements are sorted by name, and each name must
+/// be unique within the collection.
+class DictionaryAttr
+ : public Attribute::AttrBase<DictionaryAttr, Attribute,
+ detail::DictionaryAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = ArrayRef<NamedAttribute>;
+
+ /// Construct a dictionary attribute with the provided list of named
+ /// attributes. This method assumes that the provided list is unordered. If
+ /// the caller can guarantee that the attributes are ordered by name,
+ /// getWithSorted should be used instead.
+ static DictionaryAttr get(ArrayRef<NamedAttribute> value,
+ MLIRContext *context);
+
+ /// Construct a dictionary with an array of values that is known to already be
+ /// sorted by name and uniqued.
+ static DictionaryAttr getWithSorted(ArrayRef<NamedAttribute> value,
+ MLIRContext *context);
+
+ ArrayRef<NamedAttribute> getValue() const;
+
+ /// Return the specified attribute if present, null otherwise.
+ Attribute get(StringRef name) const;
+ Attribute get(Identifier name) const;
+
+ /// Return the specified named attribute if present, None otherwise.
+ Optional<NamedAttribute> getNamed(StringRef name) const;
+ Optional<NamedAttribute> getNamed(Identifier name) const;
+
+ /// Support range iteration.
+ using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
+ iterator begin() const;
+ iterator end() const;
+ bool empty() const { return size() == 0; }
+ size_t size() const;
+
+ /// Sorts the NamedAttributes in the array ordered by name as expected by
+ /// getWithSorted and returns whether the values were sorted.
+ /// Requires: uniquely named attributes.
+ static bool sort(ArrayRef<NamedAttribute> values,
+ SmallVectorImpl<NamedAttribute> &storage);
+
+ /// Sorts the NamedAttributes in the array ordered by name as expected by
+ /// getWithSorted in place on an array and returns whether the values needed
+ /// to be sorted.
+ /// Requires: uniquely named attributes.
+ static bool sortInPlace(SmallVectorImpl<NamedAttribute> &array);
+
+ /// Returns an entry with a duplicate name in `array`, if it exists, else
+ /// returns llvm::None. If `isSorted` is true, the array is assumed to be
+ /// sorted else it will be sorted in place before finding the duplicate entry.
+ static Optional<NamedAttribute>
+ findDuplicate(SmallVectorImpl<NamedAttribute> &array, bool isSorted);
+
+private:
+ /// Return empty dictionary.
+ static DictionaryAttr getEmpty(MLIRContext *context);
+};
+
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
+class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
+ detail::FloatAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = APFloat;
+
+ /// Return a float attribute for the specified value in the specified type.
+ /// These methods should only be used for simple constant values, e.g 1.0/2.0,
+ /// that are known-valid both as host double and the 'type' format.
+ static FloatAttr get(Type type, double value);
+ static FloatAttr getChecked(Type type, double value, Location loc);
+
+ /// Return a float attribute for the specified value in the specified type.
+ static FloatAttr get(Type type, const APFloat &value);
+ static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
+
+ APFloat getValue() const;
+
+ /// This function is used to convert the value to a double, even if it loses
+ /// precision.
+ double getValueAsDouble() const;
+ static double getValueAsDouble(APFloat val);
+
+ /// Verify the construction invariants for a double value.
+ static LogicalResult verifyConstructionInvariants(Location loc, Type type,
+ double value);
+ static LogicalResult verifyConstructionInvariants(Location loc, Type type,
+ const APFloat &value);
+};
+
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+class IntegerAttr
+ : public Attribute::AttrBase<IntegerAttr, Attribute,
+ detail::IntegerAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = APInt;
+
+ static IntegerAttr get(Type type, int64_t value);
+ static IntegerAttr get(Type type, const APInt &value);
+
+ APInt getValue() const;
+ /// Return the integer value as a 64-bit int. The attribute must be a signless
+ /// integer.
+ // TODO: Change callers to use getValue instead.
+ int64_t getInt() const;
+ /// Return the integer value as a signed 64-bit int. The attribute must be
+ /// a signed integer.
+ int64_t getSInt() const;
+ /// Return the integer value as a unsigned 64-bit int. The attribute must be
+ /// an unsigned integer.
+ uint64_t getUInt() const;
+
+ static LogicalResult verifyConstructionInvariants(Location loc, Type type,
+ int64_t value);
+ static LogicalResult verifyConstructionInvariants(Location loc, Type type,
+ const APInt &value);
+};
+
+//===----------------------------------------------------------------------===//
+// BoolAttr
+
+/// Special case of IntegerAttr to represent boolean integers, i.e., signless i1
+/// integers.
+class BoolAttr : public Attribute {
+public:
+ using Attribute::Attribute;
+ using ValueType = bool;
+
+ static BoolAttr get(bool value, MLIRContext *context);
+
+ /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
+ /// avoid bringing in all of IntegerAttrs methods.
+ operator IntegerAttr() const { return IntegerAttr(impl); }
+
+ /// Return the boolean value of this attribute.
+ bool getValue() const;
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(Attribute attr);
+};
+
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
+class IntegerSetAttr
+ : public Attribute::AttrBase<IntegerSetAttr, Attribute,
+ detail::IntegerSetAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = IntegerSet;
+
+ static IntegerSetAttr get(IntegerSet value);
+
+ IntegerSet getValue() const;
+};
+
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
+/// Opaque attributes represent attributes of non-registered dialects. These are
+/// attribute represented in their raw string form, and can only usefully be
+/// tested for attribute equality.
+class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
+ detail::OpaqueAttributeStorage> {
+public:
+ using Base::Base;
+
+ /// Get or create a new OpaqueAttr with the provided dialect and string data.
+ static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
+ MLIRContext *context);
+
+ /// Get or create a new OpaqueAttr with the provided dialect and string data.
+ /// If the given identifier is not a valid namespace for a dialect, then a
+ /// null attribute is returned.
+ static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
+ Type type, Location location);
+
+ /// Returns the dialect namespace of the opaque attribute.
+ Identifier getDialectNamespace() const;
+
+ /// Returns the raw attribute data of the opaque attribute.
+ StringRef getAttrData() const;
+
+ /// Verify the construction of an opaque attribute.
+ static LogicalResult verifyConstructionInvariants(Location loc,
+ Identifier dialect,
+ StringRef attrData,
+ Type type);
+};
+
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
+ detail::StringAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = StringRef;
+
+ /// Get an instance of a StringAttr with the given string.
+ static StringAttr get(StringRef bytes, MLIRContext *context);
+
+ /// Get an instance of a StringAttr with the given string and Type.
+ static StringAttr get(StringRef bytes, Type type);
+
+ StringRef getValue() const;
+};
+
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+class FlatSymbolRefAttr;
+
+/// A symbol reference attribute represents a symbolic reference to another
+/// operation.
+class SymbolRefAttr
+ : public Attribute::AttrBase<SymbolRefAttr, Attribute,
+ detail::SymbolRefAttributeStorage> {
+public:
+ using Base::Base;
+
+ /// Construct a symbol reference for the given value name.
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
+
+ /// Construct a symbol reference for the given value name, and a set of nested
+ /// references that are further resolve to a nested symbol.
+ static SymbolRefAttr get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> references,
+ MLIRContext *ctx);
+
+ /// Returns the name of the top level symbol reference, i.e. the root of the
+ /// reference path.
+ StringRef getRootReference() const;
+
+ /// Returns the name of the fully resolved symbol, i.e. the leaf of the
+ /// reference path.
+ StringRef getLeafReference() const;
+
+ /// Returns the set of nested references representing the path to the symbol
+ /// nested under the root reference.
+ ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
+};
+
+/// A symbol reference with a reference path containing a single element. This
+/// is used to refer to an operation within the current symbol table.
+class FlatSymbolRefAttr : public SymbolRefAttr {
+public:
+ using SymbolRefAttr::SymbolRefAttr;
+ using ValueType = StringRef;
+
+ /// Construct a symbol reference for the given value name.
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
+ return SymbolRefAttr::get(value, ctx);
+ }
+
+ /// Returns the name of the held symbol reference.
+ StringRef getValue() const { return getRootReference(); }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(Attribute attr) {
+ SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
+ return refAttr && refAttr.getNestedReferences().empty();
+ }
+
+private:
+ using SymbolRefAttr::get;
+ using SymbolRefAttr::getNestedReferences;
+};
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
+class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
+ detail::TypeAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = Type;
+
+ static TypeAttr get(Type value);
+
+ Type getValue() const;
+};
+
+//===----------------------------------------------------------------------===//
+// UnitAttr
+//===----------------------------------------------------------------------===//
+
+/// Unit attributes are attributes that hold no specific value and are given
+/// meaning by their existence.
+class UnitAttr
+ : public Attribute::AttrBase<UnitAttr, Attribute, AttributeStorage> {
+public:
+ using Base::Base;
+
+ static UnitAttr get(MLIRContext *context);
+};
+
+//===----------------------------------------------------------------------===//
+// Elements Attributes
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+template <typename T>
+class ElementsAttrIterator;
+template <typename T>
+class ElementsAttrRange;
+} // namespace detail
+
+/// A base attribute that represents a reference to a static shaped tensor or
+/// vector constant.
+class ElementsAttr : public Attribute {
+public:
+ using Attribute::Attribute;
+ template <typename T>
+ using iterator = detail::ElementsAttrIterator<T>;
+ template <typename T>
+ using iterator_range = detail::ElementsAttrRange<T>;
+
+ /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
+ /// with static shape.
+ ShapedType getType() const;
+
+ /// Return the value at the given index. The index is expected to refer to a
+ /// valid element.
+ Attribute getValue(ArrayRef<uint64_t> index) const;
+
+ /// Return the value of type 'T' at the given index, where 'T' corresponds to
+ /// an Attribute type.
+ template <typename T>
+ T getValue(ArrayRef<uint64_t> index) const {
+ return getValue(index).template cast<T>();
+ }
+
+ /// Return the elements of this attribute as a value of type 'T'. Note:
+ /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
+ /// iteration.
+ template <typename T>
+ iterator_range<T> getValues() const;
+
+ /// Return if the given 'index' refers to a valid element in this attribute.
+ bool isValidIndex(ArrayRef<uint64_t> index) const;
+
+ /// Returns the number of elements held by this attribute.
+ int64_t getNumElements() const;
+
+ /// Returns the number of elements held by this attribute.
+ int64_t size() const { return getNumElements(); }
+
+ /// Generates a new ElementsAttr by mapping each int value to a new
+ /// underlying APInt. The new values can represent either an integer or float.
+ /// This ElementsAttr should contain integers.
+ ElementsAttr mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Generates a new ElementsAttr by mapping each float value to a new
+ /// underlying APInt. The new values can represent either an integer or float.
+ /// This ElementsAttr should contain floats.
+ ElementsAttr mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const;
+
+ /// Method for support type inquiry through isa, cast and dyn_cast.
+ static bool classof(Attribute attr);
+
+protected:
+ /// Returns the 1 dimensional flattened row-major index from the given
+ /// multi-dimensional index.
+ uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
+};
+
+namespace detail {
+/// DenseElementsAttr data is aligned to uint64_t, so this traits class is
+/// necessary to interop with PointerIntPair.
+class DenseElementDataPointerTypeTraits {
+public:
+ static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
+ static inline const char *getFromVoidPointer(const void *ptr) {
+ return static_cast<const char *>(ptr);
+ }
+
+ // Note: We could steal more bits if the need arises.
+ static constexpr int NumLowBitsAvailable = 1;
+};
+
+/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
+using DenseIterPtrAndSplat =
+ llvm::PointerIntPair<const char *, 1, bool,
+ DenseElementDataPointerTypeTraits>;
+
+/// Impl iterator for indexed DenseElementsAttr iterators that records a data
+/// pointer and data index that is adjusted for the case of a splat attribute.
+template <typename ConcreteT, typename T, typename PointerT = T *,
+ typename ReferenceT = T &>
+class DenseElementIndexedIteratorImpl
+ : public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
+ PointerT, ReferenceT> {
+protected:
+ DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
+ size_t dataIndex)
+ : llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
+ PointerT, ReferenceT>({data, isSplat},
+ dataIndex) {}
+
+ /// Return the current index for this iterator, adjusted for the case of a
+ /// splat.
+ ptr
diff _t getDataIndex() const {
+ bool isSplat = this->base.getInt();
+ return isSplat ? 0 : this->index;
+ }
+
+ /// Return the data base pointer.
+ const char *getData() const { return this->base.getPointer(); }
+};
+
+/// Type trait detector that checks if a given type T is a complex type.
+template <typename T>
+struct is_complex_t : public std::false_type {};
+template <typename T>
+struct is_complex_t<std::complex<T>> : public std::true_type {};
+} // namespace detail
+
+/// An attribute that represents a reference to a dense vector or tensor object.
+///
+class DenseElementsAttr : public ElementsAttr {
+public:
+ using ElementsAttr::ElementsAttr;
+
+ /// Type trait used to check if the given type T is a potentially valid C++
+ /// floating point type that can be used to access the underlying element
+ /// types of a DenseElementsAttr.
+ // TODO: Use std::disjunction when C++17 is supported.
+ template <typename T>
+ struct is_valid_cpp_fp_type {
+ /// The type is a valid floating point type if it is a builtin floating
+ /// point type, or is a potentially user defined floating point type. The
+ /// latter allows for supporting users that have custom types defined for
+ /// bfloat16/half/etc.
+ static constexpr bool value = llvm::is_one_of<T, float, double>::value ||
+ (std::numeric_limits<T>::is_specialized &&
+ !std::numeric_limits<T>::is_integer);
+ };
+
+ /// Method for support type inquiry through isa, cast and dyn_cast.
+ static bool classof(Attribute attr);
+
+ /// Constructs a dense elements attribute from an array of element values.
+ /// Each element attribute value is expected to be an element of 'type'.
+ /// 'type' must be a vector or tensor with static shape. If the element of
+ /// `type` is non-integer/index/float it is assumed to be a string type.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
+
+ /// Constructs a dense integer elements attribute from an array of integer
+ /// or floating-point values. Each value is expected to be the same bitwidth
+ /// of the element type of 'type'. 'type' must be a vector or tensor with
+ /// static shape.
+ template <typename T, typename = typename std::enable_if<
+ std::numeric_limits<T>::is_integer ||
+ is_valid_cpp_fp_type<T>::value>::type>
+ static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawIntOrFloat(
+ type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
+ std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed);
+ }
+
+ /// Constructs a dense integer elements attribute from a single element.
+ template <typename T, typename = typename std::enable_if<
+ std::numeric_limits<T>::is_integer ||
+ is_valid_cpp_fp_type<T>::value ||
+ detail::is_complex_t<T>::value>::type>
+ static DenseElementsAttr get(const ShapedType &type, T value) {
+ return get(type, llvm::makeArrayRef(value));
+ }
+
+ /// Constructs a dense complex elements attribute from an array of complex
+ /// values. Each value is expected to be the same bitwidth of the element type
+ /// of 'type'. 'type' must be a vector or tensor with static shape.
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = typename std::enable_if<
+ detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ is_valid_cpp_fp_type<ElementT>::value)>::type>
+ static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
+ sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed);
+ }
+
+ /// Overload of the above 'get' method that is specialized for boolean values.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
+
+ /// Overload of the above 'get' method that is specialized for StringRef
+ /// values.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
+
+ /// Constructs a dense integer elements attribute from an array of APInt
+ /// values. Each APInt value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
+
+ /// Constructs a dense complex elements attribute from an array of APInt
+ /// values. Each APInt value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APInt>> values);
+
+ /// Constructs a dense float elements attribute from an array of APFloat
+ /// values. Each APFloat value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
+
+ /// Constructs a dense complex elements attribute from an array of APFloat
+ /// values. Each APFloat value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APFloat>> values);
+
+ /// Construct a dense elements attribute for an initializer_list of values.
+ /// Each value is expected to be the same bitwidth of the element type of
+ /// 'type'. 'type' must be a vector or tensor with static shape.
+ template <typename T>
+ static DenseElementsAttr get(const ShapedType &type,
+ const std::initializer_list<T> &list) {
+ return get(type, ArrayRef<T>(list));
+ }
+
+ /// Construct a dense elements attribute from a raw buffer representing the
+ /// data for this attribute. Users should generally not use this methods as
+ /// the expected buffer format may not be a form the user expects.
+ static DenseElementsAttr getFromRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool isSplatBuffer);
+
+ /// Returns true if the given buffer is a valid raw buffer for the given type.
+ /// `detectedSplat` is set if the buffer is valid and represents a splat
+ /// buffer.
+ static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
+ bool &detectedSplat);
+
+ //===--------------------------------------------------------------------===//
+ // Iterators
+ //===--------------------------------------------------------------------===//
+
+ /// A utility iterator that allows walking over the internal Attribute values
+ /// of a DenseElementsAttr.
+ class AttributeElementIterator
+ : public llvm::indexed_accessor_iterator<AttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute> {
+ public:
+ /// Accesses the Attribute value at this iterator position.
+ Attribute operator*() const;
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ AttributeElementIterator(DenseElementsAttr attr, size_t index);
+ };
+
+ /// Iterator for walking raw element values of the specified type 'T', which
+ /// may be any c++ data type matching the stored representation: int32_t,
+ /// float, etc.
+ template <typename T>
+ class ElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
+ const T> {
+ public:
+ /// Accesses the raw value at this iterator position.
+ const T &operator*() const {
+ return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
+ }
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ ElementIterator(const char *data, bool isSplat, size_t dataIndex)
+ : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
+ data, isSplat, dataIndex) {}
+ };
+
+ /// A utility iterator that allows walking over the internal bool values.
+ class BoolElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
+ bool, bool, bool> {
+ public:
+ /// Accesses the bool value at this iterator position.
+ bool operator*() const;
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
+ };
+
+ /// A utility iterator that allows walking over the internal raw APInt values.
+ class IntElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
+ APInt, APInt, APInt> {
+ public:
+ /// Accesses the raw APInt value at this iterator position.
+ APInt operator*() const;
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+
+ /// The bitwidth of the element type.
+ size_t bitWidth;
+ };
+
+ /// A utility iterator that allows walking over the internal raw complex APInt
+ /// values.
+ class ComplexIntElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<
+ ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
+ std::complex<APInt>> {
+ public:
+ /// Accesses the raw std::complex<APInt> value at this iterator position.
+ std::complex<APInt> operator*() const;
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+
+ /// The bitwidth of the element type.
+ size_t bitWidth;
+ };
+
+ /// Iterator for walking over APFloat values.
+ class FloatElementIterator final
+ : public llvm::mapped_iterator<IntElementIterator,
+ std::function<APFloat(const APInt &)>> {
+ friend DenseElementsAttr;
+
+ /// Initializes the float element iterator to the specified iterator.
+ FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
+
+ public:
+ using reference = APFloat;
+ };
+
+ /// Iterator for walking over complex APFloat values.
+ class ComplexFloatElementIterator final
+ : public llvm::mapped_iterator<
+ ComplexIntElementIterator,
+ std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
+ friend DenseElementsAttr;
+
+ /// Initializes the float element iterator to the specified iterator.
+ ComplexFloatElementIterator(const llvm::fltSemantics &smt,
+ ComplexIntElementIterator it);
+
+ public:
+ using reference = std::complex<APFloat>;
+ };
+
+ //===--------------------------------------------------------------------===//
+ // Value Querying
+ //===--------------------------------------------------------------------===//
+
+ /// Returns true if this attribute corresponds to a splat, i.e. if all element
+ /// values are the same.
+ bool isSplat() const;
+
+ /// Return the splat value for this attribute. This asserts that the attribute
+ /// corresponds to a splat.
+ Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
+ template <typename T>
+ typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
+ std::is_same<Attribute, T>::value,
+ T>::type
+ getSplatValue() const {
+ assert(isSplat() && "expected the attribute to be a splat");
+ return *getValues<T>().begin();
+ }
+ /// Return the splat value for derived attribute element types.
+ template <typename T>
+ typename std::enable_if<std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value,
+ T>::type
+ getSplatValue() const {
+ return getSplatValue().template cast<T>();
+ }
+
+ /// Return the value at the given index. The 'index' is expected to refer to a
+ /// valid element.
+ Attribute getValue(ArrayRef<uint64_t> index) const {
+ return getValue<Attribute>(index);
+ }
+ template <typename T>
+ T getValue(ArrayRef<uint64_t> index) const {
+ // Skip to the element corresponding to the flattened index.
+ return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
+ }
+
+ /// Return the held element values as a range of integer or floating-point
+ /// values.
+ template <typename T, typename = typename std::enable_if<
+ (!std::is_same<T, bool>::value &&
+ std::numeric_limits<T>::is_integer) ||
+ is_valid_cpp_fp_type<T>::value>::type>
+ llvm::iterator_range<ElementIterator<T>> getValues() const {
+ assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed));
+ const char *rawData = getRawData().data();
+ bool splat = isSplat();
+ return {ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements())};
+ }
+
+ /// Return the held element values as a range of std::complex.
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = typename std::enable_if<
+ detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ is_valid_cpp_fp_type<ElementT>::value)>::type>
+ llvm::iterator_range<ElementIterator<T>> getValues() const {
+ assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed));
+ const char *rawData = getRawData().data();
+ bool splat = isSplat();
+ return {ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements())};
+ }
+
+ /// Return the held element values as a range of StringRef.
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, StringRef>::value>::type>
+ llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
+ auto stringRefs = getRawStringData();
+ const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
+ bool splat = isSplat();
+ return {ElementIterator<StringRef>(ptr, splat, 0),
+ ElementIterator<StringRef>(ptr, splat, getNumElements())};
+ }
+
+ /// Return the held element values as a range of Attributes.
+ llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, Attribute>::value>::type>
+ llvm::iterator_range<AttributeElementIterator> getValues() const {
+ return getAttributeValues();
+ }
+ AttributeElementIterator attr_value_begin() const;
+ AttributeElementIterator attr_value_end() const;
+
+ /// Return the held element values a range of T, where T is a derived
+ /// attribute type.
+ template <typename T>
+ using DerivedAttributeElementIterator =
+ llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
+ template <typename T, typename = typename std::enable_if<
+ std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value>::type>
+ llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
+ auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+ return llvm::map_range(getAttributeValues(),
+ static_cast<T (*)(Attribute)>(castFn));
+ }
+
+ /// Return the held element values as a range of bool. The element type of
+ /// this attribute must be of integer type of bitwidth 1.
+ llvm::iterator_range<BoolElementIterator> getBoolValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, bool>::value>::type>
+ llvm::iterator_range<BoolElementIterator> getValues() const {
+ return getBoolValues();
+ }
+
+ /// Return the held element values as a range of APInts. The element type of
+ /// this attribute must be of integer type.
+ llvm::iterator_range<IntElementIterator> getIntValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, APInt>::value>::type>
+ llvm::iterator_range<IntElementIterator> getValues() const {
+ return getIntValues();
+ }
+ IntElementIterator int_value_begin() const;
+ IntElementIterator int_value_end() const;
+
+ /// Return the held element values as a range of complex APInts. The element
+ /// type of this attribute must be a complex of integer type.
+ llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, std::complex<APInt>>::value>::type>
+ llvm::iterator_range<ComplexIntElementIterator> getValues() const {
+ return getComplexIntValues();
+ }
+
+ /// Return the held element values as a range of APFloat. The element type of
+ /// this attribute must be of float type.
+ llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, APFloat>::value>::type>
+ llvm::iterator_range<FloatElementIterator> getValues() const {
+ return getFloatValues();
+ }
+ FloatElementIterator float_value_begin() const;
+ FloatElementIterator float_value_end() const;
+
+ /// Return the held element values as a range of complex APFloat. The element
+ /// type of this attribute must be a complex of float type.
+ llvm::iterator_range<ComplexFloatElementIterator>
+ getComplexFloatValues() const;
+ template <typename T, typename = typename std::enable_if<std::is_same<
+ T, std::complex<APFloat>>::value>::type>
+ llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
+ return getComplexFloatValues();
+ }
+
+ /// Return the raw storage data held by this attribute. Users should generally
+ /// not use this directly, as the internal storage format is not always in the
+ /// form the user might expect.
+ ArrayRef<char> getRawData() const;
+
+ /// Return the raw StringRef data held by this attribute.
+ ArrayRef<StringRef> getRawStringData() const;
+
+ //===--------------------------------------------------------------------===//
+ // Mutation Utilities
+ //===--------------------------------------------------------------------===//
+
+ /// Return a new DenseElementsAttr that has the same data as the current
+ /// attribute, but has been reshaped to 'newType'. The new type must have the
+ /// same total number of elements as well as element type.
+ DenseElementsAttr reshape(ShapedType newType);
+
+ /// Generates a new DenseElementsAttr by mapping each int value to a new
+ /// underlying APInt. The new values can represent either an integer or float.
+ /// This underlying type must be an DenseIntElementsAttr.
+ DenseElementsAttr mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Generates a new DenseElementsAttr by mapping each float value to a new
+ /// underlying APInt. the new values can represent either an integer or float.
+ /// This underlying type must be an DenseFPElementsAttr.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const;
+
+protected:
+ /// Get iterators to the raw APInt values for each element in this attribute.
+ IntElementIterator raw_int_begin() const {
+ return IntElementIterator(*this, 0);
+ }
+ IntElementIterator raw_int_end() const {
+ return IntElementIterator(*this, getNumElements());
+ }
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// complex type. This method is used to verify type invariants that the
+ /// templatized 'get' method cannot.
+ static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// integer or floating-point type. This method is used to verify type
+ /// invariants that the templatized 'get' method cannot.
+ static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+
+ /// Check the information for a C++ data type, check if this type is valid for
+ /// the current attribute. This method is used to verify specific type
+ /// invariants that the templatized 'getValues' method cannot.
+ bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
+
+ /// Check the information for a C++ data type, check if this type is valid for
+ /// the current attribute. This method is used to verify specific type
+ /// invariants that the templatized 'getValues' method cannot.
+ bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
+};
+
+/// An attribute class for representing dense arrays of strings. The structure
+/// storing and querying a list of densely packed strings.
+class DenseStringElementsAttr
+ : public Attribute::AttrBase<DenseStringElementsAttr, DenseElementsAttr,
+ detail::DenseStringElementsAttributeStorage> {
+
+public:
+ using Base::Base;
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// integer or floating-point type. This method is used to verify type
+ /// invariants that the templatized 'get' method cannot.
+ static DenseStringElementsAttr get(ShapedType type, ArrayRef<StringRef> data);
+
+protected:
+ friend DenseElementsAttr;
+};
+
+/// An attribute class for specializing behavior of Int and Floating-point
+/// densely packed string arrays.
+class DenseIntOrFPElementsAttr
+ : public Attribute::AttrBase<DenseIntOrFPElementsAttr, DenseElementsAttr,
+ detail::DenseIntOrFPElementsAttributeStorage> {
+
+public:
+ using Base::Base;
+
+ /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
+ /// the elements of `inRawData` has `type`. If `inRawData` is little endian
+ /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is
+ /// BE, converted to LE.
+ static void
+ convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,
+ MutableArrayRef<char> outRawData,
+ ShapedType type);
+
+ /// Convert endianess of input for big-endian(BE) machines. The number of
+ /// elements of `inRawData` is `numElements`, and each element has
+ /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is
+ /// converted to big endian (BE) and saved in `outRawData`. Conversely, if
+ /// `inRawData` is BE, converted to LE.
+ static void convertEndianOfCharForBEmachine(const char *inRawData,
+ char *outRawData,
+ size_t elementBitWidth,
+ size_t numElements);
+
+protected:
+ friend DenseElementsAttr;
+
+ /// Constructs a dense elements attribute from an array of raw APFloat values.
+ /// Each APFloat value is expected to have the same bitwidth as the element
+ /// type of 'type'. 'type' must be a vector or tensor with static shape.
+ static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+ ArrayRef<APFloat> values, bool isSplat);
+
+ /// Constructs a dense elements attribute from an array of raw APInt values.
+ /// Each APInt value is expected to have the same bitwidth as the element type
+ /// of 'type'. 'type' must be a vector or tensor with static shape.
+ static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+ ArrayRef<APInt> values, bool isSplat);
+
+ /// Get or create a new dense elements attribute instance with the given raw
+ /// data buffer. 'type' must be a vector or tensor with static shape.
+ static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
+ bool isSplat);
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// complex type. This method is used to verify type invariants that the
+ /// templatized 'get' method cannot.
+ static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// integer or floating-point type. This method is used to verify type
+ /// invariants that the templatized 'get' method cannot.
+ static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+};
+
+/// An attribute that represents a reference to a dense float vector or tensor
+/// object. Each element is stored as a double.
+class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
+public:
+ using iterator = DenseElementsAttr::FloatElementIterator;
+
+ using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
+
+ /// Get an instance of a DenseFPElementsAttr with the given arguments. This
+ /// simply wraps the DenseElementsAttr::get calls.
+ template <typename Arg>
+ static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
+ return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
+ .template cast<DenseFPElementsAttr>();
+ }
+ template <typename T>
+ static DenseFPElementsAttr get(const ShapedType &type,
+ const std::initializer_list<T> &list) {
+ return DenseElementsAttr::get(type, list)
+ .template cast<DenseFPElementsAttr>();
+ }
+
+ /// Generates a new DenseElementsAttr by mapping each value attribute, and
+ /// constructing the DenseElementsAttr given the new element type.
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const;
+
+ /// Iterator access to the float element values.
+ iterator begin() const { return float_value_begin(); }
+ iterator end() const { return float_value_end(); }
+
+ /// Method for supporting type inquiry through isa, cast and dyn_cast.
+ static bool classof(Attribute attr);
+};
+
+/// An attribute that represents a reference to a dense integer vector or tensor
+/// object.
+class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
+public:
+ /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
+ /// iterator directly.
+ using iterator = DenseElementsAttr::IntElementIterator;
+
+ using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
+
+ /// Get an instance of a DenseIntElementsAttr with the given arguments. This
+ /// simply wraps the DenseElementsAttr::get calls.
+ template <typename Arg>
+ static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
+ return DenseElementsAttr::get(type, llvm::makeArrayRef(arg))
+ .template cast<DenseIntElementsAttr>();
+ }
+ template <typename T>
+ static DenseIntElementsAttr get(const ShapedType &type,
+ const std::initializer_list<T> &list) {
+ return DenseElementsAttr::get(type, list)
+ .template cast<DenseIntElementsAttr>();
+ }
+
+ /// Generates a new DenseElementsAttr by mapping each value attribute, and
+ /// constructing the DenseElementsAttr given the new element type.
+ DenseElementsAttr mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const;
+
+ /// Iterator access to the integer element values.
+ iterator begin() const { return raw_int_begin(); }
+ iterator end() const { return raw_int_end(); }
+
+ /// Method for supporting type inquiry through isa, cast and dyn_cast.
+ static bool classof(Attribute attr);
+};
+
+/// An opaque attribute that represents a reference to a vector or tensor
+/// constant with opaque content. This representation is for tensor constants
+/// which the compiler may not need to interpret. This attribute is always
+/// associated with a particular dialect, which provides a method to convert
+/// tensor representation to a non-opaque format.
+class OpaqueElementsAttr
+ : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr,
+ detail::OpaqueElementsAttributeStorage> {
+public:
+ using Base::Base;
+ using ValueType = StringRef;
+
+ static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
+ StringRef bytes);
+
+ StringRef getValue() const;
+
+ /// Return the value at the given index. The 'index' is expected to refer to a
+ /// valid element.
+ Attribute getValue(ArrayRef<uint64_t> index) const;
+
+ /// Decodes the attribute value using dialect-specific decoding hook.
+ /// Returns false if decoding is successful. If not, returns true and leaves
+ /// 'result' argument unspecified.
+ bool decode(ElementsAttr &result);
+
+ /// Returns dialect associated with this opaque constant.
+ Dialect *getDialect() const;
+};
+
+/// An attribute that represents a reference to a sparse vector or tensor
+/// object.
+///
+/// This class uses COO (coordinate list) encoding to represent the sparse
+/// elements in an element attribute. Specifically, the sparse vector/tensor
+/// stores the indices and values as two separate dense elements attributes of
+/// tensor type (even if the sparse attribute is of vector type, in order to
+/// support empty lists). The dense elements attribute indices is a 2-D tensor
+/// of 64-bit integer elements with shape [N, ndims], which specifies the
+/// indices of the elements in the sparse tensor that contains nonzero values.
+/// The dense elements attribute values is a 1-D tensor with shape [N], and it
+/// supplies the corresponding values for the indices.
+///
+/// For example,
+/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
+/// [[1, 0, 0, 0],
+/// [0, 0, 5, 0],
+/// [0, 0, 0, 0]].
+class SparseElementsAttr
+ : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr,
+ detail::SparseElementsAttributeStorage> {
+public:
+ using Base::Base;
+
+ template <typename T>
+ using iterator =
+ llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptr
diff _t>,
+ std::function<T(ptr
diff _t)>>;
+
+ /// 'type' must be a vector or tensor with static shape.
+ static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
+ DenseElementsAttr values);
+
+ DenseIntElementsAttr getIndices() const;
+
+ DenseElementsAttr getValues() const;
+
+ /// Return the values of this attribute in the form of the given type 'T'. 'T'
+ /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
+ template <typename T>
+ llvm::iterator_range<iterator<T>> getValues() const {
+ auto zeroValue = getZeroValue<T>();
+ auto valueIt = getValues().getValues<T>().begin();
+ const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
+ // TODO: Move-capture flatSparseIndices when c++14 is available.
+ std::function<T(ptr
diff _t)> mapFn = [=](ptr
diff _t index) {
+ // Try to map the current index to one of the sparse indices.
+ for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)
+ if (flatSparseIndices[i] == index)
+ return *std::next(valueIt, i);
+ // Otherwise, return the zero value.
+ return zeroValue;
+ };
+ return llvm::map_range(llvm::seq<ptr
diff _t>(0, getNumElements()), mapFn);
+ }
+
+ /// Return the value of the element at the given index. The 'index' is
+ /// expected to refer to a valid element.
+ Attribute getValue(ArrayRef<uint64_t> index) const;
+
+private:
+ /// Get a zero APFloat for the given sparse attribute.
+ APFloat getZeroAPFloat() const;
+
+ /// Get a zero APInt for the given sparse attribute.
+ APInt getZeroAPInt() const;
+
+ /// Get a zero attribute for the given sparse attribute.
+ Attribute getZeroAttr() const;
+
+ /// Utility methods to generate a zero value of some type 'T'. This is used by
+ /// the 'iterator' class.
+ /// Get a zero for a given attribute type.
+ template <typename T>
+ typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type
+ getZeroValue() const {
+ return getZeroAttr().template cast<T>();
+ }
+ /// Get a zero for an APInt.
+ template <typename T>
+ typename std::enable_if<std::is_same<APInt, T>::value, T>::type
+ getZeroValue() const {
+ return getZeroAPInt();
+ }
+ template <typename T>
+ typename std::enable_if<std::is_same<std::complex<APInt>, T>::value, T>::type
+ getZeroValue() const {
+ APInt intZero = getZeroAPInt();
+ return {intZero, intZero};
+ }
+ /// Get a zero for an APFloat.
+ template <typename T>
+ typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
+ getZeroValue() const {
+ return getZeroAPFloat();
+ }
+ template <typename T>
+ typename std::enable_if<std::is_same<std::complex<APFloat>, T>::value,
+ T>::type
+ getZeroValue() const {
+ APFloat floatZero = getZeroAPFloat();
+ return {floatZero, floatZero};
+ }
+
+ /// Get a zero for an C++ integer, float, StringRef, or complex type.
+ template <typename T>
+ typename std::enable_if<
+ std::numeric_limits<T>::is_integer ||
+ DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
+ std::is_same<T, StringRef>::value ||
+ (detail::is_complex_t<T>::value &&
+ !llvm::is_one_of<T, std::complex<APInt>,
+ std::complex<APFloat>>::value),
+ T>::type
+ getZeroValue() const {
+ return T();
+ }
+
+ /// Flatten, and return, all of the sparse indices in this attribute in
+ /// row-major order.
+ std::vector<ptr
diff _t> getFlattenedSparseIndices() const;
+};
+
+/// An attribute that represents a reference to a splat vector or tensor
+/// constant, meaning all of the elements have the same value.
+class SplatElementsAttr : public DenseElementsAttr {
+public:
+ using DenseElementsAttr::DenseElementsAttr;
+
+ /// Method for support type inquiry through isa, cast and dyn_cast.
+ static bool classof(Attribute attr) {
+ auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
+ return denseAttr && denseAttr.isSplat();
+ }
+};
+
+namespace detail {
+/// This class represents a general iterator over the values of an ElementsAttr.
+/// It supports all subclasses aside from OpaqueElementsAttr.
+template <typename T>
+class ElementsAttrIterator
+ : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
+ std::random_access_iterator_tag, T,
+ std::ptr
diff _t, T, T> {
+ // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
+ // inside of a conversion operator.
+ using DenseIteratorT = typename std::enable_if<
+ true,
+ decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
+ using SparseIteratorT = SparseElementsAttr::iterator<T>;
+
+ /// A union containing the specific iterators for each derived attribute kind.
+ union Iterator {
+ Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {}
+ Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {}
+ Iterator() {}
+ ~Iterator() {}
+
+ operator const DenseIteratorT &() const { return denseIt; }
+ operator const SparseIteratorT &() const { return sparseIt; }
+ operator DenseIteratorT &() { return denseIt; }
+ operator SparseIteratorT &() { return sparseIt; }
+
+ /// An instance of a dense elements iterator.
+ DenseIteratorT denseIt;
+ /// An instance of a sparse elements iterator.
+ SparseIteratorT sparseIt;
+ };
+
+ /// Utility method to process a functor on each of the internal iterator
+ /// types.
+ template <typename RetT, template <typename> class ProcessFn,
+ typename... Args>
+ RetT process(Args &...args) const {
+ if (attr.isa<DenseElementsAttr>())
+ return ProcessFn<DenseIteratorT>()(args...);
+ if (attr.isa<SparseElementsAttr>())
+ return ProcessFn<SparseIteratorT>()(args...);
+ llvm_unreachable("unexpected attribute kind");
+ }
+
+ /// Utility functors used to generically implement the iterators methods.
+ template <typename ItT>
+ struct PlusAssign {
+ void operator()(ItT &it, ptr
diff _t offset) { it += offset; }
+ };
+ template <typename ItT>
+ struct Minus {
+ ptr
diff _t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
+ };
+ template <typename ItT>
+ struct MinusAssign {
+ void operator()(ItT &it, ptr
diff _t offset) { it -= offset; }
+ };
+ template <typename ItT>
+ struct Dereference {
+ T operator()(ItT &it) { return *it; }
+ };
+ template <typename ItT>
+ struct ConstructIter {
+ void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
+ };
+ template <typename ItT>
+ struct DestructIter {
+ void operator()(ItT &it) { it.~ItT(); }
+ };
+
+public:
+ ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
+ process<void, ConstructIter>(it, rhs.it);
+ }
+ ~ElementsAttrIterator() { process<void, DestructIter>(it); }
+
+ /// Methods necessary to support random access iteration.
+ ptr
diff _t operator-(const ElementsAttrIterator<T> &rhs) const {
+ assert(attr == rhs.attr && "incompatible iterators");
+ return process<ptr
diff _t, Minus>(it, rhs.it);
+ }
+ bool operator==(const ElementsAttrIterator<T> &rhs) const {
+ return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
+ }
+ bool operator<(const ElementsAttrIterator<T> &rhs) const {
+ assert(attr == rhs.attr && "incompatible iterators");
+ return process<bool, std::less>(it, rhs.it);
+ }
+ ElementsAttrIterator<T> &operator+=(ptr
diff _t offset) {
+ process<void, PlusAssign>(it, offset);
+ return *this;
+ }
+ ElementsAttrIterator<T> &operator-=(ptr
diff _t offset) {
+ process<void, MinusAssign>(it, offset);
+ return *this;
+ }
+
+ /// Dereference the iterator at the current index.
+ T operator*() { return process<T, Dereference>(it); }
+
+private:
+ template <typename IteratorT>
+ ElementsAttrIterator(Attribute attr, IteratorT &&it)
+ : attr(attr), it(std::forward<IteratorT>(it)) {}
+
+ /// Allow accessing the constructor.
+ friend ElementsAttr;
+
+ /// The parent elements attribute.
+ Attribute attr;
+
+ /// A union containing the specific iterators for each derived kind.
+ Iterator it;
+};
+
+template <typename T>
+class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
+ using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
+};
+} // namespace detail
+
+/// Return the elements of this attribute as a value of type 'T'.
+template <typename T>
+auto ElementsAttr::getValues() const -> iterator_range<T> {
+ if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
+ auto values = denseAttr.getValues<T>();
+ return {iterator<T>(*this, values.begin()),
+ iterator<T>(*this, values.end())};
+ }
+ if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
+ auto values = sparseAttr.getValues<T>();
+ return {iterator<T>(*this, values.begin()),
+ iterator<T>(*this, values.end())};
+ }
+ llvm_unreachable("unexpected attribute kind");
+}
+
+//===----------------------------------------------------------------------===//
+// MutableDictionaryAttr
+//===----------------------------------------------------------------------===//
+
+/// A MutableDictionaryAttr is a mutable wrapper around a DictionaryAttr. It
+/// provides additional interfaces for adding, removing, replacing attributes
+/// within a DictionaryAttr.
+///
+/// We assume there will be relatively few attributes on a given operation
+/// (maybe a dozen or so, but not hundreds or thousands) so we use linear
+/// searches for everything.
+class MutableDictionaryAttr {
+public:
+ MutableDictionaryAttr(DictionaryAttr attrs = nullptr)
+ : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
+ MutableDictionaryAttr(ArrayRef<NamedAttribute> attributes);
+
+ bool operator!=(const MutableDictionaryAttr &other) const {
+ return !(*this == other);
+ }
+ bool operator==(const MutableDictionaryAttr &other) const {
+ return attrs == other.attrs;
+ }
+
+ /// Return the underlying dictionary attribute.
+ DictionaryAttr getDictionary(MLIRContext *context) const;
+
+ /// Return the underlying dictionary attribute or null if there are no
+ /// attributes within this dictionary.
+ DictionaryAttr getDictionaryOrNull() const { return attrs; }
+
+ /// Return all of the attributes on this operation.
+ ArrayRef<NamedAttribute> getAttrs() const;
+
+ /// Replace the held attributes with ones provided in 'newAttrs'.
+ void setAttrs(ArrayRef<NamedAttribute> attributes);
+
+ /// Return the specified attribute if present, null otherwise.
+ Attribute get(StringRef name) const;
+ Attribute get(Identifier name) const;
+
+ /// Return the specified named attribute if present, None otherwise.
+ Optional<NamedAttribute> getNamed(StringRef name) const;
+ Optional<NamedAttribute> getNamed(Identifier name) const;
+
+ /// If the an attribute exists with the specified name, change it to the new
+ /// value. Otherwise, add a new attribute with the specified name/value.
+ void set(Identifier name, Attribute value);
+
+ enum class RemoveResult { Removed, NotFound };
+
+ /// Remove the attribute with the specified name if it exists. The return
+ /// value indicates whether the attribute was present or not.
+ RemoveResult remove(Identifier name);
+
+ bool empty() const { return attrs == nullptr; }
+
+private:
+ friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg);
+
+ DictionaryAttr attrs;
+};
+
+inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) {
+ if (!arg.attrs)
+ return ::llvm::hash_value((void *)nullptr);
+ return hash_value(arg.attrs);
+}
+
+} // end namespace mlir.
+
+namespace llvm {
+
+template <>
+struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
+ : public PointerLikeTypeTraits<mlir::Attribute> {
+ static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
+ return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
+ .cast<mlir::SymbolRefAttr>();
+ }
+};
+
+} // namespace llvm
+
+#endif // MLIR_IR_BUILTINATTRIBUTES_H
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 3d5bc66ee9e2..5ed5cd9930fa 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -14,6 +14,7 @@
#define MLIR_IR_OPERATION_H
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Region.h"
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 74899c9565fe..88857f174783 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -30,6 +30,9 @@
namespace mlir {
class Dialect;
+class DictionaryAttr;
+class ElementsAttr;
+class MutableDictionaryAttr;
class Operation;
struct OperationState;
class OpAsmParser;
diff --git a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
index 2b7607e6655a..5a224d93cfdf 100644
--- a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
@@ -8,7 +8,7 @@
#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index aaff786fbe2f..6e11249054bc 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -21,6 +21,7 @@
namespace mlir {
class AffineForOp;
+class AffineMap;
class FuncOp;
class LoopLikeOpInterface;
struct MemRefRegion;
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index d413816a45fa..3a80064866c0 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -12,9 +12,9 @@
#include "PybindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Registration.h"
-#include "mlir-c/StandardAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <pybind11/stl.h>
@@ -1396,7 +1396,7 @@ class PyOpAttributeMap {
} // end namespace
//------------------------------------------------------------------------------
-// Standard attribute subclasses.
+// Builtin attribute subclasses.
//------------------------------------------------------------------------------
namespace {
@@ -3045,7 +3045,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");
- // Standard attribute bindings.
+ // Builtin attribute bindings.
PyFloatAttribute::bind(m);
PyIntegerAttribute::bind(m);
PyBoolAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
similarity index 99%
rename from mlir/lib/CAPI/IR/StandardAttributes.cpp
rename to mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 9c48bc963794..8db4ac3d3a38 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -1,4 +1,4 @@
-//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===//
+//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/StandardAttributes.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 427c0a655309..411e0582bf4d 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -2,11 +2,11 @@
add_mlir_public_c_api_library(MLIRCAPIIR
AffineExpr.cpp
AffineMap.cpp
+ BuiltinAttributes.cpp
BuiltinTypes.cpp
Diagnostics.cpp
IR.cpp
Pass.cpp
- StandardAttributes.cpp
Support.cpp
LINK_LIBS PUBLIC
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 1913fde9de1e..634917fe0faf 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -8,7 +8,7 @@
#include "mlir/IR/AffineMap.h"
#include "AffineMapDetail.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 3a140790fe51..db86d7a1c27b 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -14,7 +14,7 @@
#define ATTRIBUTEDETAIL_H_
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 06f741f6ef9f..bc7816430622 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -7,16 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
-#include "AttributeDetail.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/Endian.h"
using namespace mlir;
using namespace mlir::detail;
@@ -52,1550 +43,9 @@ Dialect &Attribute::getDialect() const {
}
//===----------------------------------------------------------------------===//
-// AffineMapAttr
+// NamedAttribute
//===----------------------------------------------------------------------===//
-AffineMapAttr AffineMapAttr::get(AffineMap value) {
- return Base::get(value.getContext(), value);
-}
-
-AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
-
-//===----------------------------------------------------------------------===//
-// ArrayAttr
-//===----------------------------------------------------------------------===//
-
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
- return Base::get(context, value);
-}
-
-ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
-
-Attribute ArrayAttr::operator[](unsigned idx) const {
- assert(idx < size() && "index out of bounds");
- return getValue()[idx];
-}
-
-//===----------------------------------------------------------------------===//
-// DictionaryAttr
-//===----------------------------------------------------------------------===//
-
-/// Helper function that does either an in place sort or sorts from source array
-/// into destination. If inPlace then storage is both the source and the
-/// destination, else value is the source and storage destination. Returns
-/// whether source was sorted.
-template <bool inPlace>
-static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
- SmallVectorImpl<NamedAttribute> &storage) {
- // Specialize for the common case.
- switch (value.size()) {
- case 0:
- // Zero already sorted.
- break;
- case 1:
- // One already sorted but may need to be copied.
- if (!inPlace)
- storage.assign({value[0]});
- break;
- case 2: {
- bool isSorted = value[0] < value[1];
- if (inPlace) {
- if (!isSorted)
- std::swap(storage[0], storage[1]);
- } else if (isSorted) {
- storage.assign({value[0], value[1]});
- } else {
- storage.assign({value[1], value[0]});
- }
- return !isSorted;
- }
- default:
- if (!inPlace)
- storage.assign(value.begin(), value.end());
- // Check to see they are sorted already.
- bool isSorted = llvm::is_sorted(value);
- if (!isSorted) {
- // If not, do a general sort.
- llvm::array_pod_sort(storage.begin(), storage.end());
- value = storage;
- }
- return !isSorted;
- }
- return false;
-}
-
-/// Returns an entry with a duplicate name from the given sorted array of named
-/// attributes. Returns llvm::None if all elements have unique names.
-static Optional<NamedAttribute>
-findDuplicateElement(ArrayRef<NamedAttribute> value) {
- const Optional<NamedAttribute> none{llvm::None};
- if (value.size() < 2)
- return none;
-
- if (value.size() == 2)
- return value[0].first == value[1].first ? value[0] : none;
-
- auto it = std::adjacent_find(
- value.begin(), value.end(),
- [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
- return it != value.end() ? *it : none;
-}
-
-bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
- SmallVectorImpl<NamedAttribute> &storage) {
- bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
- assert(!findDuplicateElement(storage) &&
- "DictionaryAttr element names must be unique");
- return isSorted;
-}
-
-bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
- bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
- assert(!findDuplicateElement(array) &&
- "DictionaryAttr element names must be unique");
- return isSorted;
-}
-
-Optional<NamedAttribute>
-DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
- bool isSorted) {
- if (!isSorted)
- dictionaryAttrSort</*inPlace=*/true>(array, array);
- return findDuplicateElement(array);
-}
-
-DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
- MLIRContext *context) {
- if (value.empty())
- return DictionaryAttr::getEmpty(context);
- assert(llvm::all_of(value,
- [](const NamedAttribute &attr) { return attr.second; }) &&
- "value cannot have null entries");
-
- // We need to sort the element list to canonicalize it.
- SmallVector<NamedAttribute, 8> storage;
- if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
- value = storage;
- assert(!findDuplicateElement(value) &&
- "DictionaryAttr element names must be unique");
- return Base::get(context, value);
-}
-/// Construct a dictionary with an array of values that is known to already be
-/// sorted by name and uniqued.
-DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
- MLIRContext *context) {
- if (value.empty())
- return DictionaryAttr::getEmpty(context);
- // Ensure that the attribute elements are unique and sorted.
- assert(llvm::is_sorted(value,
- [](NamedAttribute l, NamedAttribute r) {
- return l.first.strref() < r.first.strref();
- }) &&
- "expected attribute values to be sorted");
- assert(!findDuplicateElement(value) &&
- "DictionaryAttr element names must be unique");
- return Base::get(context, value);
-}
-
-ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
- return getImpl()->getElements();
-}
-
-/// Return the specified attribute if present, null otherwise.
-Attribute DictionaryAttr::get(StringRef name) const {
- Optional<NamedAttribute> attr = getNamed(name);
- return attr ? attr->second : nullptr;
-}
-Attribute DictionaryAttr::get(Identifier name) const {
- Optional<NamedAttribute> attr = getNamed(name);
- return attr ? attr->second : nullptr;
-}
-
-/// Return the specified named attribute if present, None otherwise.
-Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
- ArrayRef<NamedAttribute> values = getValue();
- const auto *it = llvm::lower_bound(values, name);
- return it != values.end() && it->first == name ? *it
- : Optional<NamedAttribute>();
-}
-Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
- for (auto elt : getValue())
- if (elt.first == name)
- return elt;
- return llvm::None;
-}
-
-DictionaryAttr::iterator DictionaryAttr::begin() const {
- return getValue().begin();
-}
-DictionaryAttr::iterator DictionaryAttr::end() const {
- return getValue().end();
-}
-size_t DictionaryAttr::size() const { return getValue().size(); }
-
-//===----------------------------------------------------------------------===//
-// FloatAttr
-//===----------------------------------------------------------------------===//
-
-FloatAttr FloatAttr::get(Type type, double value) {
- return Base::get(type.getContext(), type, value);
-}
-
-FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
- return Base::getChecked(loc, type, value);
-}
-
-FloatAttr FloatAttr::get(Type type, const APFloat &value) {
- return Base::get(type.getContext(), type, value);
-}
-
-FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
- return Base::getChecked(loc, type, value);
-}
-
-APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
-
-double FloatAttr::getValueAsDouble() const {
- return getValueAsDouble(getValue());
-}
-double FloatAttr::getValueAsDouble(APFloat value) {
- if (&value.getSemantics() != &APFloat::IEEEdouble()) {
- bool losesInfo = false;
- value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
- &losesInfo);
- }
- return value.convertToDouble();
-}
-
-/// Verify construction invariants.
-static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
- if (!type.isa<FloatType>())
- return emitError(loc, "expected floating point type");
- return success();
-}
-
-LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
- double value) {
- return verifyFloatTypeInvariants(loc, type);
-}
-
-LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
- const APFloat &value) {
- // Verify that the type is correct.
- if (failed(verifyFloatTypeInvariants(loc, type)))
- return failure();
-
- // Verify that the type semantics match that of the value.
- if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
- return emitError(
- loc, "FloatAttr type doesn't match the type implied by its value");
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// SymbolRefAttr
-//===----------------------------------------------------------------------===//
-
-FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
- return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
-}
-
-SymbolRefAttr SymbolRefAttr::get(StringRef value,
- ArrayRef<FlatSymbolRefAttr> nestedReferences,
- MLIRContext *ctx) {
- return Base::get(ctx, value, nestedReferences);
-}
-
-StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
-
-StringRef SymbolRefAttr::getLeafReference() const {
- ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
- return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
-}
-
-ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
- return getImpl()->getNestedRefs();
-}
-
-//===----------------------------------------------------------------------===//
-// IntegerAttr
-//===----------------------------------------------------------------------===//
-
-IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
- if (type.isSignlessInteger(1))
- return BoolAttr::get(value.getBoolValue(), type.getContext());
- return Base::get(type.getContext(), type, value);
-}
-
-IntegerAttr IntegerAttr::get(Type type, int64_t value) {
- // This uses 64 bit APInts by default for index type.
- if (type.isIndex())
- return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
-
- auto intType = type.cast<IntegerType>();
- return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
-}
-
-APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
-
-int64_t IntegerAttr::getInt() const {
- assert((getImpl()->getType().isIndex() ||
- getImpl()->getType().isSignlessInteger()) &&
- "must be signless integer");
- return getValue().getSExtValue();
-}
-
-int64_t IntegerAttr::getSInt() const {
- assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
- return getValue().getSExtValue();
-}
-
-uint64_t IntegerAttr::getUInt() const {
- assert(getImpl()->getType().isUnsignedInteger() &&
- "must be unsigned integer");
- return getValue().getZExtValue();
-}
-
-static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
- if (type.isa<IntegerType, IndexType>())
- return success();
- return emitError(loc, "expected integer or index type");
-}
-
-LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
- int64_t value) {
- return verifyIntegerTypeInvariants(loc, type);
-}
-
-LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
- const APInt &value) {
- if (failed(verifyIntegerTypeInvariants(loc, type)))
- return failure();
- if (auto integerType = type.dyn_cast<IntegerType>())
- if (integerType.getWidth() != value.getBitWidth())
- return emitError(loc, "integer type bit width (")
- << integerType.getWidth() << ") doesn't match value bit width ("
- << value.getBitWidth() << ")";
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BoolAttr
-
-bool BoolAttr::getValue() const {
- auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
- return storage->getValue().getBoolValue();
-}
-
-bool BoolAttr::classof(Attribute attr) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
- return intAttr && intAttr.getType().isSignlessInteger(1);
-}
-
-//===----------------------------------------------------------------------===//
-// IntegerSetAttr
-//===----------------------------------------------------------------------===//
-
-IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
- return Base::get(value.getConstraint(0).getContext(), value);
-}
-
-IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
-
-//===----------------------------------------------------------------------===//
-// OpaqueAttr
-//===----------------------------------------------------------------------===//
-
-OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
- MLIRContext *context) {
- return Base::get(context, dialect, attrData, type);
-}
-
-OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
- Type type, Location location) {
- return Base::getChecked(location, dialect, attrData, type);
-}
-
-/// Returns the dialect namespace of the opaque attribute.
-Identifier OpaqueAttr::getDialectNamespace() const {
- return getImpl()->dialectNamespace;
-}
-
-/// Returns the raw attribute data of the opaque attribute.
-StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
-
-/// Verify the construction of an opaque attribute.
-LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
- Identifier dialect,
- StringRef attrData,
- Type type) {
- if (!Dialect::isValidNamespace(dialect.strref()))
- return emitError(loc, "invalid dialect namespace '") << dialect << "'";
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// StringAttr
-//===----------------------------------------------------------------------===//
-
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
- return get(bytes, NoneType::get(context));
-}
-
-/// Get an instance of a StringAttr with the given string and Type.
-StringAttr StringAttr::get(StringRef bytes, Type type) {
- return Base::get(type.getContext(), bytes, type);
-}
-
-StringRef StringAttr::getValue() const { return getImpl()->value; }
-
-//===----------------------------------------------------------------------===//
-// TypeAttr
-//===----------------------------------------------------------------------===//
-
-TypeAttr TypeAttr::get(Type value) {
- return Base::get(value.getContext(), value);
-}
-
-Type TypeAttr::getValue() const { return getImpl()->value; }
-
-//===----------------------------------------------------------------------===//
-// ElementsAttr
-//===----------------------------------------------------------------------===//
-
-ShapedType ElementsAttr::getType() const {
- return Attribute::getType().cast<ShapedType>();
-}
-
-/// Returns the number of elements held by this attribute.
-int64_t ElementsAttr::getNumElements() const {
- return getType().getNumElements();
-}
-
-/// Return the value at the given index. If index does not refer to a valid
-/// element, then a null attribute is returned.
-Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
- if (auto denseAttr = dyn_cast<DenseElementsAttr>())
- return denseAttr.getValue(index);
- if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
- return opaqueAttr.getValue(index);
- return cast<SparseElementsAttr>().getValue(index);
-}
-
-/// Return if the given 'index' refers to a valid element in this attribute.
-bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
- auto type = getType();
-
- // Verify that the rank of the indices matches the held type.
- auto rank = type.getRank();
- if (rank != static_cast<int64_t>(index.size()))
- return false;
-
- // Verify that all of the indices are within the shape dimensions.
- auto shape = type.getShape();
- return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
- return static_cast<int64_t>(index[i]) < shape[i];
- });
-}
-
-ElementsAttr
-ElementsAttr::mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const {
- if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
- return intOrFpAttr.mapValues(newElementType, mapping);
- llvm_unreachable("unsupported ElementsAttr subtype");
-}
-
-ElementsAttr
-ElementsAttr::mapValues(Type newElementType,
- function_ref<APInt(const APFloat &)> mapping) const {
- if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
- return intOrFpAttr.mapValues(newElementType, mapping);
- llvm_unreachable("unsupported ElementsAttr subtype");
-}
-
-/// Method for support type inquiry through isa, cast and dyn_cast.
-bool ElementsAttr::classof(Attribute attr) {
- return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
- OpaqueElementsAttr, SparseElementsAttr>();
-}
-
-/// Returns the 1 dimensional flattened row-major index from the given
-/// multi-dimensional index.
-uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
- assert(isValidIndex(index) && "expected valid multi-dimensional index");
- auto type = getType();
-
- // Reduce the provided multidimensional index into a flattended 1D row-major
- // index.
- auto rank = type.getRank();
- auto shape = type.getShape();
- uint64_t valueIndex = 0;
- uint64_t dimMultiplier = 1;
- for (int i = rank - 1; i >= 0; --i) {
- valueIndex += index[i] * dimMultiplier;
- dimMultiplier *= shape[i];
- }
- return valueIndex;
-}
-
-//===----------------------------------------------------------------------===//
-// DenseElementsAttr Utilities
-//===----------------------------------------------------------------------===//
-
-/// Get the bitwidth of a dense element type within the buffer.
-/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
-static size_t getDenseElementStorageWidth(size_t origWidth) {
- return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
-}
-static size_t getDenseElementStorageWidth(Type elementType) {
- return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
-}
-
-/// Set a bit to a specific value.
-static void setBit(char *rawData, size_t bitPos, bool value) {
- if (value)
- rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
- else
- rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
-}
-
-/// Return the value of the specified bit.
-static bool getBit(const char *rawData, size_t bitPos) {
- return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
-}
-
-/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
-/// BE format.
-static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
- char *result) {
- assert(llvm::support::endian::system_endianness() == // NOLINT
- llvm::support::endianness::big); // NOLINT
- assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
-
- // Copy the words filled with data.
- // For example, when `value` has 2 words, the first word is filled with data.
- // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
- size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
- std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
- numFilledWords, result);
- // Convert last word of APInt to LE format and store it in char
- // array(`valueLE`).
- // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
- size_t lastWordPos = numFilledWords;
- SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
- valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
- // Extract actual APInt data from `valueLE`, convert endianness to BE format,
- // and store it in `result`.
- // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- valueLE.begin(), result + lastWordPos,
- (numBytes - lastWordPos) * CHAR_BIT, 1);
-}
-
-/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
-/// format.
-static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
- APInt &result) {
- assert(llvm::support::endian::system_endianness() == // NOLINT
- llvm::support::endianness::big); // NOLINT
- assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
-
- // Copy the data that fills the word of `result` from `inArray`.
- // For example, when `result` has 2 words, the first word will be filled with
- // data. So, the first 8 bytes are copied from `inArray` here.
- // `inArray` (10 bytes, BE): |abcdefgh|ij|
- // ==> `result` (2 words, BE): |abcdefgh|--------|
- size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
- std::copy_n(
- inArray, numFilledWords,
- const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
-
- // Convert array data which will be last word of `result` to LE format, and
- // store it in char array(`inArrayLE`).
- // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
- size_t lastWordPos = numFilledWords;
- SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- inArray + lastWordPos, inArrayLE.begin(),
- (numBytes - lastWordPos) * CHAR_BIT, 1);
-
- // Convert `inArrayLE` to BE format, and store it in last word of `result`.
- // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- inArrayLE.begin(),
- const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
- lastWordPos,
- APInt::APINT_BITS_PER_WORD, 1);
-}
-
-/// Writes value to the bit position `bitPos` in array `rawData`.
-static void writeBits(char *rawData, size_t bitPos, APInt value) {
- size_t bitWidth = value.getBitWidth();
-
- // If the bitwidth is 1 we just toggle the specific bit.
- if (bitWidth == 1)
- return setBit(rawData, bitPos, value.isOneValue());
-
- // Otherwise, the bit position is guaranteed to be byte aligned.
- assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
- if (llvm::support::endian::system_endianness() ==
- llvm::support::endianness::big) {
- // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
- // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
- // work correctly in BE format.
- // ex. `value` (2 words including 10 bytes)
- // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
- copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
- rawData + (bitPos / CHAR_BIT));
- } else {
- std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
- llvm::divideCeil(bitWidth, CHAR_BIT),
- rawData + (bitPos / CHAR_BIT));
- }
-}
-
-/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
-/// `rawData`.
-static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
- // Handle a boolean bit position.
- if (bitWidth == 1)
- return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
-
- // Otherwise, the bit position must be 8-bit aligned.
- assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
- APInt result(bitWidth, 0);
- if (llvm::support::endian::system_endianness() ==
- llvm::support::endianness::big) {
- // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
- // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
- // work correctly in BE format.
- // ex. `result` (2 words including 10 bytes)
- // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
- copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
- llvm::divideCeil(bitWidth, CHAR_BIT), result);
- } else {
- std::copy_n(rawData + (bitPos / CHAR_BIT),
- llvm::divideCeil(bitWidth, CHAR_BIT),
- const_cast<char *>(
- reinterpret_cast<const char *>(result.getRawData())));
- }
- return result;
-}
-
-/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
-/// the same element count as 'type'.
-template <typename Values>
-static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
- return (values.size() == 1) ||
- (type.getNumElements() == static_cast<int64_t>(values.size()));
-}
-
-//===----------------------------------------------------------------------===//
-// DenseElementsAttr Iterators
-//===----------------------------------------------------------------------===//
-
-//===----------------------------------------------------------------------===//
-// AttributeElementIterator
-
-DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
- DenseElementsAttr attr, size_t index)
- : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
- Attribute, Attribute, Attribute>(
- attr.getAsOpaquePointer(), index) {}
-
-Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
- auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
- Type eltTy = owner.getType().getElementType();
- if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
- return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (eltTy.isa<IndexType>())
- return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
- IntElementIterator intIt(owner, index);
- FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
- return FloatAttr::get(eltTy, *floatIt);
- }
- if (owner.isa<DenseStringElementsAttr>()) {
- ArrayRef<StringRef> vals = owner.getRawStringData();
- return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
- }
- llvm_unreachable("unexpected element type");
-}
-
-//===----------------------------------------------------------------------===//
-// BoolElementIterator
-
-DenseElementsAttr::BoolElementIterator::BoolElementIterator(
- DenseElementsAttr attr, size_t dataIndex)
- : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
- attr.getRawData().data(), attr.isSplat(), dataIndex) {}
-
-bool DenseElementsAttr::BoolElementIterator::operator*() const {
- return getBit(getData(), getDataIndex());
-}
-
-//===----------------------------------------------------------------------===//
-// IntElementIterator
-
-DenseElementsAttr::IntElementIterator::IntElementIterator(
- DenseElementsAttr attr, size_t dataIndex)
- : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
- attr.getRawData().data(), attr.isSplat(), dataIndex),
- bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
-
-APInt DenseElementsAttr::IntElementIterator::operator*() const {
- return readBits(getData(),
- getDataIndex() * getDenseElementStorageWidth(bitWidth),
- bitWidth);
-}
-
-//===----------------------------------------------------------------------===//
-// ComplexIntElementIterator
-
-DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
- DenseElementsAttr attr, size_t dataIndex)
- : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
- std::complex<APInt>, std::complex<APInt>,
- std::complex<APInt>>(
- attr.getRawData().data(), attr.isSplat(), dataIndex) {
- auto complexType = attr.getType().getElementType().cast<ComplexType>();
- bitWidth = getDenseElementBitWidth(complexType.getElementType());
-}
-
-std::complex<APInt>
-DenseElementsAttr::ComplexIntElementIterator::operator*() const {
- size_t storageWidth = getDenseElementStorageWidth(bitWidth);
- size_t offset = getDataIndex() * storageWidth * 2;
- return {readBits(getData(), offset, bitWidth),
- readBits(getData(), offset + storageWidth, bitWidth)};
-}
-
-//===----------------------------------------------------------------------===//
-// FloatElementIterator
-
-DenseElementsAttr::FloatElementIterator::FloatElementIterator(
- const llvm::fltSemantics &smt, IntElementIterator it)
- : llvm::mapped_iterator<IntElementIterator,
- std::function<APFloat(const APInt &)>>(
- it, [&](const APInt &val) { return APFloat(smt, val); }) {}
-
-//===----------------------------------------------------------------------===//
-// ComplexFloatElementIterator
-
-DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
- const llvm::fltSemantics &smt, ComplexIntElementIterator it)
- : llvm::mapped_iterator<
- ComplexIntElementIterator,
- std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
- it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
- return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
- }) {}
-
-//===----------------------------------------------------------------------===//
-// DenseElementsAttr
-//===----------------------------------------------------------------------===//
-
-/// Method for support type inquiry through isa, cast and dyn_cast.
-bool DenseElementsAttr::classof(Attribute attr) {
- return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
-}
-
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<Attribute> values) {
- assert(hasSameElementsOrSplat(type, values));
-
- // If the element type is not based on int/float/index, assume it is a string
- // type.
- auto eltType = type.getElementType();
- if (!type.getElementType().isIntOrIndexOrFloat()) {
- SmallVector<StringRef, 8> stringValues;
- stringValues.reserve(values.size());
- for (Attribute attr : values) {
- assert(attr.isa<StringAttr>() &&
- "expected string value for non integer/index/float element");
- stringValues.push_back(attr.cast<StringAttr>().getValue());
- }
- return get(type, stringValues);
- }
-
- // Otherwise, get the raw storage width to use for the allocation.
- size_t bitWidth = getDenseElementBitWidth(eltType);
- size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-
- // Compress the attribute values into a character buffer.
- SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
- values.size());
- APInt intVal;
- for (unsigned i = 0, e = values.size(); i < e; ++i) {
- assert(eltType == values[i].getType() &&
- "expected attribute value to have element type");
- if (eltType.isa<FloatType>())
- intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
- else if (eltType.isa<IntegerType>())
- intVal = values[i].cast<IntegerAttr>().getValue();
- else
- llvm_unreachable("unexpected element type");
-
- assert(intVal.getBitWidth() == bitWidth &&
- "expected value to have same bitwidth as element type");
- writeBits(data.data(), i * storageBitWidth, intVal);
- }
- return DenseIntOrFPElementsAttr::getRaw(type, data,
- /*isSplat=*/(values.size() == 1));
-}
-
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<bool> values) {
- assert(hasSameElementsOrSplat(type, values));
- assert(type.getElementType().isInteger(1));
-
- std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
- for (int i = 0, e = values.size(); i != e; ++i)
- setBit(buff.data(), i, values[i]);
- return DenseIntOrFPElementsAttr::getRaw(type, buff,
- /*isSplat=*/(values.size() == 1));
-}
-
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<StringRef> values) {
- assert(!type.getElementType().isIntOrFloat());
- return DenseStringElementsAttr::get(type, values);
-}
-
-/// Constructs a dense integer elements attribute from an array of APInt
-/// values. Each APInt value is expected to have the same bitwidth as the
-/// element type of 'type'.
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<APInt> values) {
- assert(type.getElementType().isIntOrIndex());
- assert(hasSameElementsOrSplat(type, values));
- size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
- /*isSplat=*/(values.size() == 1));
-}
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<std::complex<APInt>> values) {
- ComplexType complex = type.getElementType().cast<ComplexType>();
- assert(complex.getElementType().isa<IntegerType>());
- assert(hasSameElementsOrSplat(type, values));
- size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
- ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
- values.size() * 2);
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
- /*isSplat=*/(values.size() == 1));
-}
-
-// Constructs a dense float elements attribute from an array of APFloat
-// values. Each APFloat value is expected to have the same bitwidth as the
-// element type of 'type'.
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<APFloat> values) {
- assert(type.getElementType().isa<FloatType>());
- assert(hasSameElementsOrSplat(type, values));
- size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
- /*isSplat=*/(values.size() == 1));
-}
-DenseElementsAttr
-DenseElementsAttr::get(ShapedType type,
- ArrayRef<std::complex<APFloat>> values) {
- ComplexType complex = type.getElementType().cast<ComplexType>();
- assert(complex.getElementType().isa<FloatType>());
- assert(hasSameElementsOrSplat(type, values));
- ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
- values.size() * 2);
- size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
- /*isSplat=*/(values.size() == 1));
-}
-
-/// Construct a dense elements attribute from a raw buffer representing the
-/// data for this attribute. Users should generally not use this methods as
-/// the expected buffer format may not be a form the user expects.
-DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
- ArrayRef<char> rawBuffer,
- bool isSplatBuffer) {
- return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
-}
-
-/// Returns true if the given buffer is a valid raw buffer for the given type.
-bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
- ArrayRef<char> rawBuffer,
- bool &detectedSplat) {
- size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
- size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
-
- // Storage width of 1 is special as it is packed by the bit.
- if (storageWidth == 1) {
- // Check for a splat, or a buffer equal to the number of elements.
- if ((detectedSplat = rawBuffer.size() == 1))
- return true;
- return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
- }
- // All other types are 8-bit aligned.
- if ((detectedSplat = rawBufferWidth == storageWidth))
- return true;
- return rawBufferWidth == (storageWidth * type.getNumElements());
-}
-
-/// Check the information for a C++ data type, check if this type is valid for
-/// the current attribute. This method is used to verify specific type
-/// invariants that the templatized 'getValues' method cannot.
-static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
- bool isSigned) {
- // Make sure that the data element size is the same as the type element width.
- if (getDenseElementBitWidth(type) !=
- static_cast<size_t>(dataEltSize * CHAR_BIT))
- return false;
-
- // Check that the element type is either float or integer or index.
- if (!isInt)
- return type.isa<FloatType>();
- if (type.isIndex())
- return true;
-
- auto intType = type.dyn_cast<IntegerType>();
- if (!intType)
- return false;
-
- // Make sure signedness semantics is consistent.
- if (intType.isSignless())
- return true;
- return intType.isSigned() ? isSigned : !isSigned;
-}
-
-/// Defaults down the subclass implementation.
-DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt, bool isSigned) {
- return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
- isSigned);
-}
-DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt,
- bool isSigned) {
- return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
- isInt, isSigned);
-}
-
-/// A method used to verify specific type invariants that the templatized 'get'
-/// method cannot.
-bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
- bool isSigned) const {
- return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
- isSigned);
-}
-
-/// Check the information for a C++ data type, check if this type is valid for
-/// the current attribute.
-bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
- bool isSigned) const {
- return ::isValidIntOrFloat(
- getType().getElementType().cast<ComplexType>().getElementType(),
- dataEltSize / 2, isInt, isSigned);
-}
-
-/// Returns true if this attribute corresponds to a splat, i.e. if all element
-/// values are the same.
-bool DenseElementsAttr::isSplat() const {
- return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
-}
-
-/// Return the held element values as a range of Attributes.
-auto DenseElementsAttr::getAttributeValues() const
- -> llvm::iterator_range<AttributeElementIterator> {
- return {attr_value_begin(), attr_value_end()};
-}
-auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
- return AttributeElementIterator(*this, 0);
-}
-auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
- return AttributeElementIterator(*this, getNumElements());
-}
-
-/// Return the held element values as a range of bool. The element type of
-/// this attribute must be of integer type of bitwidth 1.
-auto DenseElementsAttr::getBoolValues() const
- -> llvm::iterator_range<BoolElementIterator> {
- auto eltType = getType().getElementType().dyn_cast<IntegerType>();
- assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
- (void)eltType;
- return {BoolElementIterator(*this, 0),
- BoolElementIterator(*this, getNumElements())};
-}
-
-/// Return the held element values as a range of APInts. The element type of
-/// this attribute must be of integer type.
-auto DenseElementsAttr::getIntValues() const
- -> llvm::iterator_range<IntElementIterator> {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return {raw_int_begin(), raw_int_end()};
-}
-auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return raw_int_begin();
-}
-auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return raw_int_end();
-}
-auto DenseElementsAttr::getComplexIntValues() const
- -> llvm::iterator_range<ComplexIntElementIterator> {
- Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
- (void)eltTy;
- assert(eltTy.isa<IntegerType>() && "expected complex integral type");
- return {ComplexIntElementIterator(*this, 0),
- ComplexIntElementIterator(*this, getNumElements())};
-}
-
-/// Return the held element values as a range of APFloat. The element type of
-/// this attribute must be of float type.
-auto DenseElementsAttr::getFloatValues() const
- -> llvm::iterator_range<FloatElementIterator> {
- auto elementType = getType().getElementType().cast<FloatType>();
- const auto &elementSemantics = elementType.getFloatSemantics();
- return {FloatElementIterator(elementSemantics, raw_int_begin()),
- FloatElementIterator(elementSemantics, raw_int_end())};
-}
-auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
- return getFloatValues().begin();
-}
-auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
- return getFloatValues().end();
-}
-auto DenseElementsAttr::getComplexFloatValues() const
- -> llvm::iterator_range<ComplexFloatElementIterator> {
- Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
- assert(eltTy.isa<FloatType>() && "expected complex float type");
- const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
- return {{semantics, {*this, 0}},
- {semantics, {*this, static_cast<size_t>(getNumElements())}}};
-}
-
-/// Return the raw storage data held by this attribute.
-ArrayRef<char> DenseElementsAttr::getRawData() const {
- return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
-}
-
-ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
- return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
-}
-
-/// Return a new DenseElementsAttr that has the same data as the current
-/// attribute, but has been reshaped to 'newType'. The new type must have the
-/// same total number of elements as well as element type.
-DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
- ShapedType curType = getType();
- if (curType == newType)
- return *this;
-
- (void)curType;
- assert(newType.getElementType() == curType.getElementType() &&
- "expected the same element type");
- assert(newType.getNumElements() == curType.getNumElements() &&
- "expected the same number of elements");
- return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
-}
-
-DenseElementsAttr
-DenseElementsAttr::mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const {
- return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
-}
-
-DenseElementsAttr DenseElementsAttr::mapValues(
- Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
- return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
-}
-
-//===----------------------------------------------------------------------===//
-// DenseStringElementsAttr
-//===----------------------------------------------------------------------===//
-
-DenseStringElementsAttr
-DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
- return Base::get(type.getContext(), type, values, (values.size() == 1));
-}
-
-//===----------------------------------------------------------------------===//
-// DenseIntOrFPElementsAttr
-//===----------------------------------------------------------------------===//
-
-/// Utility method to write a range of APInt values to a buffer.
-template <typename APRangeT>
-static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
- APRangeT &&values) {
- data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
- size_t offset = 0;
- for (auto it = values.begin(), e = values.end(); it != e;
- ++it, offset += storageWidth) {
- assert((*it).getBitWidth() <= storageWidth);
- writeBits(data.data(), offset, *it);
- }
-}
-
-/// Constructs a dense elements attribute from an array of raw APFloat values.
-/// Each APFloat value is expected to have the same bitwidth as the element
-/// type of 'type'. 'type' must be a vector or tensor with static shape.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- size_t storageWidth,
- ArrayRef<APFloat> values,
- bool isSplat) {
- std::vector<char> data;
- auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
- writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
- return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
-}
-
-/// Constructs a dense elements attribute from an array of raw APInt values.
-/// Each APInt value is expected to have the same bitwidth as the element type
-/// of 'type'.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- size_t storageWidth,
- ArrayRef<APInt> values,
- bool isSplat) {
- std::vector<char> data;
- writeAPIntsToBuffer(storageWidth, data, values);
- return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
-}
-
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- ArrayRef<char> data,
- bool isSplat) {
- assert((type.isa<RankedTensorType, VectorType>()) &&
- "type must be ranked tensor or vector");
- assert(type.hasStaticShape() && "type must have static shape");
- return Base::get(type.getContext(), type, data, isSplat);
-}
-
-/// Overload of the raw 'get' method that asserts that the given type is of
-/// complex type. This method is used to verify type invariants that the
-/// templatized 'get' method cannot.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt,
- bool isSigned) {
- assert(::isValidIntOrFloat(
- type.getElementType().cast<ComplexType>().getElementType(),
- dataEltSize / 2, isInt, isSigned));
-
- int64_t numElements = data.size() / dataEltSize;
- assert(numElements == 1 || numElements == type.getNumElements());
- return getRaw(type, data, /*isSplat=*/numElements == 1);
-}
-
-/// Overload of the 'getRaw' method that asserts that the given type is of
-/// integer type. This method is used to verify type invariants that the
-/// templatized 'get' method cannot.
-DenseElementsAttr
-DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned) {
- assert(
- ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
-
- int64_t numElements = data.size() / dataEltSize;
- assert(numElements == 1 || numElements == type.getNumElements());
- return getRaw(type, data, /*isSplat=*/numElements == 1);
-}
-
-void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- const char *inRawData, char *outRawData, size_t elementBitWidth,
- size_t numElements) {
- using llvm::support::ulittle16_t;
- using llvm::support::ulittle32_t;
- using llvm::support::ulittle64_t;
-
- assert(llvm::support::endian::system_endianness() == // NOLINT
- llvm::support::endianness::big); // NOLINT
- // NOLINT to avoid warning message about replacing by static_assert()
-
- // Following std::copy_n always converts endianness on BE machine.
- switch (elementBitWidth) {
- case 16: {
- const ulittle16_t *inRawDataPos =
- reinterpret_cast<const ulittle16_t *>(inRawData);
- uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
- std::copy_n(inRawDataPos, numElements, outDataPos);
- break;
- }
- case 32: {
- const ulittle32_t *inRawDataPos =
- reinterpret_cast<const ulittle32_t *>(inRawData);
- uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
- std::copy_n(inRawDataPos, numElements, outDataPos);
- break;
- }
- case 64: {
- const ulittle64_t *inRawDataPos =
- reinterpret_cast<const ulittle64_t *>(inRawData);
- uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
- std::copy_n(inRawDataPos, numElements, outDataPos);
- break;
- }
- default: {
- size_t nBytes = elementBitWidth / CHAR_BIT;
- for (size_t i = 0; i < nBytes; i++)
- std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
- break;
- }
- }
-}
-
-void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
- ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
- ShapedType type) {
- size_t numElements = type.getNumElements();
- Type elementType = type.getElementType();
- if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
- elementType = complexTy.getElementType();
- numElements = numElements * 2;
- }
- size_t elementBitWidth = getDenseElementStorageWidth(elementType);
- assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
- inRawData.size() <= outRawData.size());
- convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
- elementBitWidth, numElements);
-}
-
-//===----------------------------------------------------------------------===//
-// DenseFPElementsAttr
-//===----------------------------------------------------------------------===//
-
-template <typename Fn, typename Attr>
-static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
- Type newElementType,
- llvm::SmallVectorImpl<char> &data) {
- size_t bitWidth = getDenseElementBitWidth(newElementType);
- size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-
- ShapedType newArrayType;
- if (inType.isa<RankedTensorType>())
- newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
- else if (inType.isa<UnrankedTensorType>())
- newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
- else if (inType.isa<VectorType>())
- newArrayType = VectorType::get(inType.getShape(), newElementType);
- else
- assert(newArrayType && "Unhandled tensor type");
-
- size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
- data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
-
- // Functor used to process a single element value of the attribute.
- auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
- auto newInt = mapping(value);
- assert(newInt.getBitWidth() == bitWidth);
- writeBits(data.data(), index * storageBitWidth, newInt);
- };
-
- // Check for the splat case.
- if (attr.isSplat()) {
- processElt(*attr.begin(), /*index=*/0);
- return newArrayType;
- }
-
- // Otherwise, process all of the element values.
- uint64_t elementIdx = 0;
- for (auto value : attr)
- processElt(value, elementIdx++);
- return newArrayType;
-}
-
-DenseElementsAttr DenseFPElementsAttr::mapValues(
- Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
- llvm::SmallVector<char, 8> elementData;
- auto newArrayType =
- mappingHelper(mapping, *this, getType(), newElementType, elementData);
-
- return getRaw(newArrayType, elementData, isSplat());
-}
-
-/// Method for supporting type inquiry through isa, cast and dyn_cast.
-bool DenseFPElementsAttr::classof(Attribute attr) {
- return attr.isa<DenseElementsAttr>() &&
- attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
-}
-
-//===----------------------------------------------------------------------===//
-// DenseIntElementsAttr
-//===----------------------------------------------------------------------===//
-
-DenseElementsAttr DenseIntElementsAttr::mapValues(
- Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
- llvm::SmallVector<char, 8> elementData;
- auto newArrayType =
- mappingHelper(mapping, *this, getType(), newElementType, elementData);
-
- return getRaw(newArrayType, elementData, isSplat());
-}
-
-/// Method for supporting type inquiry through isa, cast and dyn_cast.
-bool DenseIntElementsAttr::classof(Attribute attr) {
- return attr.isa<DenseElementsAttr>() &&
- attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
-}
-
-//===----------------------------------------------------------------------===//
-// OpaqueElementsAttr
-//===----------------------------------------------------------------------===//
-
-OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
- StringRef bytes) {
- assert(TensorType::isValidElementType(type.getElementType()) &&
- "Input element type should be a valid tensor element type");
- return Base::get(type.getContext(), type, dialect, bytes);
-}
-
-StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
-
-/// Return the value at the given index. If index does not refer to a valid
-/// element, then a null attribute is returned.
-Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
- assert(isValidIndex(index) && "expected valid multi-dimensional index");
- return Attribute();
-}
-
-Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
-
-bool OpaqueElementsAttr::decode(ElementsAttr &result) {
- auto *d = getDialect();
- if (!d)
- return true;
- auto *interface =
- d->getRegisteredInterface<DialectDecodeAttributesInterface>();
- if (!interface)
- return true;
- return failed(interface->decode(*this, result));
-}
-
-//===----------------------------------------------------------------------===//
-// SparseElementsAttr
-//===----------------------------------------------------------------------===//
-
-SparseElementsAttr SparseElementsAttr::get(ShapedType type,
- DenseElementsAttr indices,
- DenseElementsAttr values) {
- assert(indices.getType().getElementType().isInteger(64) &&
- "expected sparse indices to be 64-bit integer values");
- assert((type.isa<RankedTensorType, VectorType>()) &&
- "type must be ranked tensor or vector");
- assert(type.hasStaticShape() && "type must have static shape");
- return Base::get(type.getContext(), type,
- indices.cast<DenseIntElementsAttr>(), values);
-}
-
-DenseIntElementsAttr SparseElementsAttr::getIndices() const {
- return getImpl()->indices;
-}
-
-DenseElementsAttr SparseElementsAttr::getValues() const {
- return getImpl()->values;
-}
-
-/// Return the value of the element at the given index.
-Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
- assert(isValidIndex(index) && "expected valid multi-dimensional index");
- auto type = getType();
-
- // The sparse indices are 64-bit integers, so we can reinterpret the raw data
- // as a 1-D index array.
- auto sparseIndices = getIndices();
- auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
-
- // Check to see if the indices are a splat.
- if (sparseIndices.isSplat()) {
- // If the index is also not a splat of the index value, we know that the
- // value is zero.
- auto splatIndex = *sparseIndexValues.begin();
- if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
- return getZeroAttr();
-
- // If the indices are a splat, we also expect the values to be a splat.
- assert(getValues().isSplat() && "expected splat values");
- return getValues().getSplatValue();
- }
-
- // Build a mapping between known indices and the offset of the stored element.
- llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
- auto numSparseIndices = sparseIndices.getType().getDimSize(0);
- size_t rank = type.getRank();
- for (size_t i = 0, e = numSparseIndices; i != e; ++i)
- mappedIndices.try_emplace(
- {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
-
- // Look for the provided index key within the mapped indices. If the provided
- // index is not found, then return a zero attribute.
- auto it = mappedIndices.find(index);
- if (it == mappedIndices.end())
- return getZeroAttr();
-
- // Otherwise, return the held sparse value element.
- return getValues().getValue(it->second);
-}
-
-/// Get a zero APFloat for the given sparse attribute.
-APFloat SparseElementsAttr::getZeroAPFloat() const {
- auto eltType = getType().getElementType().cast<FloatType>();
- return APFloat(eltType.getFloatSemantics());
-}
-
-/// Get a zero APInt for the given sparse attribute.
-APInt SparseElementsAttr::getZeroAPInt() const {
- auto eltType = getType().getElementType().cast<IntegerType>();
- return APInt::getNullValue(eltType.getWidth());
-}
-
-/// Get a zero attribute for the given attribute type.
-Attribute SparseElementsAttr::getZeroAttr() const {
- auto eltType = getType().getElementType();
-
- // Handle floating point elements.
- if (eltType.isa<FloatType>())
- return FloatAttr::get(eltType, 0);
-
- // Otherwise, this is an integer.
- // TODO: Handle StringAttr here.
- return IntegerAttr::get(eltType, 0);
-}
-
-/// Flatten, and return, all of the sparse indices in this attribute in
-/// row-major order.
-std::vector<ptr
diff _t> SparseElementsAttr::getFlattenedSparseIndices() const {
- std::vector<ptr
diff _t> flatSparseIndices;
-
- // The sparse indices are 64-bit integers, so we can reinterpret the raw data
- // as a 1-D index array.
- auto sparseIndices = getIndices();
- auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
- if (sparseIndices.isSplat()) {
- SmallVector<uint64_t, 8> indices(getType().getRank(),
- *sparseIndexValues.begin());
- flatSparseIndices.push_back(getFlattenedIndex(indices));
- return flatSparseIndices;
- }
-
- // Otherwise, reinterpret each index as an ArrayRef when flattening.
- auto numSparseIndices = sparseIndices.getType().getDimSize(0);
- size_t rank = getType().getRank();
- for (size_t i = 0, e = numSparseIndices; i != e; ++i)
- flatSparseIndices.push_back(getFlattenedIndex(
- {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
- return flatSparseIndices;
-}
-
-//===----------------------------------------------------------------------===//
-// MutableDictionaryAttr
-//===----------------------------------------------------------------------===//
-
-MutableDictionaryAttr::MutableDictionaryAttr(
- ArrayRef<NamedAttribute> attributes) {
- setAttrs(attributes);
-}
-
-/// Return the underlying dictionary attribute.
-DictionaryAttr
-MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
- // Construct empty DictionaryAttr if needed.
- if (!attrs)
- return DictionaryAttr::get({}, context);
- return attrs;
-}
-
-ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
- return attrs ? attrs.getValue() : llvm::None;
-}
-
-/// Replace the held attributes with ones provided in 'newAttrs'.
-void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
- // Don't create an attribute list if there are no attributes.
- if (attributes.empty())
- attrs = nullptr;
- else
- attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
-}
-
-/// Return the specified attribute if present, null otherwise.
-Attribute MutableDictionaryAttr::get(StringRef name) const {
- return attrs ? attrs.get(name) : nullptr;
-}
-
-/// Return the specified attribute if present, null otherwise.
-Attribute MutableDictionaryAttr::get(Identifier name) const {
- return attrs ? attrs.get(name) : nullptr;
-}
-
-/// Return the specified named attribute if present, None otherwise.
-Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
- return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
-}
-Optional<NamedAttribute>
-MutableDictionaryAttr::getNamed(Identifier name) const {
- return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
-}
-
-/// If the an attribute exists with the specified name, change it to the new
-/// value. Otherwise, add a new attribute with the specified name/value.
-void MutableDictionaryAttr::set(Identifier name, Attribute value) {
- assert(value && "attributes may never be null");
-
- // Look for an existing value for the given name, and set it in-place.
- ArrayRef<NamedAttribute> values = getAttrs();
- const auto *it = llvm::find_if(
- values, [name](NamedAttribute attr) { return attr.first == name; });
- if (it != values.end()) {
- // Bail out early if the value is the same as what we already have.
- if (it->second == value)
- return;
-
- SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
- newAttrs[it - values.begin()].second = value;
- attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
- return;
- }
-
- // Otherwise, insert the new attribute into its sorted position.
- it = llvm::lower_bound(values, name);
- SmallVector<NamedAttribute, 8> newAttrs;
- newAttrs.reserve(values.size() + 1);
- newAttrs.append(values.begin(), it);
- newAttrs.push_back({name, value});
- newAttrs.append(it, values.end());
- attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
-}
-
-/// Remove the attribute with the specified name if it exists. The return
-/// value indicates whether the attribute was present or not.
-auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
- auto origAttrs = getAttrs();
- for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
- if (origAttrs[i].first == name) {
- // Handle the simple case of removing the only attribute in the list.
- if (e == 1) {
- attrs = nullptr;
- return RemoveResult::Removed;
- }
-
- SmallVector<NamedAttribute, 8> newAttrs;
- newAttrs.reserve(origAttrs.size() - 1);
- newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
- newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
- attrs = DictionaryAttr::getWithSorted(newAttrs,
- newAttrs[0].second.getContext());
- return RemoveResult::Removed;
- }
- }
- return RemoveResult::NotFound;
-}
-
bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
return strcmp(lhs.first.data(), rhs.first.data()) < 0;
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 72fb240f6be1..64eb37c1e277 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
new file mode 100644
index 000000000000..efd4ec657f3c
--- /dev/null
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -0,0 +1,1567 @@
+//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "AttributeDetail.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Endian.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
+AffineMapAttr AffineMapAttr::get(AffineMap value) {
+ return Base::get(value.getContext(), value);
+}
+
+AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ return Base::get(context, value);
+}
+
+ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
+
+Attribute ArrayAttr::operator[](unsigned idx) const {
+ assert(idx < size() && "index out of bounds");
+ return getValue()[idx];
+}
+
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
+/// Helper function that does either an in place sort or sorts from source array
+/// into destination. If inPlace then storage is both the source and the
+/// destination, else value is the source and storage destination. Returns
+/// whether source was sorted.
+template <bool inPlace>
+static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
+ SmallVectorImpl<NamedAttribute> &storage) {
+ // Specialize for the common case.
+ switch (value.size()) {
+ case 0:
+ // Zero already sorted.
+ break;
+ case 1:
+ // One already sorted but may need to be copied.
+ if (!inPlace)
+ storage.assign({value[0]});
+ break;
+ case 2: {
+ bool isSorted = value[0] < value[1];
+ if (inPlace) {
+ if (!isSorted)
+ std::swap(storage[0], storage[1]);
+ } else if (isSorted) {
+ storage.assign({value[0], value[1]});
+ } else {
+ storage.assign({value[1], value[0]});
+ }
+ return !isSorted;
+ }
+ default:
+ if (!inPlace)
+ storage.assign(value.begin(), value.end());
+ // Check to see they are sorted already.
+ bool isSorted = llvm::is_sorted(value);
+ if (!isSorted) {
+ // If not, do a general sort.
+ llvm::array_pod_sort(storage.begin(), storage.end());
+ value = storage;
+ }
+ return !isSorted;
+ }
+ return false;
+}
+
+/// Returns an entry with a duplicate name from the given sorted array of named
+/// attributes. Returns llvm::None if all elements have unique names.
+static Optional<NamedAttribute>
+findDuplicateElement(ArrayRef<NamedAttribute> value) {
+ const Optional<NamedAttribute> none{llvm::None};
+ if (value.size() < 2)
+ return none;
+
+ if (value.size() == 2)
+ return value[0].first == value[1].first ? value[0] : none;
+
+ auto it = std::adjacent_find(
+ value.begin(), value.end(),
+ [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
+ return it != value.end() ? *it : none;
+}
+
+bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
+ SmallVectorImpl<NamedAttribute> &storage) {
+ bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
+ assert(!findDuplicateElement(storage) &&
+ "DictionaryAttr element names must be unique");
+ return isSorted;
+}
+
+bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
+ bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
+ assert(!findDuplicateElement(array) &&
+ "DictionaryAttr element names must be unique");
+ return isSorted;
+}
+
+Optional<NamedAttribute>
+DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
+ bool isSorted) {
+ if (!isSorted)
+ dictionaryAttrSort</*inPlace=*/true>(array, array);
+ return findDuplicateElement(array);
+}
+
+DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
+ MLIRContext *context) {
+ if (value.empty())
+ return DictionaryAttr::getEmpty(context);
+ assert(llvm::all_of(value,
+ [](const NamedAttribute &attr) { return attr.second; }) &&
+ "value cannot have null entries");
+
+ // We need to sort the element list to canonicalize it.
+ SmallVector<NamedAttribute, 8> storage;
+ if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
+ value = storage;
+ assert(!findDuplicateElement(value) &&
+ "DictionaryAttr element names must be unique");
+ return Base::get(context, value);
+}
+/// Construct a dictionary with an array of values that is known to already be
+/// sorted by name and uniqued.
+DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
+ MLIRContext *context) {
+ if (value.empty())
+ return DictionaryAttr::getEmpty(context);
+ // Ensure that the attribute elements are unique and sorted.
+ assert(llvm::is_sorted(value,
+ [](NamedAttribute l, NamedAttribute r) {
+ return l.first.strref() < r.first.strref();
+ }) &&
+ "expected attribute values to be sorted");
+ assert(!findDuplicateElement(value) &&
+ "DictionaryAttr element names must be unique");
+ return Base::get(context, value);
+}
+
+ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
+ return getImpl()->getElements();
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute DictionaryAttr::get(StringRef name) const {
+ Optional<NamedAttribute> attr = getNamed(name);
+ return attr ? attr->second : nullptr;
+}
+Attribute DictionaryAttr::get(Identifier name) const {
+ Optional<NamedAttribute> attr = getNamed(name);
+ return attr ? attr->second : nullptr;
+}
+
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
+ ArrayRef<NamedAttribute> values = getValue();
+ const auto *it = llvm::lower_bound(values, name);
+ return it != values.end() && it->first == name ? *it
+ : Optional<NamedAttribute>();
+}
+Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
+ for (auto elt : getValue())
+ if (elt.first == name)
+ return elt;
+ return llvm::None;
+}
+
+DictionaryAttr::iterator DictionaryAttr::begin() const {
+ return getValue().begin();
+}
+DictionaryAttr::iterator DictionaryAttr::end() const {
+ return getValue().end();
+}
+size_t DictionaryAttr::size() const { return getValue().size(); }
+
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
+FloatAttr FloatAttr::get(Type type, double value) {
+ return Base::get(type.getContext(), type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
+ return Base::getChecked(loc, type, value);
+}
+
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+ return Base::get(type.getContext(), type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
+ return Base::getChecked(loc, type, value);
+}
+
+APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
+
+double FloatAttr::getValueAsDouble() const {
+ return getValueAsDouble(getValue());
+}
+double FloatAttr::getValueAsDouble(APFloat value) {
+ if (&value.getSemantics() != &APFloat::IEEEdouble()) {
+ bool losesInfo = false;
+ value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ }
+ return value.convertToDouble();
+}
+
+/// Verify construction invariants.
+static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
+ if (!type.isa<FloatType>())
+ return emitError(loc, "expected floating point type");
+ return success();
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
+ double value) {
+ return verifyFloatTypeInvariants(loc, type);
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
+ const APFloat &value) {
+ // Verify that the type is correct.
+ if (failed(verifyFloatTypeInvariants(loc, type)))
+ return failure();
+
+ // Verify that the type semantics match that of the value.
+ if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+ return emitError(
+ loc, "FloatAttr type doesn't match the type implied by its value");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
+}
+
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
+ return Base::get(ctx, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+ return getImpl()->getNestedRefs();
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
+ if (type.isSignlessInteger(1))
+ return BoolAttr::get(value.getBoolValue(), type.getContext());
+ return Base::get(type.getContext(), type, value);
+}
+
+IntegerAttr IntegerAttr::get(Type type, int64_t value) {
+ // This uses 64 bit APInts by default for index type.
+ if (type.isIndex())
+ return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
+
+ auto intType = type.cast<IntegerType>();
+ return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
+}
+
+APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
+
+int64_t IntegerAttr::getInt() const {
+ assert((getImpl()->getType().isIndex() ||
+ getImpl()->getType().isSignlessInteger()) &&
+ "must be signless integer");
+ return getValue().getSExtValue();
+}
+
+int64_t IntegerAttr::getSInt() const {
+ assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
+ return getValue().getSExtValue();
+}
+
+uint64_t IntegerAttr::getUInt() const {
+ assert(getImpl()->getType().isUnsignedInteger() &&
+ "must be unsigned integer");
+ return getValue().getZExtValue();
+}
+
+static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
+ if (type.isa<IntegerType, IndexType>())
+ return success();
+ return emitError(loc, "expected integer or index type");
+}
+
+LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
+ int64_t value) {
+ return verifyIntegerTypeInvariants(loc, type);
+}
+
+LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
+ const APInt &value) {
+ if (failed(verifyIntegerTypeInvariants(loc, type)))
+ return failure();
+ if (auto integerType = type.dyn_cast<IntegerType>())
+ if (integerType.getWidth() != value.getBitWidth())
+ return emitError(loc, "integer type bit width (")
+ << integerType.getWidth() << ") doesn't match value bit width ("
+ << value.getBitWidth() << ")";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BoolAttr
+
+bool BoolAttr::getValue() const {
+ auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
+ return storage->getValue().getBoolValue();
+}
+
+bool BoolAttr::classof(Attribute attr) {
+ IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ return intAttr && intAttr.getType().isSignlessInteger(1);
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
+IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
+ return Base::get(value.getConstraint(0).getContext(), value);
+}
+
+IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
+ MLIRContext *context) {
+ return Base::get(context, dialect, attrData, type);
+}
+
+OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
+ Type type, Location location) {
+ return Base::getChecked(location, dialect, attrData, type);
+}
+
+/// Returns the dialect namespace of the opaque attribute.
+Identifier OpaqueAttr::getDialectNamespace() const {
+ return getImpl()->dialectNamespace;
+}
+
+/// Returns the raw attribute data of the opaque attribute.
+StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
+
+/// Verify the construction of an opaque attribute.
+LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
+ Identifier dialect,
+ StringRef attrData,
+ Type type) {
+ if (!Dialect::isValidNamespace(dialect.strref()))
+ return emitError(loc, "invalid dialect namespace '") << dialect << "'";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+ return get(bytes, NoneType::get(context));
+}
+
+/// Get an instance of a StringAttr with the given string and Type.
+StringAttr StringAttr::get(StringRef bytes, Type type) {
+ return Base::get(type.getContext(), bytes, type);
+}
+
+StringRef StringAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// TypeAttr
+//===----------------------------------------------------------------------===//
+
+TypeAttr TypeAttr::get(Type value) {
+ return Base::get(value.getContext(), value);
+}
+
+Type TypeAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType ElementsAttr::getType() const {
+ return Attribute::getType().cast<ShapedType>();
+}
+
+/// Returns the number of elements held by this attribute.
+int64_t ElementsAttr::getNumElements() const {
+ return getType().getNumElements();
+}
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>())
+ return denseAttr.getValue(index);
+ if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
+ return opaqueAttr.getValue(index);
+ return cast<SparseElementsAttr>().getValue(index);
+}
+
+/// Return if the given 'index' refers to a valid element in this attribute.
+bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
+ auto type = getType();
+
+ // Verify that the rank of the indices matches the held type.
+ auto rank = type.getRank();
+ if (rank != static_cast<int64_t>(index.size()))
+ return false;
+
+ // Verify that all of the indices are within the shape dimensions.
+ auto shape = type.getShape();
+ return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
+ return static_cast<int64_t>(index[i]) < shape[i];
+ });
+}
+
+ElementsAttr
+ElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const {
+ if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+ return intOrFpAttr.mapValues(newElementType, mapping);
+ llvm_unreachable("unsupported ElementsAttr subtype");
+}
+
+ElementsAttr
+ElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const {
+ if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+ return intOrFpAttr.mapValues(newElementType, mapping);
+ llvm_unreachable("unsupported ElementsAttr subtype");
+}
+
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool ElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+ OpaqueElementsAttr, SparseElementsAttr>();
+}
+
+/// Returns the 1 dimensional flattened row-major index from the given
+/// multi-dimensional index.
+uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ auto type = getType();
+
+ // Reduce the provided multidimensional index into a flattended 1D row-major
+ // index.
+ auto rank = type.getRank();
+ auto shape = type.getShape();
+ uint64_t valueIndex = 0;
+ uint64_t dimMultiplier = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ valueIndex += index[i] * dimMultiplier;
+ dimMultiplier *= shape[i];
+ }
+ return valueIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr Utilities
+//===----------------------------------------------------------------------===//
+
+/// Get the bitwidth of a dense element type within the buffer.
+/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
+static size_t getDenseElementStorageWidth(size_t origWidth) {
+ return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+}
+static size_t getDenseElementStorageWidth(Type elementType) {
+ return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
+}
+
+/// Set a bit to a specific value.
+static void setBit(char *rawData, size_t bitPos, bool value) {
+ if (value)
+ rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
+ else
+ rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
+}
+
+/// Return the value of the specified bit.
+static bool getBit(const char *rawData, size_t bitPos) {
+ return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
+}
+
+/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
+/// BE format.
+static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
+ char *result) {
+ assert(llvm::support::endian::system_endianness() == // NOLINT
+ llvm::support::endianness::big); // NOLINT
+ assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
+
+ // Copy the words filled with data.
+ // For example, when `value` has 2 words, the first word is filled with data.
+ // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
+ size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
+ std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
+ numFilledWords, result);
+ // Convert last word of APInt to LE format and store it in char
+ // array(`valueLE`).
+ // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
+ size_t lastWordPos = numFilledWords;
+ SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
+ DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
+ valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
+ // Extract actual APInt data from `valueLE`, convert endianness to BE format,
+ // and store it in `result`.
+ // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
+ DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ valueLE.begin(), result + lastWordPos,
+ (numBytes - lastWordPos) * CHAR_BIT, 1);
+}
+
+/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
+/// format.
+static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
+ APInt &result) {
+ assert(llvm::support::endian::system_endianness() == // NOLINT
+ llvm::support::endianness::big); // NOLINT
+ assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
+
+ // Copy the data that fills the word of `result` from `inArray`.
+ // For example, when `result` has 2 words, the first word will be filled with
+ // data. So, the first 8 bytes are copied from `inArray` here.
+ // `inArray` (10 bytes, BE): |abcdefgh|ij|
+ // ==> `result` (2 words, BE): |abcdefgh|--------|
+ size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
+ std::copy_n(
+ inArray, numFilledWords,
+ const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
+
+ // Convert array data which will be last word of `result` to LE format, and
+ // store it in char array(`inArrayLE`).
+ // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
+ size_t lastWordPos = numFilledWords;
+ SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
+ DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ inArray + lastWordPos, inArrayLE.begin(),
+ (numBytes - lastWordPos) * CHAR_BIT, 1);
+
+ // Convert `inArrayLE` to BE format, and store it in last word of `result`.
+ // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
+ DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ inArrayLE.begin(),
+ const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
+ lastWordPos,
+ APInt::APINT_BITS_PER_WORD, 1);
+}
+
+/// Writes value to the bit position `bitPos` in array `rawData`.
+static void writeBits(char *rawData, size_t bitPos, APInt value) {
+ size_t bitWidth = value.getBitWidth();
+
+ // If the bitwidth is 1 we just toggle the specific bit.
+ if (bitWidth == 1)
+ return setBit(rawData, bitPos, value.isOneValue());
+
+ // Otherwise, the bit position is guaranteed to be byte aligned.
+ assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+ if (llvm::support::endian::system_endianness() ==
+ llvm::support::endianness::big) {
+ // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
+ // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
+ // work correctly in BE format.
+ // ex. `value` (2 words including 10 bytes)
+ // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
+ copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
+ rawData + (bitPos / CHAR_BIT));
+ } else {
+ std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
+ llvm::divideCeil(bitWidth, CHAR_BIT),
+ rawData + (bitPos / CHAR_BIT));
+ }
+}
+
+/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
+/// `rawData`.
+static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
+ // Handle a boolean bit position.
+ if (bitWidth == 1)
+ return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
+
+ // Otherwise, the bit position must be 8-bit aligned.
+ assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+ APInt result(bitWidth, 0);
+ if (llvm::support::endian::system_endianness() ==
+ llvm::support::endianness::big) {
+ // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
+ // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
+ // work correctly in BE format.
+ // ex. `result` (2 words including 10 bytes)
+ // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
+ copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
+ llvm::divideCeil(bitWidth, CHAR_BIT), result);
+ } else {
+ std::copy_n(rawData + (bitPos / CHAR_BIT),
+ llvm::divideCeil(bitWidth, CHAR_BIT),
+ const_cast<char *>(
+ reinterpret_cast<const char *>(result.getRawData())));
+ }
+ return result;
+}
+
+/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
+/// the same element count as 'type'.
+template <typename Values>
+static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
+ return (values.size() == 1) ||
+ (type.getNumElements() == static_cast<int64_t>(values.size()));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr Iterators
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// AttributeElementIterator
+
+DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
+ DenseElementsAttr attr, size_t index)
+ : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
+ Attribute, Attribute, Attribute>(
+ attr.getAsOpaquePointer(), index) {}
+
+Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
+ auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
+ Type eltTy = owner.getType().getElementType();
+ if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
+ return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+ if (eltTy.isa<IndexType>())
+ return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+ if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
+ IntElementIterator intIt(owner, index);
+ FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
+ return FloatAttr::get(eltTy, *floatIt);
+ }
+ if (owner.isa<DenseStringElementsAttr>()) {
+ ArrayRef<StringRef> vals = owner.getRawStringData();
+ return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
+ }
+ llvm_unreachable("unexpected element type");
+}
+
+//===----------------------------------------------------------------------===//
+// BoolElementIterator
+
+DenseElementsAttr::BoolElementIterator::BoolElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex) {}
+
+bool DenseElementsAttr::BoolElementIterator::operator*() const {
+ return getBit(getData(), getDataIndex());
+}
+
+//===----------------------------------------------------------------------===//
+// IntElementIterator
+
+DenseElementsAttr::IntElementIterator::IntElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex),
+ bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
+
+APInt DenseElementsAttr::IntElementIterator::operator*() const {
+ return readBits(getData(),
+ getDataIndex() * getDenseElementStorageWidth(bitWidth),
+ bitWidth);
+}
+
+//===----------------------------------------------------------------------===//
+// ComplexIntElementIterator
+
+DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
+ std::complex<APInt>, std::complex<APInt>,
+ std::complex<APInt>>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex) {
+ auto complexType = attr.getType().getElementType().cast<ComplexType>();
+ bitWidth = getDenseElementBitWidth(complexType.getElementType());
+}
+
+std::complex<APInt>
+DenseElementsAttr::ComplexIntElementIterator::operator*() const {
+ size_t storageWidth = getDenseElementStorageWidth(bitWidth);
+ size_t offset = getDataIndex() * storageWidth * 2;
+ return {readBits(getData(), offset, bitWidth),
+ readBits(getData(), offset + storageWidth, bitWidth)};
+}
+
+//===----------------------------------------------------------------------===//
+// FloatElementIterator
+
+DenseElementsAttr::FloatElementIterator::FloatElementIterator(
+ const llvm::fltSemantics &smt, IntElementIterator it)
+ : llvm::mapped_iterator<IntElementIterator,
+ std::function<APFloat(const APInt &)>>(
+ it, [&](const APInt &val) { return APFloat(smt, val); }) {}
+
+//===----------------------------------------------------------------------===//
+// ComplexFloatElementIterator
+
+DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
+ const llvm::fltSemantics &smt, ComplexIntElementIterator it)
+ : llvm::mapped_iterator<
+ ComplexIntElementIterator,
+ std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
+ it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
+ return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
+ }) {}
+
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr
+//===----------------------------------------------------------------------===//
+
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool DenseElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<Attribute> values) {
+ assert(hasSameElementsOrSplat(type, values));
+
+ // If the element type is not based on int/float/index, assume it is a string
+ // type.
+ auto eltType = type.getElementType();
+ if (!type.getElementType().isIntOrIndexOrFloat()) {
+ SmallVector<StringRef, 8> stringValues;
+ stringValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(attr.isa<StringAttr>() &&
+ "expected string value for non integer/index/float element");
+ stringValues.push_back(attr.cast<StringAttr>().getValue());
+ }
+ return get(type, stringValues);
+ }
+
+ // Otherwise, get the raw storage width to use for the allocation.
+ size_t bitWidth = getDenseElementBitWidth(eltType);
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+ // Compress the attribute values into a character buffer.
+ SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+ values.size());
+ APInt intVal;
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ assert(eltType == values[i].getType() &&
+ "expected attribute value to have element type");
+ if (eltType.isa<FloatType>())
+ intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
+ else if (eltType.isa<IntegerType>())
+ intVal = values[i].cast<IntegerAttr>().getValue();
+ else
+ llvm_unreachable("unexpected element type");
+
+ assert(intVal.getBitWidth() == bitWidth &&
+ "expected value to have same bitwidth as element type");
+ writeBits(data.data(), i * storageBitWidth, intVal);
+ }
+ return DenseIntOrFPElementsAttr::getRaw(type, data,
+ /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<bool> values) {
+ assert(hasSameElementsOrSplat(type, values));
+ assert(type.getElementType().isInteger(1));
+
+ std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
+ for (int i = 0, e = values.size(); i != e; ++i)
+ setBit(buff.data(), i, values[i]);
+ return DenseIntOrFPElementsAttr::getRaw(type, buff,
+ /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<StringRef> values) {
+ assert(!type.getElementType().isIntOrFloat());
+ return DenseStringElementsAttr::get(type, values);
+}
+
+/// Constructs a dense integer elements attribute from an array of APInt
+/// values. Each APInt value is expected to have the same bitwidth as the
+/// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<APInt> values) {
+ assert(type.getElementType().isIntOrIndex());
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+ /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<std::complex<APInt>> values) {
+ ComplexType complex = type.getElementType().cast<ComplexType>();
+ assert(complex.getElementType().isa<IntegerType>());
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+ ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
+ values.size() * 2);
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
+ /*isSplat=*/(values.size() == 1));
+}
+
+// Constructs a dense float elements attribute from an array of APFloat
+// values. Each APFloat value is expected to have the same bitwidth as the
+// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<APFloat> values) {
+ assert(type.getElementType().isa<FloatType>());
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+ /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr
+DenseElementsAttr::get(ShapedType type,
+ ArrayRef<std::complex<APFloat>> values) {
+ ComplexType complex = type.getElementType().cast<ComplexType>();
+ assert(complex.getElementType().isa<FloatType>());
+ assert(hasSameElementsOrSplat(type, values));
+ ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
+ values.size() * 2);
+ size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
+ /*isSplat=*/(values.size() == 1));
+}
+
+/// Construct a dense elements attribute from a raw buffer representing the
+/// data for this attribute. Users should generally not use this methods as
+/// the expected buffer format may not be a form the user expects.
+DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool isSplatBuffer) {
+ return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
+}
+
+/// Returns true if the given buffer is a valid raw buffer for the given type.
+bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool &detectedSplat) {
+ size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
+ size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
+
+ // Storage width of 1 is special as it is packed by the bit.
+ if (storageWidth == 1) {
+ // Check for a splat, or a buffer equal to the number of elements.
+ if ((detectedSplat = rawBuffer.size() == 1))
+ return true;
+ return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
+ }
+ // All other types are 8-bit aligned.
+ if ((detectedSplat = rawBufferWidth == storageWidth))
+ return true;
+ return rawBufferWidth == (storageWidth * type.getNumElements());
+}
+
+/// Check the information for a C++ data type, check if this type is valid for
+/// the current attribute. This method is used to verify specific type
+/// invariants that the templatized 'getValues' method cannot.
+static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
+ bool isSigned) {
+ // Make sure that the data element size is the same as the type element width.
+ if (getDenseElementBitWidth(type) !=
+ static_cast<size_t>(dataEltSize * CHAR_BIT))
+ return false;
+
+ // Check that the element type is either float or integer or index.
+ if (!isInt)
+ return type.isa<FloatType>();
+ if (type.isIndex())
+ return true;
+
+ auto intType = type.dyn_cast<IntegerType>();
+ if (!intType)
+ return false;
+
+ // Make sure signedness semantics is consistent.
+ if (intType.isSignless())
+ return true;
+ return intType.isSigned() ? isSigned : !isSigned;
+}
+
+/// Defaults down the subclass implementation.
+DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt, bool isSigned) {
+ return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
+ isSigned);
+}
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt,
+ bool isSigned) {
+ return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
+ isInt, isSigned);
+}
+
+/// A method used to verify specific type invariants that the templatized 'get'
+/// method cannot.
+bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
+ bool isSigned) const {
+ return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
+ isSigned);
+}
+
+/// Check the information for a C++ data type, check if this type is valid for
+/// the current attribute.
+bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
+ bool isSigned) const {
+ return ::isValidIntOrFloat(
+ getType().getElementType().cast<ComplexType>().getElementType(),
+ dataEltSize / 2, isInt, isSigned);
+}
+
+/// Returns true if this attribute corresponds to a splat, i.e. if all element
+/// values are the same.
+bool DenseElementsAttr::isSplat() const {
+ return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
+}
+
+/// Return the held element values as a range of Attributes.
+auto DenseElementsAttr::getAttributeValues() const
+ -> llvm::iterator_range<AttributeElementIterator> {
+ return {attr_value_begin(), attr_value_end()};
+}
+auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
+ return AttributeElementIterator(*this, 0);
+}
+auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
+ return AttributeElementIterator(*this, getNumElements());
+}
+
+/// Return the held element values as a range of bool. The element type of
+/// this attribute must be of integer type of bitwidth 1.
+auto DenseElementsAttr::getBoolValues() const
+ -> llvm::iterator_range<BoolElementIterator> {
+ auto eltType = getType().getElementType().dyn_cast<IntegerType>();
+ assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
+ (void)eltType;
+ return {BoolElementIterator(*this, 0),
+ BoolElementIterator(*this, getNumElements())};
+}
+
+/// Return the held element values as a range of APInts. The element type of
+/// this attribute must be of integer type.
+auto DenseElementsAttr::getIntValues() const
+ -> llvm::iterator_range<IntElementIterator> {
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
+ return {raw_int_begin(), raw_int_end()};
+}
+auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
+ return raw_int_begin();
+}
+auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
+ assert(getType().getElementType().isIntOrIndex() && "expected integral type");
+ return raw_int_end();
+}
+auto DenseElementsAttr::getComplexIntValues() const
+ -> llvm::iterator_range<ComplexIntElementIterator> {
+ Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+ (void)eltTy;
+ assert(eltTy.isa<IntegerType>() && "expected complex integral type");
+ return {ComplexIntElementIterator(*this, 0),
+ ComplexIntElementIterator(*this, getNumElements())};
+}
+
+/// Return the held element values as a range of APFloat. The element type of
+/// this attribute must be of float type.
+auto DenseElementsAttr::getFloatValues() const
+ -> llvm::iterator_range<FloatElementIterator> {
+ auto elementType = getType().getElementType().cast<FloatType>();
+ const auto &elementSemantics = elementType.getFloatSemantics();
+ return {FloatElementIterator(elementSemantics, raw_int_begin()),
+ FloatElementIterator(elementSemantics, raw_int_end())};
+}
+auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
+ return getFloatValues().begin();
+}
+auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
+ return getFloatValues().end();
+}
+auto DenseElementsAttr::getComplexFloatValues() const
+ -> llvm::iterator_range<ComplexFloatElementIterator> {
+ Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+ assert(eltTy.isa<FloatType>() && "expected complex float type");
+ const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
+ return {{semantics, {*this, 0}},
+ {semantics, {*this, static_cast<size_t>(getNumElements())}}};
+}
+
+/// Return the raw storage data held by this attribute.
+ArrayRef<char> DenseElementsAttr::getRawData() const {
+ return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
+}
+
+ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
+ return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
+}
+
+/// Return a new DenseElementsAttr that has the same data as the current
+/// attribute, but has been reshaped to 'newType'. The new type must have the
+/// same total number of elements as well as element type.
+DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+ ShapedType curType = getType();
+ if (curType == newType)
+ return *this;
+
+ (void)curType;
+ assert(newType.getElementType() == curType.getElementType() &&
+ "expected the same element type");
+ assert(newType.getNumElements() == curType.getNumElements() &&
+ "expected the same number of elements");
+ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
+}
+
+DenseElementsAttr
+DenseElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const {
+ return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
+ return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+}
+
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseStringElementsAttr
+DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
+ return Base::get(type.getContext(), type, values, (values.size() == 1));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+/// Utility method to write a range of APInt values to a buffer.
+template <typename APRangeT>
+static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
+ APRangeT &&values) {
+ data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
+ size_t offset = 0;
+ for (auto it = values.begin(), e = values.end(); it != e;
+ ++it, offset += storageWidth) {
+ assert((*it).getBitWidth() <= storageWidth);
+ writeBits(data.data(), offset, *it);
+ }
+}
+
+/// Constructs a dense elements attribute from an array of raw APFloat values.
+/// Each APFloat value is expected to have the same bitwidth as the element
+/// type of 'type'. 'type' must be a vector or tensor with static shape.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+ size_t storageWidth,
+ ArrayRef<APFloat> values,
+ bool isSplat) {
+ std::vector<char> data;
+ auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
+ writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
+ return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+}
+
+/// Constructs a dense elements attribute from an array of raw APInt values.
+/// Each APInt value is expected to have the same bitwidth as the element type
+/// of 'type'.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+ size_t storageWidth,
+ ArrayRef<APInt> values,
+ bool isSplat) {
+ std::vector<char> data;
+ writeAPIntsToBuffer(storageWidth, data, values);
+ return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+}
+
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+ ArrayRef<char> data,
+ bool isSplat) {
+ assert((type.isa<RankedTensorType, VectorType>()) &&
+ "type must be ranked tensor or vector");
+ assert(type.hasStaticShape() && "type must have static shape");
+ return Base::get(type.getContext(), type, data, isSplat);
+}
+
+/// Overload of the raw 'get' method that asserts that the given type is of
+/// complex type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt,
+ bool isSigned) {
+ assert(::isValidIntOrFloat(
+ type.getElementType().cast<ComplexType>().getElementType(),
+ dataEltSize / 2, isInt, isSigned));
+
+ int64_t numElements = data.size() / dataEltSize;
+ assert(numElements == 1 || numElements == type.getNumElements());
+ return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
+/// Overload of the 'getRaw' method that asserts that the given type is of
+/// integer type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr
+DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned) {
+ assert(
+ ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
+
+ int64_t numElements = data.size() / dataEltSize;
+ assert(numElements == 1 || numElements == type.getNumElements());
+ return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
+void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ const char *inRawData, char *outRawData, size_t elementBitWidth,
+ size_t numElements) {
+ using llvm::support::ulittle16_t;
+ using llvm::support::ulittle32_t;
+ using llvm::support::ulittle64_t;
+
+ assert(llvm::support::endian::system_endianness() == // NOLINT
+ llvm::support::endianness::big); // NOLINT
+ // NOLINT to avoid warning message about replacing by static_assert()
+
+ // Following std::copy_n always converts endianness on BE machine.
+ switch (elementBitWidth) {
+ case 16: {
+ const ulittle16_t *inRawDataPos =
+ reinterpret_cast<const ulittle16_t *>(inRawData);
+ uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
+ std::copy_n(inRawDataPos, numElements, outDataPos);
+ break;
+ }
+ case 32: {
+ const ulittle32_t *inRawDataPos =
+ reinterpret_cast<const ulittle32_t *>(inRawData);
+ uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
+ std::copy_n(inRawDataPos, numElements, outDataPos);
+ break;
+ }
+ case 64: {
+ const ulittle64_t *inRawDataPos =
+ reinterpret_cast<const ulittle64_t *>(inRawData);
+ uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
+ std::copy_n(inRawDataPos, numElements, outDataPos);
+ break;
+ }
+ default: {
+ size_t nBytes = elementBitWidth / CHAR_BIT;
+ for (size_t i = 0; i < nBytes; i++)
+ std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
+ break;
+ }
+ }
+}
+
+void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
+ ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
+ ShapedType type) {
+ size_t numElements = type.getNumElements();
+ Type elementType = type.getElementType();
+ if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
+ elementType = complexTy.getElementType();
+ numElements = numElements * 2;
+ }
+ size_t elementBitWidth = getDenseElementStorageWidth(elementType);
+ assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
+ inRawData.size() <= outRawData.size());
+ convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
+ elementBitWidth, numElements);
+}
+
+//===----------------------------------------------------------------------===//
+// DenseFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+template <typename Fn, typename Attr>
+static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
+ Type newElementType,
+ llvm::SmallVectorImpl<char> &data) {
+ size_t bitWidth = getDenseElementBitWidth(newElementType);
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+ ShapedType newArrayType;
+ if (inType.isa<RankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<UnrankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<VectorType>())
+ newArrayType = VectorType::get(inType.getShape(), newElementType);
+ else
+ assert(newArrayType && "Unhandled tensor type");
+
+ size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
+ data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
+
+ // Functor used to process a single element value of the attribute.
+ auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
+ auto newInt = mapping(value);
+ assert(newInt.getBitWidth() == bitWidth);
+ writeBits(data.data(), index * storageBitWidth, newInt);
+ };
+
+ // Check for the splat case.
+ if (attr.isSplat()) {
+ processElt(*attr.begin(), /*index=*/0);
+ return newArrayType;
+ }
+
+ // Otherwise, process all of the element values.
+ uint64_t elementIdx = 0;
+ for (auto value : attr)
+ processElt(value, elementIdx++);
+ return newArrayType;
+}
+
+DenseElementsAttr DenseFPElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType =
+ mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+ return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseFPElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseElementsAttr>() &&
+ attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseElementsAttr DenseIntElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType =
+ mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+ return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseIntElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseElementsAttr>() &&
+ attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
+}
+
+//===----------------------------------------------------------------------===//
+// OpaqueElementsAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
+ StringRef bytes) {
+ assert(TensorType::isValidElementType(type.getElementType()) &&
+ "Input element type should be a valid tensor element type");
+ return Base::get(type.getContext(), type, dialect, bytes);
+}
+
+StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ return Attribute();
+}
+
+Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
+
+bool OpaqueElementsAttr::decode(ElementsAttr &result) {
+ auto *d = getDialect();
+ if (!d)
+ return true;
+ auto *interface =
+ d->getRegisteredInterface<DialectDecodeAttributesInterface>();
+ if (!interface)
+ return true;
+ return failed(interface->decode(*this, result));
+}
+
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+SparseElementsAttr SparseElementsAttr::get(ShapedType type,
+ DenseElementsAttr indices,
+ DenseElementsAttr values) {
+ assert(indices.getType().getElementType().isInteger(64) &&
+ "expected sparse indices to be 64-bit integer values");
+ assert((type.isa<RankedTensorType, VectorType>()) &&
+ "type must be ranked tensor or vector");
+ assert(type.hasStaticShape() && "type must have static shape");
+ return Base::get(type.getContext(), type,
+ indices.cast<DenseIntElementsAttr>(), values);
+}
+
+DenseIntElementsAttr SparseElementsAttr::getIndices() const {
+ return getImpl()->indices;
+}
+
+DenseElementsAttr SparseElementsAttr::getValues() const {
+ return getImpl()->values;
+}
+
+/// Return the value of the element at the given index.
+Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ auto type = getType();
+
+ // The sparse indices are 64-bit integers, so we can reinterpret the raw data
+ // as a 1-D index array.
+ auto sparseIndices = getIndices();
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
+
+ // Check to see if the indices are a splat.
+ if (sparseIndices.isSplat()) {
+ // If the index is also not a splat of the index value, we know that the
+ // value is zero.
+ auto splatIndex = *sparseIndexValues.begin();
+ if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
+ return getZeroAttr();
+
+ // If the indices are a splat, we also expect the values to be a splat.
+ assert(getValues().isSplat() && "expected splat values");
+ return getValues().getSplatValue();
+ }
+
+ // Build a mapping between known indices and the offset of the stored element.
+ llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
+ auto numSparseIndices = sparseIndices.getType().getDimSize(0);
+ size_t rank = type.getRank();
+ for (size_t i = 0, e = numSparseIndices; i != e; ++i)
+ mappedIndices.try_emplace(
+ {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
+
+ // Look for the provided index key within the mapped indices. If the provided
+ // index is not found, then return a zero attribute.
+ auto it = mappedIndices.find(index);
+ if (it == mappedIndices.end())
+ return getZeroAttr();
+
+ // Otherwise, return the held sparse value element.
+ return getValues().getValue(it->second);
+}
+
+/// Get a zero APFloat for the given sparse attribute.
+APFloat SparseElementsAttr::getZeroAPFloat() const {
+ auto eltType = getType().getElementType().cast<FloatType>();
+ return APFloat(eltType.getFloatSemantics());
+}
+
+/// Get a zero APInt for the given sparse attribute.
+APInt SparseElementsAttr::getZeroAPInt() const {
+ auto eltType = getType().getElementType().cast<IntegerType>();
+ return APInt::getNullValue(eltType.getWidth());
+}
+
+/// Get a zero attribute for the given attribute type.
+Attribute SparseElementsAttr::getZeroAttr() const {
+ auto eltType = getType().getElementType();
+
+ // Handle floating point elements.
+ if (eltType.isa<FloatType>())
+ return FloatAttr::get(eltType, 0);
+
+ // Otherwise, this is an integer.
+ // TODO: Handle StringAttr here.
+ return IntegerAttr::get(eltType, 0);
+}
+
+/// Flatten, and return, all of the sparse indices in this attribute in
+/// row-major order.
+std::vector<ptr
diff _t> SparseElementsAttr::getFlattenedSparseIndices() const {
+ std::vector<ptr
diff _t> flatSparseIndices;
+
+ // The sparse indices are 64-bit integers, so we can reinterpret the raw data
+ // as a 1-D index array.
+ auto sparseIndices = getIndices();
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
+ if (sparseIndices.isSplat()) {
+ SmallVector<uint64_t, 8> indices(getType().getRank(),
+ *sparseIndexValues.begin());
+ flatSparseIndices.push_back(getFlattenedIndex(indices));
+ return flatSparseIndices;
+ }
+
+ // Otherwise, reinterpret each index as an ArrayRef when flattening.
+ auto numSparseIndices = sparseIndices.getType().getDimSize(0);
+ size_t rank = getType().getRank();
+ for (size_t i = 0, e = numSparseIndices; i != e; ++i)
+ flatSparseIndices.push_back(getFlattenedIndex(
+ {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
+ return flatSparseIndices;
+}
+
+//===----------------------------------------------------------------------===//
+// MutableDictionaryAttr
+//===----------------------------------------------------------------------===//
+
+MutableDictionaryAttr::MutableDictionaryAttr(
+ ArrayRef<NamedAttribute> attributes) {
+ setAttrs(attributes);
+}
+
+/// Return the underlying dictionary attribute.
+DictionaryAttr
+MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
+ // Construct empty DictionaryAttr if needed.
+ if (!attrs)
+ return DictionaryAttr::get({}, context);
+ return attrs;
+}
+
+ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
+ return attrs ? attrs.getValue() : llvm::None;
+}
+
+/// Replace the held attributes with ones provided in 'newAttrs'.
+void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
+ // Don't create an attribute list if there are no attributes.
+ if (attributes.empty())
+ attrs = nullptr;
+ else
+ attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute MutableDictionaryAttr::get(StringRef name) const {
+ return attrs ? attrs.get(name) : nullptr;
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute MutableDictionaryAttr::get(Identifier name) const {
+ return attrs ? attrs.get(name) : nullptr;
+}
+
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
+ return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+Optional<NamedAttribute>
+MutableDictionaryAttr::getNamed(Identifier name) const {
+ return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+void MutableDictionaryAttr::set(Identifier name, Attribute value) {
+ assert(value && "attributes may never be null");
+
+ // Look for an existing value for the given name, and set it in-place.
+ ArrayRef<NamedAttribute> values = getAttrs();
+ const auto *it = llvm::find_if(
+ values, [name](NamedAttribute attr) { return attr.first == name; });
+ if (it != values.end()) {
+ // Bail out early if the value is the same as what we already have.
+ if (it->second == value)
+ return;
+
+ SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
+ newAttrs[it - values.begin()].second = value;
+ attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
+ return;
+ }
+
+ // Otherwise, insert the new attribute into its sorted position.
+ it = llvm::lower_bound(values, name);
+ SmallVector<NamedAttribute, 8> newAttrs;
+ newAttrs.reserve(values.size() + 1);
+ newAttrs.append(values.begin(), it);
+ newAttrs.push_back({name, value});
+ newAttrs.append(it, values.end());
+ attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
+}
+
+/// Remove the attribute with the specified name if it exists. The return
+/// value indicates whether the attribute was present or not.
+auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
+ auto origAttrs = getAttrs();
+ for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
+ if (origAttrs[i].first == name) {
+ // Handle the simple case of removing the only attribute in the list.
+ if (e == 1) {
+ attrs = nullptr;
+ return RemoveResult::Removed;
+ }
+
+ SmallVector<NamedAttribute, 8> newAttrs;
+ newAttrs.reserve(origAttrs.size() - 1);
+ newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
+ newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
+ attrs = DictionaryAttr::getWithSorted(newAttrs,
+ newAttrs[0].second.getContext());
+ return RemoveResult::Removed;
+ }
+ }
+ return RemoveResult::NotFound;
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 237525ffd10d..50a5a64da69e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
using namespace mlir;
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 9e7af9047692..42cdb3a91a50 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRIR
Attributes.cpp
Block.cpp
Builders.cpp
+ BuiltinAttributes.cpp
BuiltinDialect.cpp
BuiltinTypes.cpp
Diagnostics.cpp
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 723e96bc8ce7..0d4d1617d5af 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -13,10 +13,10 @@
#include "mlir-c/IR.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h"
-#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardDialect.h"
#include <assert.h>
@@ -739,7 +739,7 @@ bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
return !strncmp(lhs, rhs.data, rhs.length);
}
-int printStandardAttributes(MlirContext ctx) {
+int printBuiltinAttributes(MlirContext ctx) {
MlirAttribute floating =
mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
if (!mlirAttributeIsAFloat(floating) ||
@@ -1334,7 +1334,7 @@ int main() {
if (printBuiltinTypes(ctx))
return 2;
- if (printStandardAttributes(ctx))
+ if (printBuiltinAttributes(ctx))
return 3;
if (printAffineMap(ctx))
return 4;
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9d47a5c96406..e80e2f6ae7a6 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Identifier.h"
#include "gtest/gtest.h"
More information about the Mlir-commits
mailing list