[Mlir-commits] [mlir] 41d4aa7 - [SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
Chris Lattner
llvmlistbot at llvm.org
Sun Aug 29 22:01:21 PDT 2021
Author: Chris Lattner
Date: 2021-08-29T21:54:47-07:00
New Revision: 41d4aa7de68ed551010f27ff04ffc54e7616292a
URL: https://github.com/llvm/llvm-project/commit/41d4aa7de68ed551010f27ff04ffc54e7616292a
DIFF: https://github.com/llvm/llvm-project/commit/41d4aa7de68ed551010f27ff04ffc54e7616292a.diff
LOG: [SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
SymbolRefAttr is fundamentally a base string plus a sequence
of nested references. Instead of storing the string data as
a copies StringRef, store it as an already-uniqued StringAttr.
This makes a lot of things simpler and more efficient because:
1) references to the symbol are already stored as StringAttr's:
there is no need to copy the string data into MLIRContext
multiple times.
2) This allows pointer comparisons instead of string
comparisons (or redundant uniquing) within SymbolTable.cpp.
3) This allows SymbolTable to hold a DenseMap instead of a
StringMap (which again copies the string data and slows
lookup).
This is a moderately invasive patch, so I kept a lot of
compatibility APIs around. It would be nice to explore changing
getName() to return a StringAttr for example (right now you have
to use getNameAttr()), and eliminate things like the StringRef
version of getSymbol.
Differential Revision: https://reviews.llvm.org/D108899
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/lib/IR/TestSymbolUses.cpp
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 92348a0f436e0..5eb0fbae06051 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -378,10 +378,10 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
unsigned getNumKernelOperands();
/// The name of the kernel's containing module.
- StringRef getKernelModuleName();
+ StringAttr getKernelModuleName();
/// The name of the kernel.
- StringRef getKernelName();
+ StringAttr getKernelName();
/// The i-th operand passed to the kernel function.
Value getKernelOperand(unsigned i);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index e27a1cb7ffb37..7e6aa710e1f94 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -98,9 +98,16 @@ class Builder {
StringAttr getStringAttr(const Twine &bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
- FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
- SymbolRefAttr getSymbolRefAttr(StringRef value,
+ FlatSymbolRefAttr getSymbolRefAttr(StringAttr value);
+ SymbolRefAttr getSymbolRefAttr(StringAttr value,
ArrayRef<FlatSymbolRefAttr> nestedReferences);
+ SymbolRefAttr getSymbolRefAttr(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+ return getSymbolRefAttr(getStringAttr(value), nestedReferences);
+ }
+ FlatSymbolRefAttr getSymbolRefAttr(StringRef value) {
+ return getSymbolRefAttr(getStringAttr(value));
+ }
// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c71c5d4bbd4ae..0240e17e83419 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -30,8 +30,10 @@ class ShapedType;
//===----------------------------------------------------------------------===//
namespace detail {
-template <typename T> class ElementsAttrIterator;
-template <typename T> class ElementsAttrRange;
+template <typename T>
+class ElementsAttrIterator;
+template <typename T>
+class ElementsAttrRange;
} // namespace detail
/// A base attribute that represents a reference to a static shaped tensor or
@@ -39,8 +41,10 @@ template <typename T> class ElementsAttrRange;
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
- template <typename T> using iterator = detail::ElementsAttrIterator<T>;
- template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
+ template <typename T>
+ using iterator = detail::ElementsAttrIterator<T>;
+ template <typename T>
+ using iterator_range = detail::ElementsAttrRange<T>;
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
@@ -52,14 +56,16 @@ class ElementsAttr : public Attribute {
/// Return the value of type 'T' at the given index, where 'T' corresponds to
/// an Attribute type.
- template <typename T> T getValue(ArrayRef<uint64_t> index) const {
+ template <typename T>
+ T getValue(ArrayRef<uint64_t> index) const {
return getValue(index).template cast<T>();
}
/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
- template <typename T> iterator_range<T> getValues() const;
+ template <typename T>
+ iterator_range<T> getValues() const;
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
@@ -139,7 +145,8 @@ class DenseElementIndexedIteratorImpl
};
/// Type trait detector that checks if a given type T is a complex type.
-template <typename T> struct is_complex_t : public std::false_type {};
+template <typename T>
+struct is_complex_t : public std::false_type {};
template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
@@ -154,7 +161,8 @@ class DenseElementsAttr : public ElementsAttr {
/// floating point type that can be used to access the underlying element
/// types of a DenseElementsAttr.
// TODO: Use std::disjunction when C++17 is supported.
- template <typename T> struct is_valid_cpp_fp_type {
+ template <typename T>
+ struct is_valid_cpp_fp_type {
/// The type is a valid floating point type if it is a builtin floating
/// point type, or is a potentially user defined floating point type. The
/// latter allows for supporting users that have custom types defined for
@@ -423,7 +431,8 @@ class DenseElementsAttr : public ElementsAttr {
Attribute getValue(ArrayRef<uint64_t> index) const {
return getValue<Attribute>(index);
}
- template <typename T> T getValue(ArrayRef<uint64_t> index) const {
+ template <typename T>
+ T getValue(ArrayRef<uint64_t> index) const {
// Skip to the element corresponding to the flattened index.
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
}
@@ -680,8 +689,15 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
return SymbolRefAttr::get(ctx, value);
}
+ static FlatSymbolRefAttr get(StringAttr value) {
+ return SymbolRefAttr::get(value);
+ }
+
+ /// Returns the name of the held symbol reference as a StringAttr.
+ StringAttr getAttr() const { return getRootReference(); }
+
/// Returns the name of the held symbol reference.
- StringRef getValue() const { return getRootReference(); }
+ StringRef getValue() const { return getAttr().getValue(); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
@@ -845,22 +861,28 @@ class ElementsAttrIterator
}
/// Utility functors used to generically implement the iterators methods.
- template <typename ItT> struct PlusAssign {
+ template <typename ItT>
+ struct PlusAssign {
void operator()(ItT &it, ptr
diff _t offset) { it += offset; }
};
- template <typename ItT> struct Minus {
+ template <typename ItT>
+ struct Minus {
ptr
diff _t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
};
- template <typename ItT> struct MinusAssign {
+ template <typename ItT>
+ struct MinusAssign {
void operator()(ItT &it, ptr
diff _t offset) { it -= offset; }
};
- template <typename ItT> struct Dereference {
+ template <typename ItT>
+ struct Dereference {
T operator()(ItT &it) { return *it; }
};
- template <typename ItT> struct ConstructIter {
+ template <typename ItT>
+ struct ConstructIter {
void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
};
- template <typename ItT> struct DestructIter {
+ template <typename ItT>
+ struct DestructIter {
void operator()(ItT &it) { it.~ItT(); }
};
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d514168bae367..08c3d0f2ebade 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -881,17 +881,26 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
@parent_reference::@nested_reference
```
}];
- let parameters = (ins
- StringRefParameter<"">:$rootReference,
- ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences
- );
+ let parameters =
+ (ins "StringAttr":$rootReference,
+ ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences);
+
+ let builders = [
+ AttrBuilderWithInferredContext<
+ (ins "StringAttr":$rootReference,
+ "ArrayRef<FlatSymbolRefAttr>":$nestedReferences), [{
+ return $_get(rootReference.getContext(), rootReference, nestedReferences);
+ }]>,
+ ];
let extraClassDeclaration = [{
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
+ static FlatSymbolRefAttr get(StringAttr value);
/// Returns the name of the fully resolved symbol, i.e. the leaf of the
/// reference path.
- StringRef getLeafReference() const;
+ StringAttr getLeafReference() const;
}];
+ let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index e1b7e28953578..facd2a8d77a14 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1734,7 +1734,7 @@ def IsNullAttr : AttrConstraint<
class ReferToOp<string opClass> : AttrConstraint<
CPred<"isa_and_nonnull<" # opClass # ">("
"::mlir::SymbolTable::lookupNearestSymbolFrom("
- "&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getValue()))">,
+ "&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getAttr()))">,
"referencing to a '" # opClass # "' symbol">;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 8bd6c25939cc0..97bda0ea73b40 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -31,7 +31,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
let methods = [
InterfaceMethod<"Returns the name of this symbol.",
- "StringRef", "getName", (ins), [{
+ "StringAttr", "getNameAttr", (ins), [{
// Don't rely on the trait implementation as optional symbol operations
// may override this.
return mlir::SymbolTable::getSymbolName($_op);
@@ -40,11 +40,10 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
}]
>,
InterfaceMethod<"Sets the name of this symbol.",
- "void", "setName", (ins "StringRef":$name), [{}],
+ "void", "setName", (ins "StringAttr":$name), [{}],
/*defaultImplementation=*/[{
this->getOperation()->setAttr(
- mlir::SymbolTable::getSymbolAttrName(),
- StringAttr::get(this->getOperation()->getContext(), name));
+ mlir::SymbolTable::getSymbolAttrName(), name);
}]
>,
InterfaceMethod<"Gets the visibility of this symbol.",
@@ -122,7 +121,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
symbol 'newSymbol' that are nested within the given operation 'from'.
Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
}],
- "LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
+ "LogicalResult", "replaceAllSymbolUses", (ins "StringAttr":$newSymbol,
"Operation *":$from), [{}],
/*defaultImplementation=*/[{
return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
@@ -176,6 +175,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
}];
let extraClassDeclaration = [{
+ /// Convenience version of `getNameAttr` that returns a StringRef.
+ StringRef getName() {
+ return getNameAttr().getValue();
+ }
+
+ /// Convenience version of `setName` that take a StringRef.
+ void setName(StringRef name) {
+ setName(StringAttr::get(this->getContext(), name));
+ }
+
/// Custom classof that handles the case where the symbol is optional.
static bool classof(Operation *op) {
auto *opConcept = getInterfaceFor(op);
@@ -188,6 +197,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
let extraTraitClassDeclaration = [{
using Visibility = mlir::SymbolTable::Visibility;
+
+ /// Convenience version of `getNameAttr` that returns a StringRef.
+ StringRef getName() {
+ return getNameAttr().getValue();
+ }
+
+ /// Convenience version of `setName` that take a StringRef.
+ void setName(StringRef name) {
+ setName(StringAttr::get(this->getContext(), name));
+ }
}];
}
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 97c87a43c3855..07a8f3fbb2dbf 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -30,7 +30,16 @@ class SymbolTable {
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Operation *lookup(StringRef name) const;
- template <typename T> T lookup(StringRef name) const {
+ template <typename T>
+ T lookup(StringRef name) const {
+ return dyn_cast_or_null<T>(lookup(name));
+ }
+
+ /// Look up a symbol with the specified name, returning null if no such
+ /// name exists. Names never include the @ on them.
+ Operation *lookup(StringAttr name) const;
+ template <typename T>
+ T lookup(StringAttr name) const {
return dyn_cast_or_null<T>(lookup(name));
}
@@ -74,10 +83,15 @@ class SymbolTable {
Nested,
};
- /// Returns the name of the given symbol operation.
- static StringRef getSymbolName(Operation *symbol);
+ /// Returns the name of the given symbol operation, aborting if no symbol is
+ /// present.
+ static StringAttr getSymbolName(Operation *symbol);
+
/// Sets the name of the given symbol operation.
- static void setSymbolName(Operation *symbol, StringRef name);
+ static void setSymbolName(Operation *symbol, StringAttr name);
+ static void setSymbolName(Operation *symbol, StringRef name) {
+ setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
+ }
/// Returns the visibility of the given symbol operation.
static Visibility getSymbolVisibility(Operation *symbol);
@@ -100,7 +114,10 @@ class SymbolTable {
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait.
- static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
+ static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
+ static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
+ return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
+ }
static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
/// by a given SymbolRefAttr. Returns failure if any of the nested references
@@ -112,11 +129,11 @@ class SymbolTable {
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
- static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
+ static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
static Operation *lookupNearestSymbolFrom(Operation *from,
SymbolRefAttr symbol);
template <typename T>
- static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
+ static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
}
template <typename T>
@@ -169,9 +186,9 @@ class SymbolTable {
/// operation 'from'. This does not traverse into any nested symbol tables.
/// This function returns None if there are any unknown operations that may
/// potentially be symbol tables.
- static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
+ static Optional<UseRange> getSymbolUses(StringAttr symbol, Operation *from);
static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
- static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
+ static Optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
/// Return if the given symbol is known to have no uses that are nested
@@ -180,9 +197,9 @@ class SymbolTable {
/// unknown operations that may potentially be symbol tables. This doesn't
/// necessarily mean that there are no uses, we just can't conservatively
/// prove it.
- static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
+ static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
- static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
+ static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
@@ -190,23 +207,24 @@ class SymbolTable {
/// 'from'. This does not traverse into any nested symbol tables. If there are
/// any unknown operations that may potentially be symbol tables, no uses are
/// replaced and failure is returned.
- static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
- StringRef newSymbol,
+ static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
+ StringAttr newSymbol,
Operation *from);
static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
- StringRef newSymbolName,
+ StringAttr newSymbolName,
Operation *from);
- static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
- StringRef newSymbol, Region *from);
+ static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
+ StringAttr newSymbol, Region *from);
static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
- StringRef newSymbolName,
+ StringAttr newSymbolName,
Region *from);
private:
Operation *symbolTableOp;
- /// This is a mapping from a name to the symbol with that name.
- llvm::StringMap<Operation *> symbolTable;
+ /// This is a mapping from a name to the symbol with that name. They key is
+ /// always known to be a StringAttr.
+ DenseMap<Attribute, Operation *> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;
@@ -226,7 +244,7 @@ class SymbolTableCollection {
public:
/// Look up a symbol with the specified name within the specified symbol table
/// operation, returning null if no such name exists.
- Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
+ Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
template <typename T, typename NameT>
T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
@@ -244,10 +262,10 @@ class SymbolTableCollection {
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
- Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
+ Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
template <typename T>
- T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
+ T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
}
template <typename T>
@@ -290,7 +308,7 @@ class SymbolUserMap {
}
/// Replace all of the uses of the given symbol with `newSymbolName`.
- void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName);
+ void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
private:
/// A reference to the symbol table used to construct this map.
@@ -327,18 +345,28 @@ class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Symbol names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
- Operation *lookupSymbol(StringRef name) {
+ Operation *lookupSymbol(StringAttr name) {
return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
}
- template <typename T> T lookupSymbol(StringRef name) {
+ template <typename T>
+ T lookupSymbol(StringAttr name) {
return dyn_cast_or_null<T>(lookupSymbol(name));
}
Operation *lookupSymbol(SymbolRefAttr symbol) {
return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
}
- template <typename T> T lookupSymbol(SymbolRefAttr symbol) {
+ template <typename T>
+ T lookupSymbol(SymbolRefAttr symbol) {
return dyn_cast_or_null<T>(lookupSymbol(symbol));
}
+
+ Operation *lookupSymbol(StringRef name) {
+ return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
+ }
+ template <typename T>
+ T lookupSymbol(StringRef name) {
+ return dyn_cast_or_null<T>(lookupSymbol(name));
+ }
};
} // end namespace OpTrait
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 3ec1b73c009fe..4ae54d4caad4e 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -212,15 +212,16 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
- return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
+ auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
+ return wrap(SymbolRefAttr::get(symbolAttr, refs));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
+ return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference().getValue());
}
MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
+ return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference().getValue());
}
intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 242d331ca76a8..1234a9a14fbdc 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -704,7 +704,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
// Get the function from the module. The name corresponds to the name of
// the kernel function.
auto kernelName = generateKernelNameConstant(
- launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
+ launchOp.getKernelModuleName().getValue(),
+ launchOp.getKernelName().getValue(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 520a3fa274318..20b7f9ad448c7 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -106,7 +106,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Operation *op) const {
using LLVM::LLVMFuncOp;
- Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 6a3231f6873b8..d648a293400d4 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -181,9 +181,8 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
StringRef(binary.data(), binary.size())));
// Set entry point name as an attribute.
- vulkanLaunchCallOp->setAttr(
- kSPIRVEntryPointAttrName,
- StringAttr::get(loc->getContext(), launchOp.getKernelName()));
+ vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
+ launchOp.getKernelName());
launchOp.erase();
}
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index f47b8878a17c6..4a0dbf17ca2de 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -52,9 +52,8 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
// fnName is a dynamic std::string, unique it via a SymbolRefAttr.
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
auto module = op->getParentOfType<ModuleOp>();
- if (module.lookupSymbol(fnName)) {
+ if (module.lookupSymbol(fnNameAttr.getAttr()))
return fnNameAttr;
- }
SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
assert(op->getNumResults() == 0 &&
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 354a0a62934b3..eef6f2c0b6f7b 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -127,8 +127,9 @@ static LogicalResult encodeKernelName(spirv::ModuleOp module) {
// {spv_module_name}_{function_name}
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
StringRef funcName = entryPoint.fn();
- auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
- std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
+ auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
+ StringAttr newFuncName =
+ StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
return failure();
SymbolTable::setSymbolName(funcOp, newFuncName);
@@ -166,9 +167,10 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// is named:
// __spv__{kernel_module_name}
// based on GPU to SPIR-V conversion.
- StringRef kernelModuleName = launchOp.getKernelModuleName();
+ StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
- auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
+ auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
+ StringAttr::get(context, spvModuleName));
if (!spvModule) {
return launchOp.emitOpError("SPIR-V kernel module '")
<< spvModuleName << "' is not found";
@@ -180,9 +182,10 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// variables. The name of the kernel will be
// {spv_module_name}_{kernel_function_name}
// to avoid symbolic name conflicts.
- StringRef kernelFuncName = launchOp.getKernelName();
+ StringRef kernelFuncName = launchOp.getKernelName().getValue();
std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
- auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
+ auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
+ StringAttr::get(context, newKernelFuncName));
if (!kernelFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 108618602dde9..8d957e0df4338 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1523,12 +1523,13 @@ void mlir::encodeBindAttribute(ModuleOp module) {
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
+ auto nameAttr = StringAttr::get(op->getContext(), name);
// Replace all symbol uses and set the new symbol name. Finally, remove
// descriptor set and binding attributes.
- if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
+ if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
op.emitError("unable to replace all symbol uses for ") << name;
- SymbolTable::setSymbolName(op, name);
+ SymbolTable::setSymbolName(op, nameAttr);
op->removeAttr(kDescriptorSet);
op->removeAttr(kBinding);
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index eb9145ee83839..b8a184fea3c3a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -196,14 +196,15 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return success();
// Check that `launch_func` refers to a well-formed GPU kernel module.
- StringRef kernelModuleName = launchOp.getKernelModuleName();
+ StringAttr kernelModuleName = launchOp.getKernelModuleName();
auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName);
if (!kernelModule)
return launchOp.emitOpError()
- << "kernel module '" << kernelModuleName << "' is undefined";
+ << "kernel module '" << kernelModuleName.getValue()
+ << "' is undefined";
// Check that `launch_func` refers to a well-formed kernel function.
- Operation *kernelFunc = module.lookupSymbol(launchOp.kernel());
+ Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr());
auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelGPUFunction && !kernelLLVMFunction)
@@ -555,11 +556,11 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
return getNumOperands() - asyncDependencies().size() - kNumConfigOperands;
}
-StringRef LaunchFuncOp::getKernelModuleName() {
+StringAttr LaunchFuncOp::getKernelModuleName() {
return kernel().getRootReference();
}
-StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
+StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
Value LaunchFuncOp::getKernelOperand(unsigned i) {
return getOperand(asyncDependencies().size() + kNumConfigOperands + i);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7a8ff8fbdb317..6473b944b76d8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -343,8 +343,8 @@ LogicalResult verifySymbolAttribute(
// a constraint in the operation definition.
for (SymbolRefAttr symbolRef :
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
- StringRef metadataName = symbolRef.getRootReference();
- StringRef symbolName = symbolRef.getLeafReference();
+ StringAttr metadataName = symbolRef.getRootReference();
+ StringAttr symbolName = symbolRef.getLeafReference();
// We want @metadata::@symbol, not just @symbol
if (metadataName == symbolName) {
return op->emitOpError() << "expected '" << symbolRef
@@ -770,7 +770,7 @@ static LogicalResult verify(CallOp &op) {
bool isIndirect = false;
// If this is an indirect call, the callee attribute is missing.
- Optional<StringRef> calleeName = op.callee();
+ FlatSymbolRefAttr calleeName = op.calleeAttr();
if (!calleeName) {
isIndirect = true;
if (!op.getNumOperands())
@@ -782,14 +782,15 @@ static LogicalResult verify(CallOp &op) {
<< ptrType;
fnType = ptrType.getElementType();
} else {
- Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
+ Operation *callee =
+ SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr());
if (!callee)
return op.emitOpError()
- << "'" << *calleeName
+ << "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
- return op.emitOpError() << "'" << *calleeName
+ return op.emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";
fnType = fn.getType();
@@ -2253,14 +2254,14 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
if (!accessGroupRef)
return op->emitOpError()
<< "expected '" << attr << "' to be a symbol reference";
- StringRef metadataName = accessGroupRef.getRootReference();
+ StringAttr metadataName = accessGroupRef.getRootReference();
auto metadataOp =
SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
op->getParentOp(), metadataName);
if (!metadataOp)
return op->emitOpError()
<< "expected '" << attr << "' to reference a metadata op";
- StringRef accessGroupName = accessGroupRef.getLeafReference();
+ StringAttr accessGroupName = accessGroupRef.getLeafReference();
Operation *accessGroupOp =
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
if (!accessGroupOp)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fc18fdd78b3cc..5f4d83286c712 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1066,7 +1066,7 @@ void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
- addressOfOp.variable()));
+ addressOfOp.variableAttr()));
if (!varOp) {
return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
}
@@ -1953,14 +1953,14 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
//===----------------------------------------------------------------------===//
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
- auto fnName = functionCallOp.callee();
+ auto fnName = functionCallOp.calleeAttr();
auto funcOp =
dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
functionCallOp->getParentOp(), fnName));
if (!funcOp) {
return functionCallOp.emitOpError("callee function '")
- << fnName << "' not found in nearest symbol table";
+ << fnName.getValue() << "' not found in nearest symbol table";
}
auto functionType = funcOp.getType();
@@ -2115,7 +2115,7 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
if (auto init =
varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
- varOp->getParentOp(), init.getValue());
+ varOp->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
@@ -2691,7 +2691,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
- referenceOfOp->getParentOp(), referenceOfOp.spec_const());
+ referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr());
Type constType;
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
@@ -3516,17 +3516,17 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
if (cType.isa<spirv::CooperativeMatrixNVType>())
return constOp.emitError("unsupported composite type ") << cType;
- else if (constituents.size() != cType.getNumElements())
+ if (constituents.size() != cType.getNumElements())
return constOp.emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "
<< constituents.size();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
- auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+ auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
auto constituentSpecConstOp =
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
- constOp->getParentOp(), constituent.getValue()));
+ constOp->getParentOp(), constituent.getAttr()));
if (constituentSpecConstOp.default_value().getType() !=
cType.getElementType(index))
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 1007603d3cf2f..2b4ac38d618f5 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -30,21 +30,20 @@ static constexpr unsigned maxFreeID = 1 << 20;
/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
-static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
- spirv::ModuleOp module) {
+static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
+ spirv::ModuleOp module) {
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
- while (lastUsedID < maxFreeID) {
- std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
+ MLIRContext *ctx = module->getContext();
- if (!SymbolTable::lookupSymbolIn(module, possible)) {
- newSymName += llvm::utostr(lastUsedID);
- break;
- }
+ while (lastUsedID < maxFreeID) {
+ auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
+ if (!SymbolTable::lookupSymbolIn(module, possible))
+ return possible;
}
- return newSymName;
+ return StringAttr::get(ctx, newSymName);
}
/// Checks if a symbol with the same name as `op` already exists in `source`.
@@ -57,7 +56,7 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
return success();
StringRef oldSymName = op.getName();
- SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
+ StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
return op.emitError("unable to update all symbol uses for ")
@@ -234,7 +233,7 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
SymbolOpInterface replacementSymOp = result.first->second;
if (failed(SymbolTable::replaceAllSymbolUses(
- symbolOp, replacementSymOp.getName(), combinedModule))) {
+ symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
<< symbolOp.getName() << " to " << replacementSymOp.getName();
return nullptr;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index 372295a986afc..fad437ce1330f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -64,11 +64,11 @@ class SPIRVAddressOfOpLayoutInfoDecoration
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
- auto varName = op.variable();
+ auto varName = op.variableAttr();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
- op, varOp.type(), rewriter.getSymbolRefAttr(varName));
+ op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr()));
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index e243e8e300d6a..14ff5679ff9a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -96,19 +96,21 @@ static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
}
/// Returns function reference (first hit also inserts into module).
-static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
+static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType,
ValueRange operands) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
- auto func = module.lookupSymbol<FuncOp>(name);
+ auto result = SymbolRefAttr::get(context, name);
+ auto func = module.lookupSymbol<FuncOp>(result.getAttr());
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
moduleBuilder
- .create<FuncOp>(op->getLoc(), name,
- FunctionType::get(context, operands.getTypes(), result))
+ .create<FuncOp>(
+ op->getLoc(), name,
+ FunctionType::get(context, operands.getTypes(), resultType))
.setPrivate();
}
- return SymbolRefAttr::get(context, name);
+ return result;
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b689225925dbe..65cbc8a8c33eb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1659,7 +1659,7 @@ void ModulePrinter::printAttribute(Attribute attr,
printType(typeAttr.getValue());
} else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
- printSymbolReference(refAttr.getRootReference(), os);
+ printSymbolReference(refAttr.getRootReference().getValue(), os);
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
os << "::";
printSymbolReference(nestedRef.getValue(), os);
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index df04309707e63..0ced5e55a5183 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -216,13 +216,15 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
assert(symName && "value does not have a valid symbol name");
return getSymbolRefAttr(symName.getValue());
}
-FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
- return SymbolRefAttr::get(getContext(), value);
+
+FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) {
+ return SymbolRefAttr::get(value);
}
+
SymbolRefAttr
-Builder::getSymbolRefAttr(StringRef value,
+Builder::getSymbolRefAttr(StringAttr value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
- return SymbolRefAttr::get(getContext(), value, nestedReferences);
+ return SymbolRefAttr::get(value, nestedReferences);
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 5a754dfe7c3f5..e9e1ed8c25452 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -273,12 +273,16 @@ LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
//===----------------------------------------------------------------------===//
FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
- return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
+ return get(StringAttr::get(ctx, value));
}
-StringRef SymbolRefAttr::getLeafReference() const {
+FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
+ return get(value, {}).cast<FlatSymbolRefAttr>();
+}
+
+StringAttr SymbolRefAttr::getLeafReference() const {
ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
- return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 9f145a2db9ffb..ad4f08364e7bf 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -22,17 +22,13 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
return op->getNumRegions() == 1 && !op->getDialect();
}
-/// Returns the string name of the given symbol, or None if this is not a
+/// Returns the string name of the given symbol, or null if this is not a
/// symbol.
-static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
- auto nameAttr =
- symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
- return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
+static StringAttr getNameIfSymbol(Operation *op) {
+ return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
}
-static Optional<StringRef> getNameIfSymbol(Operation *symbol,
- Identifier symbolAttrNameId) {
- auto nameAttr = symbol->getAttrOfType<StringAttr>(symbolAttrNameId);
- return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
+static StringAttr getNameIfSymbol(Operation *op, Identifier symbolAttrNameId) {
+ return op->getAttrOfType<StringAttr>(symbolAttrNameId);
}
/// Computes the nested symbol reference attribute for the symbol 'symbolName'
@@ -40,13 +36,13 @@ static Optional<StringRef> getNameIfSymbol(Operation *symbol,
/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
/// Returns success if all references up to 'within' could be computed.
static LogicalResult
-collectValidReferencesFor(Operation *symbol, StringRef symbolName,
+collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
Operation *within,
SmallVectorImpl<SymbolRefAttr> &results) {
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
MLIRContext *ctx = symbol->getContext();
- auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
+ auto leafRef = FlatSymbolRefAttr::get(symbolName);
results.push_back(leafRef);
// Early exit for when 'within' is the parent of 'symbol'.
@@ -63,17 +59,16 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
return failure();
// Each parent of 'symbol' should also be a symbol.
- Optional<StringRef> symbolTableName =
- getNameIfSymbol(symbolTableOp, symbolNameId);
+ StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
if (!symbolTableName)
return failure();
- results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
+ results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
symbolTableOp = symbolTableOp->getParentOp();
if (symbolTableOp == within)
break;
nestedRefs.insert(nestedRefs.begin(),
- FlatSymbolRefAttr::get(ctx, *symbolTableName));
+ FlatSymbolRefAttr::get(symbolTableName));
} while (true);
return success();
}
@@ -119,11 +114,11 @@ SymbolTable::SymbolTable(Operation *symbolTableOp)
Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(),
symbolTableOp->getContext());
for (auto &op : symbolTableOp->getRegion(0).front()) {
- Optional<StringRef> name = getNameIfSymbol(&op, symbolNameId);
+ StringAttr name = getNameIfSymbol(&op, symbolNameId);
if (!name)
continue;
- auto inserted = symbolTable.insert({*name, &op});
+ auto inserted = symbolTable.insert({name, &op});
(void)inserted;
assert(inserted.second &&
"expected region to contain uniquely named symbol operations");
@@ -133,18 +128,21 @@ SymbolTable::SymbolTable(Operation *symbolTableOp)
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Operation *SymbolTable::lookup(StringRef name) const {
+ return lookup(StringAttr::get(symbolTableOp->getContext(), name));
+}
+Operation *SymbolTable::lookup(StringAttr name) const {
return symbolTable.lookup(name);
}
/// Erase the given symbol from the table.
void SymbolTable::erase(Operation *symbol) {
- Optional<StringRef> name = getNameIfSymbol(symbol);
+ StringAttr name = getNameIfSymbol(symbol);
assert(name && "expected valid 'name' attribute");
assert(symbol->getParentOp() == symbolTableOp &&
"expected this operation to be inside of the operation with this "
"SymbolTable");
- auto it = symbolTable.find(*name);
+ auto it = symbolTable.find(name);
if (it != symbolTable.end() && it->second == symbol) {
symbolTable.erase(it);
symbol->erase();
@@ -180,7 +178,7 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
- StringRef name = getSymbolName(symbol);
+ StringAttr name = getSymbolName(symbol);
if (symbolTable.insert({name, symbol}).second)
return;
// If the symbol was already in the table, also return.
@@ -188,28 +186,31 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
return;
// If a conflict was detected, then the symbol will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(name);
+ SmallString<128> nameBuffer(name.getValue());
unsigned originalLength = nameBuffer.size();
+ MLIRContext *context = symbol->getContext();
+
// Iteratively try suffixes until we find one that isn't used.
do {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
- } while (!symbolTable.insert({nameBuffer, symbol}).second);
+ } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
+ .second);
setSymbolName(symbol, nameBuffer);
}
/// Returns the name of the given symbol operation.
-StringRef SymbolTable::getSymbolName(Operation *symbol) {
- Optional<StringRef> name = getNameIfSymbol(symbol);
+StringAttr SymbolTable::getSymbolName(Operation *symbol) {
+ StringAttr name = getNameIfSymbol(symbol);
assert(name && "expected valid symbol name");
- return *name;
+ return name;
}
+
/// Sets the name of the given symbol operation.
-void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
- symbol->setAttr(getSymbolAttrName(),
- StringAttr::get(symbol->getContext(), name));
+void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
+ symbol->setAttr(getSymbolAttrName(), name);
}
/// Returns the visibility of the given symbol operation.
@@ -295,7 +296,7 @@ void SymbolTable::walkSymbolTables(
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
/// was found.
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
- StringRef symbol) {
+ StringAttr symbol) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
Region ®ion = symbolTableOp->getRegion(0);
if (region.empty())
@@ -322,7 +323,7 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
static LogicalResult lookupSymbolInImpl(
Operation *symbolTableOp, SymbolRefAttr symbol,
SmallVectorImpl<Operation *> &symbols,
- function_ref<Operation *(Operation *, StringRef)> lookupSymbolFn) {
+ function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
// Lookup the root reference for this symbol.
@@ -343,7 +344,7 @@ static LogicalResult lookupSymbolInImpl(
// Otherwise, lookup each of the nested non-leaf references and ensure that
// each corresponds to a valid symbol table.
for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
- symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue());
+ symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
return failure();
symbols.push_back(symbolTableOp);
@@ -355,7 +356,7 @@ static LogicalResult lookupSymbolInImpl(
LogicalResult
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
SmallVectorImpl<Operation *> &symbols) {
- auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) {
+ auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
return lookupSymbolIn(symbolTableOp, symbol);
};
return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
@@ -365,7 +366,7 @@ SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
/// nullptr if no valid symbol was found.
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
- StringRef symbol) {
+ StringAttr symbol) {
Operation *symbolTableOp = getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
@@ -610,7 +611,7 @@ struct SymbolScope {
/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
Operation *limit) {
- StringRef symName = SymbolTable::getSymbolName(symbol);
+ StringAttr symName = SymbolTable::getSymbolName(symbol);
assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
// Compute the ancestors of 'limit'.
@@ -625,7 +626,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
// doesn't support parent references.
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
symbol->getParentOp())
- return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
+ return {{SymbolRefAttr::get(symName), limit}};
return {};
}
@@ -679,9 +680,9 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
return scopes;
}
template <typename IRUnit>
-static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
+static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
IRUnit *limit) {
- return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
+ return {{SymbolRefAttr::get(symbol), limit}};
}
/// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -753,7 +754,7 @@ static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables. This function returns None if there
/// are any unknown operations that may potentially be symbol tables.
-auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
+auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
@@ -761,7 +762,7 @@ auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
-auto SymbolTable::getSymbolUses(StringRef symbol, Region *from)
+auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
@@ -792,13 +793,13 @@ static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
/// the given operation 'from'. This does not traverse into any nested symbol
/// tables. This function will also return false if there are any unknown
/// operations that may potentially be symbol tables.
-bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
+bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
-bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) {
+bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
@@ -861,14 +862,13 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
- return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
- nestedRefs);
+ return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
}
/// The implementation of SymbolTable::replaceAllSymbolUses below.
template <typename SymbolT, typename IRUnitT>
static LogicalResult
-replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
+replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
// A collection of operations along with their new attribute dictionary.
std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
@@ -888,8 +888,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
};
// Generate a new attribute to replace the given attribute.
- MLIRContext *ctx = limit->getContext();
- FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
+ FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@@ -905,13 +904,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
if (useRef != scope.symbol) {
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
replacementRef =
- SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
+ SymbolRefAttr::get(newSymbol, useRef.getNestedReferences());
} else {
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
newLeafAttr;
replacementRef =
- SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
+ SymbolRefAttr::get(useRef.getRootReference(), nestedRefs);
}
}
@@ -949,23 +948,23 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
/// 'from'. This does not traverse into any nested symbol tables. If there are
/// any unknown operations that may potentially be symbol tables, no uses are
/// replaced and failure is returned.
-LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
- StringRef newSymbol,
+LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
+ StringAttr newSymbol,
Operation *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
- StringRef newSymbol,
+ StringAttr newSymbol,
Operation *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
-LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
- StringRef newSymbol,
+LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
+ StringAttr newSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
- StringRef newSymbol,
+ StringAttr newSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
@@ -975,7 +974,7 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
//===----------------------------------------------------------------------===//
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
- StringRef symbol) {
+ StringAttr symbol) {
return getSymbolTable(symbolTableOp).lookup(symbol);
}
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
@@ -992,7 +991,7 @@ LogicalResult
SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols) {
- auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) {
+ auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
return lookupSymbolIn(symbolTableOp, symbol);
};
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
@@ -1003,7 +1002,7 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
- StringRef symbol) {
+ StringAttr symbol) {
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
@@ -1052,7 +1051,7 @@ SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
}
void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
- StringRef newSymbolName) {
+ StringAttr newSymbolName) {
auto it = symbolToUsers.find(symbol);
if (it == symbolToUsers.end())
return;
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index a81387f3f58e2..4370b23bb7b2b 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -818,7 +818,7 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create(
- op, rewriterToAddr[op.rewriter().getLeafReference()]));
+ op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.matchedOps());
writer.appendPDLValueList(op.inputs());
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 0c3839ae9b8c2..51133f63c722f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -814,8 +814,8 @@ LogicalResult ModuleTranslation::createAliasScopeMetadata() {
llvm::MDNode *
ModuleTranslation::getAliasScope(Operation &opInst,
SymbolRefAttr aliasScopeRef) const {
- StringRef metadataName = aliasScopeRef.getRootReference();
- StringRef scopeName = aliasScopeRef.getLeafReference();
+ StringAttr metadataName = aliasScopeRef.getRootReference();
+ StringAttr scopeName = aliasScopeRef.getLeafReference();
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
opInst.getParentOp(), metadataName);
Operation *aliasScopeOp =
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 4d40155b8b655..22ff2adad32b7 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -84,7 +84,7 @@ struct SymbolUsesPass
table.erase(op);
assert(!table.lookup(name) &&
"expected erased operation to be unknown now");
- module.emitRemark() << name << " function successfully erased";
+ module.emitRemark() << name.getValue() << " function successfully erased";
}
}
};
@@ -110,8 +110,8 @@ struct SymbolReplacementPass
StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
if (!newName)
return;
- symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue());
- SymbolTable::setSymbolName(nestedOp, newName.getValue());
+ symbolUsers.replaceAllUsesWith(nestedOp, newName);
+ SymbolTable::setSymbolName(nestedOp, newName);
});
}
};
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 9441009ce3b83..159c623bf4eb3 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -80,8 +80,10 @@ struct TestPDLByteCodePass
// The test cases are encompassed via two modules, one containing the
// patterns and one containing the operations to rewrite.
- ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
- ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
+ ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
+ StringAttr::get(module->getContext(), "patterns"));
+ ModuleOp irModule = module.lookupSymbol<ModuleOp>(
+ StringAttr::get(module->getContext(), "ir"));
if (!patternModule || !irModule)
return;
More information about the Mlir-commits
mailing list