[Mlir-commits] [mlir] 120591e - [mlir] Replace usages of Identifier with StringAttr
River Riddle
llvmlistbot at llvm.org
Wed Nov 10 18:06:41 PST 2021
Author: River Riddle
Date: 2021-11-11T02:02:24Z
New Revision: 120591e126f97924f665baacea49080642b1162f
URL: https://github.com/llvm/llvm-project/commit/120591e126f97924f665baacea49080642b1162f
DIFF: https://github.com/llvm/llvm-project/commit/120591e126f97924f665baacea49080642b1162f.diff
LOG: [mlir] Replace usages of Identifier with StringAttr
Identifier and StringAttr essentially serve the same purpose, i.e. to hold a string value. Keeping these seemingly identical pieces of functionality separate has caused problems in certain situations:
* Identifier has nice accessors that StringAttr doesn't
* Identifier can't be used as an Attribute, meaning strings are often duplicated between Identifier/StringAttr (e.g. in PDL)
The only thing that Identifier has that StringAttr doesn't is support for caching a dialect that is referenced by the string (e.g. dialect.foo). This functionality is added to StringAttr, as this is useful for StringAttr in generally the same ways it was useful for Identifier.
Differential Revision: https://reviews.llvm.org/D113536
Added:
Modified:
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Identifier.h
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/SymbolTable.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/IR/TestPrintNesting.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 8366b0bce6d70..7fd47504aa85b 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -30,7 +30,7 @@ DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
-DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
+DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 2964246ae2a59..dec3e69beeb49 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -15,6 +15,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StorageUniquerSupport.h"
+#include "mlir/IR/Types.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/Twine.h"
@@ -118,7 +119,7 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
public:
/// Get the type of this attribute.
- Type getType() const;
+ Type getType() const { return type; }
/// Return the abstract descriptor for this attribute.
const AbstractAttribute &getAbstractAttribute() const {
@@ -131,24 +132,27 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
/// Note: All attributes require a valid type. If no type is provided here,
/// the type of the attribute will automatically default to NoneType
/// upon initialization in the uniquer.
- AttributeStorage(Type type);
- AttributeStorage();
+ AttributeStorage(Type type = nullptr) : type(type) {}
/// Set the type of this attribute.
- void setType(Type type);
+ void setType(Type newType) { type = newType; }
- // Set the abstract attribute for this storage instance. This is used by the
- // AttributeUniquer when initializing a newly constructed storage object.
- void initialize(const AbstractAttribute &abstractAttr) {
+ /// Set the abstract attribute for this storage instance. This is used by the
+ /// AttributeUniquer when initializing a newly constructed storage object.
+ void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
abstractAttribute = &abstractAttr;
}
+ /// Default initialization for attribute storage classes that require no
+ /// additional initialization.
+ void initialize(MLIRContext *context) {}
+
private:
+ /// The type of the attribute value.
+ Type type;
+
/// The abstract descriptor for this attribute.
const AbstractAttribute *abstractAttribute;
-
- /// The opaque type of the attribute value.
- const void *type;
};
/// Default storage type for attributes that require no additional
@@ -188,6 +192,10 @@ class AttributeUniquer {
return ctx->getAttributeUniquer().get<typename T::ImplType>(
[ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
+
+ // Execute any additional attribute storage initialization with the
+ // context.
+ static_cast<typename T::ImplType *>(storage)->initialize(ctx);
},
T::getTypeID(), std::forward<Args>(args)...);
}
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 69e48d7a72561..74a4405abd59f 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -13,7 +13,10 @@
#include "llvm/Support/PointerLikeTypeTraits.h"
namespace mlir {
-class Identifier;
+class StringAttr;
+
+// TODO: Remove this when all usages have been replaced with StringAttr.
+using Identifier = StringAttr;
/// Attributes are known-constant values of operations.
///
@@ -61,7 +64,7 @@ class Attribute {
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
/// Return the type of this attribute.
- Type getType() const;
+ Type getType() const { return impl->getType(); }
/// Return the context this attribute belongs to.
MLIRContext *getContext() const;
@@ -126,7 +129,7 @@ template <typename U> U Attribute::cast() const {
}
inline ::llvm::hash_code hash_value(Attribute arg) {
- return ::llvm::hash_value(arg.impl);
+ return DenseMapInfo<const Attribute::ImplType *>::getHashValue(arg.impl);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ba0fbe49239fc..6dabb6f3bf06c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -885,7 +885,35 @@ auto SparseElementsAttr::value_begin() const -> iterator<T> {
};
return iterator<T>(llvm::seq<ptr
diff _t>(0, getNumElements()).begin(), mapFn);
}
-} // end namespace mlir.
+
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+/// Define comparisons for StringAttr against nullptr and itself to avoid the
+/// StringRef overloads from being chosen when not desirable.
+inline bool operator==(StringAttr lhs, std::nullptr_t) { return !lhs; }
+inline bool operator!=(StringAttr lhs, std::nullptr_t) {
+ return static_cast<bool>(lhs);
+}
+inline bool operator==(StringAttr lhs, StringAttr rhs) {
+ return (Attribute)lhs == (Attribute)rhs;
+}
+inline bool operator!=(StringAttr lhs, StringAttr rhs) { return !(lhs == rhs); }
+
+/// Allow direct comparison with StringRef.
+inline bool operator==(StringAttr lhs, StringRef rhs) {
+ return lhs.getValue() == rhs;
+}
+inline bool operator!=(StringAttr lhs, StringRef rhs) { return !(lhs == rhs); }
+inline bool operator==(StringRef lhs, StringAttr rhs) {
+ return rhs.getValue() == lhs;
+}
+inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }
+
+inline Type StringAttr::getType() const { return Attribute::getType(); }
+
+} // end namespace mlir
//===----------------------------------------------------------------------===//
// Attribute Utilities
@@ -893,12 +921,30 @@ auto SparseElementsAttr::value_begin() const -> iterator<T> {
namespace llvm {
+template <>
+struct DenseMapInfo<mlir::StringAttr> : public DenseMapInfo<mlir::Attribute> {
+ static mlir::StringAttr getEmptyKey() {
+ const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
+ return mlir::StringAttr::getFromOpaquePointer(pointer);
+ }
+ static mlir::StringAttr getTombstoneKey() {
+ const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
+ return mlir::StringAttr::getFromOpaquePointer(pointer);
+ }
+};
+template <>
+struct PointerLikeTypeTraits<mlir::StringAttr>
+ : public PointerLikeTypeTraits<mlir::Attribute> {
+ static inline mlir::StringAttr getFromVoidPointer(void *p) {
+ return mlir::StringAttr::getFromOpaquePointer(p);
+ }
+};
+
template <>
struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
: public PointerLikeTypeTraits<mlir::Attribute> {
static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
- return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
- .cast<mlir::SymbolRefAttr>();
+ return mlir::SymbolRefAttr::getFromOpaquePointer(ptr);
}
};
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index c6631cd79fe58..a0393df7aa896 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -915,6 +915,44 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
let extraClassDeclaration = [{
using ValueType = StringRef;
+ /// If the value of this string is prefixed with a dialect namespace,
+ /// returns the dialect corresponding to that namespace if it is loaded,
+ /// nullptr otherwise. For example, the string `llvm.fastmathflags` would
+ /// return the LLVM dialect, assuming it is loaded in the context.
+ Dialect *getReferencedDialect() const;
+
+ /// Enable conversion to StringRef.
+ operator StringRef() const { return getValue(); }
+
+ /// Returns the underlying string value
+ StringRef strref() const { return getValue(); }
+
+ /// Convert the underling value to an std::string.
+ std::string str() const { return getValue().str(); }
+
+ /// Return a pointer to the start of the string data.
+ const char *data() const { return getValue().data(); }
+
+ /// Return the number of bytes in this string.
+ size_t size() const { return getValue().size(); }
+
+ /// Iterate over the underlying string data.
+ StringRef::iterator begin() const { return getValue().begin(); }
+ StringRef::iterator end() const { return getValue().end(); }
+
+ /// Compare the underlying string value to the one in `rhs`.
+ int compare(StringAttr rhs) const {
+ if (*this == rhs)
+ return 0;
+ return getValue().compare(rhs.getValue());
+ }
+
+ /// FIXME: Defined as part of transition of Identifier->StringAttr. Prefer
+ /// using the other `get` methods instead.
+ static StringAttr get(const Twine &str, MLIRContext *context) {
+ return get(context, str);
+ }
+
private:
/// Return an empty StringAttr with NoneType type. This is a special variant
/// of the `get` method that is used by the MLIRContext to cache the
@@ -923,6 +961,7 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
friend MLIRContext;
public:
}];
+ let genStorageClass = 0;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f2f3ccf537626..0daabe25ea1cd 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -20,11 +20,14 @@ namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
-class Identifier;
class IndexType;
class IntegerType;
+class StringAttr;
class TypeRange;
+// TODO: Remove this when all usages have been replaced with StringAttr.
+using Identifier = StringAttr;
+
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 5f6d870fc805e..e8af2fe782f3d 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -24,7 +24,6 @@ class SourceMgr;
namespace mlir {
class DiagnosticEngine;
-class Identifier;
struct LogicalResult;
class MLIRContext;
class Operation;
@@ -196,6 +195,7 @@ class Diagnostic {
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
return *this;
}
+ Diagnostic &operator<<(StringAttr val);
/// Stream in a string literal.
Diagnostic &operator<<(const char *val) {
@@ -208,9 +208,6 @@ class Diagnostic {
Diagnostic &operator<<(const Twine &val);
Diagnostic &operator<<(Twine &&val);
- /// Stream in an Identifier.
- Diagnostic &operator<<(Identifier val);
-
/// Stream in an OperationName.
Diagnostic &operator<<(OperationName val);
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index c7ee9429d583e..51a3f9d582bf4 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -612,7 +612,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError(
"arguments may only have dialect attributes");
- if (Dialect *dialect = attr.first.getDialect()) {
+ if (Dialect *dialect = attr.first.getReferencedDialect()) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
@@ -645,7 +645,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
for (auto attr : resultAttrs) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes");
- if (Dialect *dialect = attr.first.getDialect()) {
+ if (Dialect *dialect = attr.first.getReferencedDialect()) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i,
attr)))
diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h
index 75208892b77f2..f72e1a2e06714 100644
--- a/mlir/include/mlir/IR/Identifier.h
+++ b/mlir/include/mlir/IR/Identifier.h
@@ -9,151 +9,12 @@
#ifndef MLIR_IR_IDENTIFIER_H
#define MLIR_IR_IDENTIFIER_H
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/PointerUnion.h"
-#include "llvm/ADT/StringMapEntry.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/PointerLikeTypeTraits.h"
+#include "mlir/IR/BuiltinAttributes.h"
namespace mlir {
-class Dialect;
-class MLIRContext;
-
-/// This class represents a uniqued string owned by an MLIRContext. Strings
-/// represented by this type cannot contain nul characters, and may not have a
-/// zero length.
-///
-/// This is a POD type with pointer size, so it should be passed around by
-/// value. The underlying data is owned by MLIRContext and is thus immortal for
-/// almost all clients.
-///
-/// An Identifier may be prefixed with a dialect namespace followed by a single
-/// dot `.`. This is particularly useful when used as a key in a NamedAttribute
-/// to
diff erentiate a dependent attribute (specific to an operation) from a
-/// generic attribute defined by the dialect (in general applicable to multiple
-/// operations).
-class Identifier {
- using EntryType =
- llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>>;
-
-public:
- /// Return an identifier for the specified string.
- static Identifier get(const Twine &string, MLIRContext *context);
-
- Identifier(const Identifier &) = default;
- Identifier &operator=(const Identifier &other) = default;
-
- /// Return a StringRef for the string.
- StringRef strref() const { return entry->first(); }
-
- /// Identifiers implicitly convert to StringRefs.
- operator StringRef() const { return strref(); }
-
- /// Return an std::string.
- std::string str() const { return strref().str(); }
-
- /// Return a null terminated C string.
- const char *c_str() const { return entry->getKeyData(); }
-
- /// Return a pointer to the start of the string data.
- const char *data() const { return entry->getKeyData(); }
-
- /// Return the number of bytes in this string.
- unsigned size() const { return entry->getKeyLength(); }
-
- /// Return the dialect loaded in the context for this identifier or nullptr if
- /// this identifier isn't prefixed with a loaded dialect. For example the
- /// `llvm.fastmathflags` identifier would return the LLVM dialect here,
- /// assuming it is loaded in the context.
- Dialect *getDialect();
-
- /// Return the current MLIRContext associated with this identifier.
- MLIRContext *getContext();
-
- const char *begin() const { return data(); }
- const char *end() const { return entry->getKeyData() + size(); }
-
- bool operator==(Identifier other) const { return entry == other.entry; }
- bool operator!=(Identifier rhs) const { return !(*this == rhs); }
-
- void print(raw_ostream &os) const;
- void dump() const;
-
- const void *getAsOpaquePointer() const {
- return static_cast<const void *>(entry);
- }
- static Identifier getFromOpaquePointer(const void *entry) {
- return Identifier(static_cast<const EntryType *>(entry));
- }
-
- /// Compare the underlying StringRef.
- int compare(Identifier rhs) const { return strref().compare(rhs.strref()); }
-
-private:
- /// This contains the bytes of the string, which is guaranteed to be nul
- /// terminated.
- const EntryType *entry;
- explicit Identifier(const EntryType *entry) : entry(entry) {}
-};
-
-inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) {
- identifier.print(os);
- return os;
-}
-
-// Identifier/Identifier equality comparisons are defined inline.
-inline bool operator==(Identifier lhs, StringRef rhs) {
- return lhs.strref() == rhs;
-}
-inline bool operator!=(Identifier lhs, StringRef rhs) { return !(lhs == rhs); }
-
-inline bool operator==(StringRef lhs, Identifier rhs) {
- return rhs.strref() == lhs;
-}
-inline bool operator!=(StringRef lhs, Identifier rhs) { return !(lhs == rhs); }
-
-// Make identifiers hashable.
-inline llvm::hash_code hash_value(Identifier arg) {
- // Identifiers are uniqued, so we can just hash the pointer they contain.
- return llvm::hash_value(arg.getAsOpaquePointer());
-}
+/// NOTICE: Identifier is deprecated and usages of it should be replaced with
+/// StringAttr.
+using Identifier = StringAttr;
} // end namespace mlir
-namespace llvm {
-// Identifiers hash just like pointers, there is no need to hash the bytes.
-template <>
-struct DenseMapInfo<mlir::Identifier> {
- static mlir::Identifier getEmptyKey() {
- auto pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
- return mlir::Identifier::getFromOpaquePointer(pointer);
- }
- static mlir::Identifier getTombstoneKey() {
- auto pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
- return mlir::Identifier::getFromOpaquePointer(pointer);
- }
- static unsigned getHashValue(mlir::Identifier val) {
- return mlir::hash_value(val);
- }
- static bool isEqual(mlir::Identifier lhs, mlir::Identifier rhs) {
- return lhs == rhs;
- }
-};
-
-/// The pointer inside of an identifier comes from a StringMap, so its alignment
-/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
-/// steal the low bits.
-template <>
-struct PointerLikeTypeTraits<mlir::Identifier> {
-public:
- static inline void *getAsVoidPointer(mlir::Identifier i) {
- return const_cast<void *>(i.getAsOpaquePointer());
- }
- static inline mlir::Identifier getFromVoidPointer(void *p) {
- return mlir::Identifier::getFromOpaquePointer(p);
- }
- static constexpr int NumLowBitsAvailable = 2;
-};
-
-} // end namespace llvm
#endif
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index b625a4f0ed0e9..1077c9f49b89a 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -19,7 +19,6 @@
namespace mlir {
-class Identifier;
class Location;
class WalkResult;
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index ca3a1a61f85f1..887306fbc545c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -456,7 +456,7 @@ class OperationName {
Dialect *getDialect() const {
if (const auto *abstractOp = getAbstractOperation())
return &abstractOp->dialect;
- return representation.get<Identifier>().getDialect();
+ return representation.get<Identifier>().getReferencedDialect();
}
/// Return the operation name with dialect name stripped, if it has one.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 28b3326e9f75d..8cd159e6f0438 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -164,8 +164,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Get an instance of the concrete type from a void pointer.
static ConcreteT getFromOpaquePointer(const void *ptr) {
- return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
- : nullptr;
+ return ConcreteT((const typename BaseT::ImplType *)ptr);
}
protected:
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 3950fee156abd..15433cd0b69c9 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -15,8 +15,6 @@
#include "llvm/ADT/StringMap.h"
namespace mlir {
-class Identifier;
-class Operation;
/// This class allows for representing and managing the symbol table used by
/// operations with the 'SymbolTable' trait. Inserting into and erasing from
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 4ab63ae06ab8f..9e68b9086368d 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -27,12 +27,15 @@ class Any;
namespace mlir {
class AnalysisManager;
-class Identifier;
class MLIRContext;
class Operation;
class Pass;
class PassInstrumentation;
class PassInstrumentor;
+class StringAttr;
+
+// TODO: Remove this when all usages have been replaced with StringAttr.
+using Identifier = StringAttr;
namespace detail {
struct OpPassManagerImpl;
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 35ac619955b23..170c77d43a298 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -105,8 +105,13 @@ class StorageUniquer {
/// Copy the provided string into memory managed by our bump pointer
/// allocator.
StringRef copyInto(StringRef str) {
- auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
- return StringRef(result.data(), str.size());
+ if (str.empty())
+ return StringRef();
+
+ char *result = allocator.Allocate<char>(str.size() + 1);
+ std::uninitialized_copy(str.begin(), str.end(), result);
+ result[str.size()] = 0;
+ return StringRef(result, str.size());
}
/// Allocate an instance of the provided type.
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 0c563e6e7d39c..beabfa41d3269 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -82,7 +82,7 @@ class LLVMTranslationInterface
amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
- getInterfaceFor(attribute.first.getDialect())) {
+ getInterfaceFor(attribute.first.getReferencedDialect())) {
return iface->amendOperation(op, attribute, moduleTranslation);
}
return success();
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cf59a67f9c8f0..4c25fd4505b76 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1845,7 +1845,8 @@ class PyOpAttributeMap {
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(
namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
}
void dunderSetItem(const std::string &name, PyAttribute attr) {
@@ -2601,7 +2602,8 @@ void mlir::python::populateIRCore(py::module &m) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(
- mlirIdentifierStr(self.namedAttr.name).data);
+ py::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d6c4ccf6a8b0..7ce4283600c8f 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -186,11 +186,11 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
- return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
+ return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
- return wrap(StringAttr::get(unwrap(str), unwrap(type)));
+ return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
}
MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8bed10a9d9d56..7bbbc4a1d24a0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -805,7 +805,7 @@ MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
MlirOperation operation) {
- return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
+ return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
}
void mlirSymbolTableErase(MlirSymbolTable symbolTable,
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 87fbe2585a6c6..243e7865be1a5 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -154,7 +154,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
} else {
auto id = entry.getKey().get<Identifier>();
if (!ids.insert(id).second)
- return emitError() << "repeated layout entry key: " << id;
+ return emitError() << "repeated layout entry key: " << id.getValue();
}
}
return success();
@@ -221,7 +221,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
for (const auto &kvp : newEntriesForID) {
Identifier id = kvp.second.getKey().get<Identifier>();
- Dialect *dialect = id.getDialect();
+ Dialect *dialect = id.getReferencedDialect();
if (!entriesForID.count(id)) {
entriesForID[id] = kvp.second;
continue;
@@ -377,6 +377,6 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return success();
}
- return op->emitError() << "attribute '" << attr.first
+ return op->emitError() << "attribute '" << attr.first.getValue()
<< "' not supported by dialect";
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d747a23644114..d0db26a1bb3bb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -753,7 +753,7 @@ struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
// Copy over unknown attributes. They might be load bearing for some flow.
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs()) {
- if (!llvm::is_contained(odsAttrs, kv.first.c_str())) {
+ if (!llvm::is_contained(odsAttrs, kv.first.getValue())) {
newOp->setAttr(kv.first, kv.second);
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 15625255f3374..0dd935557d121 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -46,10 +46,6 @@
using namespace mlir;
using namespace mlir::detail;
-void Identifier::print(raw_ostream &os) const { os << str(); }
-
-void Identifier::dump() const { print(llvm::errs()); }
-
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
@@ -1339,7 +1335,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
})
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
if (pretty) {
- os << loc.getFilename();
+ os << loc.getFilename().getValue();
} else {
os << "\"";
printEscapedString(loc.getFilename(), os);
@@ -1693,7 +1689,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
- os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
+ os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
<< llvm::toHex(opaqueAttr.getValue()) << "\">";
}
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index f62d886cacb62..03373ae70abe4 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -319,6 +319,41 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
ArrayRef<StringRef> data;
};
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+struct StringAttrStorage : public AttributeStorage {
+ StringAttrStorage(StringRef value, Type type)
+ : AttributeStorage(type), value(value), referencedDialect(nullptr) {}
+
+ /// The hash key is a tuple of the parameter types.
+ using KeyTy = std::pair<StringRef, Type>;
+ bool operator==(const KeyTy &key) const {
+ return value == key.first && getType() == key.second;
+ }
+ static ::llvm::hash_code hashKey(const KeyTy &key) {
+ return DenseMapInfo<KeyTy>::getHashValue(key);
+ }
+
+ /// Define a construction method for creating a new instance of this
+ /// storage.
+ static StringAttrStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<StringAttrStorage>())
+ StringAttrStorage(allocator.copyInto(key.first), key.second);
+ }
+
+ /// Initialize the storage given an MLIRContext.
+ void initialize(MLIRContext *context);
+
+ /// The raw string value.
+ StringRef value;
+ /// If the string value contains a dialect namespace prefix (e.g.
+ /// dialect.blah), this is the dialect referenced.
+ Dialect *referencedDialect;
+};
+
} // namespace detail
} // namespace mlir
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 4cc501d5819a6..dc6149fc262f0 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -12,28 +12,10 @@
using namespace mlir;
using namespace mlir::detail;
-//===----------------------------------------------------------------------===//
-// AttributeStorage
-//===----------------------------------------------------------------------===//
-
-AttributeStorage::AttributeStorage(Type type)
- : type(type.getAsOpaquePointer()) {}
-AttributeStorage::AttributeStorage() : type(nullptr) {}
-
-Type AttributeStorage::getType() const {
- return Type::getFromOpaquePointer(type);
-}
-void AttributeStorage::setType(Type newType) {
- type = newType.getAsOpaquePointer();
-}
-
//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//
-/// Return the type of this attribute.
-Type Attribute::getType() const { return impl->getType(); }
-
/// Return the context this attribute belongs to.
MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
@@ -42,13 +24,8 @@ MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
//===----------------------------------------------------------------------===//
bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
- return strcmp(lhs.first.data(), rhs.first.data()) < 0;
+ return lhs.first.compare(rhs.first) < 0;
}
bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
- // This is correct even when attr.first.data()[name.size()] is not a zero
- // string terminator, because we only care about a less than comparison.
- // This can't use memcmp, because it doesn't guarantee that it will stop
- // reading both buffers if one is shorter than the other, even if there is
- // a
diff erence.
- return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
+ return lhs.first.getValue().compare(rhs) < 0;
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 38c8430268985..b99e988eb2276 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -264,6 +264,12 @@ StringAttr StringAttr::get(const Twine &twine, Type type) {
return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
}
+StringRef StringAttr::getValue() const { return getImpl()->value; }
+
+Dialect *StringAttr::getReferencedDialect() const {
+ return getImpl()->referencedDialect;
+}
+
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
@@ -1250,7 +1256,7 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
- Dialect *dialect = getDialect().getDialect();
+ Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
auto *interface =
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 20514957b5fa6..a007a487d8afc 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -253,7 +253,7 @@ static LogicalResult verify(ModuleOp op) {
attr.first.strref()))
return op.emitOpError() << "can only contain attributes with "
"dialect-prefixed names, found: '"
- << attr.first << "'";
+ << attr.first.getValue() << "'";
}
// Check that there is at most one data layout spec attribute.
@@ -266,7 +266,8 @@ static LogicalResult verify(ModuleOp op) {
op.emitOpError() << "expects at most one data layout attribute";
diag.attachNote() << "'" << layoutSpecAttrName
<< "' is a data layout attribute";
- diag.attachNote() << "'" << na.first << "' is a data layout attribute";
+ diag.attachNote() << "'" << na.first.getValue()
+ << "' is a data layout attribute";
}
layoutSpecAttrName = na.first.strref();
layoutSpec = spec;
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 31391b2b9405e..8400870b91351 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -8,7 +8,6 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -109,11 +108,8 @@ Diagnostic &Diagnostic::operator<<(Twine &&val) {
return *this;
}
-/// Stream in an Identifier.
-Diagnostic &Diagnostic::operator<<(Identifier val) {
- // An identifier is stored in the context, so we don't need to worry about the
- // lifetime of its data.
- arguments.push_back(DiagnosticArgument(val.strref()));
+Diagnostic &Diagnostic::operator<<(StringAttr val) {
+ arguments.push_back(DiagnosticArgument(val));
return *this;
}
@@ -469,7 +465,7 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
// the constructor of SMDiagnostic that takes a location.
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
- locOS << fileLoc->getFilename() << ":" << fileLoc->getLine() << ":"
+ locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
<< fileLoc->getColumn();
llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
diag.print(nullptr, os);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 065568379a06b..aa907ea75fc3d 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
@@ -33,6 +32,7 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/raw_ostream.h"
@@ -227,14 +227,6 @@ class MLIRContextImpl {
/// An action manager for use within the context.
DebugActionManager debugActionManager;
- //===--------------------------------------------------------------------===//
- // Identifier uniquing
- //===--------------------------------------------------------------------===//
-
- // Identifier allocator and mutex for thread safety.
- llvm::BumpPtrAllocator identifierAllocator;
- llvm::sys::SmartRWMutex<true> identifierMutex;
-
//===--------------------------------------------------------------------===//
// Diagnostics
//===--------------------------------------------------------------------===//
@@ -289,12 +281,6 @@ class MLIRContextImpl {
/// operations.
llvm::StringMap<AbstractOperation> registeredOperations;
- /// Identifiers are uniqued by string value and use the internal string set
- /// for storage.
- llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
- llvm::BumpPtrAllocator &>
- identifiers;
-
/// An allocator used for AbstractAttribute and AbstractType objects.
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -349,10 +335,15 @@ class MLIRContextImpl {
DictionaryAttr emptyDictionaryAttr;
StringAttr emptyStringAttr;
+ /// Map of string attributes that may reference a dialect, that are awaiting
+ /// that dialect to be loaded.
+ llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
+ DenseMap<StringRef, SmallVector<StringAttrStorage *>>
+ dialectReferencingStrAttrs;
+
public:
MLIRContextImpl(bool threadingIsEnabled)
- : threadingIsEnabled(threadingIsEnabled),
- identifiers(identifierAllocator) {
+ : threadingIsEnabled(threadingIsEnabled) {
if (threadingIsEnabled) {
ownedThreadPool = std::make_unique<llvm::ThreadPool>();
threadPool = ownedThreadPool.get();
@@ -541,12 +532,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
// Refresh all the identifiers dialect field, this catches cases where a
// dialect may be loaded after identifier prefixed with this dialect name
// were already created.
- llvm::SmallString<32> dialectPrefix(dialectNamespace);
- dialectPrefix.push_back('.');
- for (auto &identifierEntry : impl.identifiers)
- if (identifierEntry.second.is<MLIRContext *>() &&
- identifierEntry.first().startswith(dialectPrefix))
- identifierEntry.second = dialect.get();
+ auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
+ if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
+ for (StringAttrStorage *storage : stringAttrsIt->second)
+ storage->referencedDialect = dialect.get();
+ impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
+ }
// Actually register the interfaces with delayed registration.
impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
@@ -784,7 +775,8 @@ void AbstractOperation::insert(
MutableArrayRef<Identifier> cachedAttrNames;
if (!attrNames.empty()) {
cachedAttrNames = MutableArrayRef<Identifier>(
- impl.identifierAllocator.Allocate<Identifier>(attrNames.size()),
+ impl.abstractDialectSymbolAllocator.Allocate<Identifier>(
+ attrNames.size()),
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx));
@@ -840,63 +832,6 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
return it->second;
}
-//===----------------------------------------------------------------------===//
-// Identifier uniquing
-//===----------------------------------------------------------------------===//
-
-/// Return an identifier for the specified string.
-Identifier Identifier::get(const Twine &string, MLIRContext *context) {
- SmallString<32> tempStr;
- StringRef str = string.toStringRef(tempStr);
-
- // Check invariants after seeing if we already have something in the
- // identifier table - if we already had it in the table, then it already
- // passed invariant checks.
- assert(!str.empty() && "Cannot create an empty identifier");
- assert(!str.contains('\0') &&
- "Cannot create an identifier with a nul character");
-
- auto getDialectOrContext = [&]() {
- PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
- auto dialectNamePair = str.split('.');
- if (!dialectNamePair.first.empty())
- if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
- dialectOrContext = dialect;
- return dialectOrContext;
- };
-
- auto &impl = context->getImpl();
- if (!context->isMultithreadingEnabled()) {
- auto insertedIt = impl.identifiers.insert({str, nullptr});
- if (insertedIt.second)
- insertedIt.first->second = getDialectOrContext();
- return Identifier(&*insertedIt.first);
- }
-
- // Check for an existing identifier in read-only mode.
- {
- llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
- auto it = impl.identifiers.find(str);
- if (it != impl.identifiers.end())
- return Identifier(&*it);
- }
-
- // Acquire a writer-lock so that we can safely create the new instance.
- llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
- auto it = impl.identifiers.insert({str, getDialectOrContext()}).first;
- return Identifier(&*it);
-}
-
-Dialect *Identifier::getDialect() {
- return entry->second.dyn_cast<Dialect *>();
-}
-
-MLIRContext *Identifier::getContext() {
- if (Dialect *dialect = getDialect())
- return dialect->getContext();
- return entry->second.get<MLIRContext *>();
-}
-
//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//
@@ -995,7 +930,7 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
MLIRContext *ctx,
TypeID attrID) {
- storage->initialize(AbstractAttribute::lookup(attrID, ctx));
+ storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
// If the attribute did not provide a type, then default to NoneType.
if (!storage->getType())
@@ -1019,6 +954,24 @@ DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
return context->getImpl().emptyDictionaryAttr;
}
+void StringAttrStorage::initialize(MLIRContext *context) {
+ // Check for a dialect namespace prefix, if there isn't one we don't need to
+ // do any additional initialization.
+ auto dialectNamePair = value.split('.');
+ if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
+ return;
+
+ // If one exists, we check to see if this dialect is loaded. If it is, we set
+ // the dialect now, if it isn't we record this storage for initialization
+ // later if the dialect ever gets loaded.
+ if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
+ return;
+
+ MLIRContextImpl &impl = context->getImpl();
+ llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
+ impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
+}
+
/// Return an empty string.
StringAttr StringAttr::get(MLIRContext *context) {
return context->getImpl().emptyStringAttr;
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 4c9bc848ce472..c78fc52c9d240 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -73,10 +73,10 @@ void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
void NamedAttrList::push_back(NamedAttribute newAttribute) {
assert(newAttribute.second && "unexpected null attribute");
- if (isSorted())
- dictionarySorted.setInt(
- attrs.empty() ||
- strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0);
+ if (isSorted()) {
+ dictionarySorted.setInt(attrs.empty() ||
+ attrs.back().first.compare(newAttribute.first) < 0);
+ }
dictionarySorted.setPointer(nullptr);
attrs.push_back(newAttribute);
}
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 0b381d638bfd0..90acbecb626d8 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -170,7 +170,7 @@ LogicalResult OperationVerifier::verifyOperation(
/// Verify that all of the attributes are okay.
for (auto attr : op.getAttrs()) {
// Check for any optional dialect specific attributes.
- if (auto *dialect = attr.first.getDialect())
+ if (auto *dialect = attr.first.getReferencedDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr)))
return failure();
}
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 5d238a5bd326d..5223c31cd75c0 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -431,7 +431,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
for (const auto &kvp : ids) {
Identifier identifier = kvp.second.getKey().get<Identifier>();
- Dialect *dialect = identifier.getDialect();
+ Dialect *dialect = identifier.getReferencedDialect();
// Ignore attributes that belong to an unknown dialect, the dialect may
// actually implement the relevant interface but we don't know about that.
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 7150bfc6a682d..11b9f03b10d3e 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -273,7 +273,7 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("expected attribute name");
if (!seenKeys.insert(*nameId).second)
return emitError("duplicate key '")
- << *nameId << "' in dictionary attribute";
+ << nameId->getValue() << "' in dictionary attribute";
consumeToken();
// Lazy load a dialect in the context if there is a possible namespace.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 4d3f3b051a406..d69f16f840040 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1127,7 +1127,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
Optional<NamedAttribute> duplicate = opState.attributes.findDuplicate();
if (duplicate)
return emitError(getNameLoc(), "attribute '")
- << duplicate->first
+ << duplicate->first.getValue()
<< "' occurs more than once in the attribute list";
return success();
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 038ad85335cd0..abddb790fbc99 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -822,7 +822,7 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
if (llvm::is_contained(exclude, attr.first.strref()))
return success();
- os << "/* " << attr.first << " */";
+ os << "/* " << attr.first.getValue() << " */";
if (failed(emitAttribute(op.getLoc(), attr.second)))
return failure();
return success();
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 641df2859fb11..11a276470368c 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -221,7 +221,7 @@ class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
if (printAttrs) {
os << "\n";
for (const NamedAttribute &attr : op->getAttrs()) {
- os << '\n' << attr.first << ": ";
+ os << '\n' << attr.first.getValue() << ": ";
emitMlirAttr(os, attr.second);
}
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 117e2e515beb4..c3415c6e2a5d6 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -494,7 +494,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
- .setDistributionOptions(cyclicNprocsEqNiters),
+ .setDistributionOptions(cyclicNprocsEqNiters),
LinalgTransformationFilter(
Identifier::get("tensors_distribute1", context),
Identifier::get("tensors_after_distribute1", context)));
@@ -508,8 +508,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
SmallVector<RewritePatternSet, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
- fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
- stage1Patterns);
+ fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
@@ -519,8 +518,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
- fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
- stage1Patterns);
+ fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
}
{
// Canonicalization patterns
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 4df5730b282fd..d2b049bfb14cd 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -243,7 +243,7 @@ void TestDerivedAttributeDriver::runOnFunction() {
if (!dAttr)
return;
for (auto d : dAttr)
- dOp.emitRemark() << d.first << " = " << d.second;
+ dOp.emitRemark() << d.first.getValue() << " = " << d.second;
});
}
diff --git a/mlir/test/lib/IR/TestPrintNesting.cpp b/mlir/test/lib/IR/TestPrintNesting.cpp
index b85e0b788a087..f48a9181f1af5 100644
--- a/mlir/test/lib/IR/TestPrintNesting.cpp
+++ b/mlir/test/lib/IR/TestPrintNesting.cpp
@@ -37,8 +37,8 @@ struct TestPrintNestingPass
if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs())
- printIndent() << " - '" << attr.first << "' : '" << attr.second
- << "'\n";
+ printIndent() << " - '" << attr.first.getValue() << "' : '"
+ << attr.second << "'\n";
}
// Recurse into each of the regions attached to the operation.
More information about the Mlir-commits
mailing list