[flang-commits] [flang] 5eae715 - [mlir] Add NamedAttrList
Jacques Pienaar via flang-commits
flang-commits at lists.llvm.org
Thu May 7 12:38:40 PDT 2020
Author: Jacques Pienaar
Date: 2020-05-07T12:33:36-07:00
New Revision: 5eae715a3115be2640d0fd37d0bd4771abf2ab9b
URL: https://github.com/llvm/llvm-project/commit/5eae715a3115be2640d0fd37d0bd4771abf2ab9b
DIFF: https://github.com/llvm/llvm-project/commit/5eae715a3115be2640d0fd37d0bd4771abf2ab9b.diff
LOG: [mlir] Add NamedAttrList
This is a wrapper around vector of NamedAttributes that keeps track of whether sorted and does some minimal effort to remain sorted (doing more, e.g., appending attributes in sorted order, could be done in follow up). It contains whether sorted and if a DictionaryAttr is queried, it caches the returned DictionaryAttr along with whether sorted.
Change MutableDictionaryAttr to always return a non-null Attribute even when empty (reserve null cases for errors). To this end change the getter to take a context as input so that the empty DictionaryAttr could be queried. Also create one instance of the empty dictionary attribute that could be reused without needing to lock context etc.
Update infer type op interface to use DictionaryAttr and use NamedAttrList to avoid incurring multiple conversion costs.
Fix bug in sorting helper function.
Differential Revision: https://reviews.llvm.org/D79463
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/Analysis/CallGraph.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Function.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Pass/IRPrinting.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 383256c3916f..b2486c954fc5 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -654,7 +654,7 @@ class fir_IntegralSwitchTerminatorOp<string mnemonic,
mlir::Attribute ivalue; // Integer or Unit
mlir::Block *dest;
llvm::SmallVector<mlir::Value, 8> destArg;
- llvm::SmallVector<mlir::NamedAttribute, 1> temp;
+ mlir::NamedAttrList temp;
if (parser.parseAttribute(ivalue, "i", temp) ||
parser.parseComma() ||
parser.parseSuccessorAndUseList(dest, destArg))
@@ -2215,7 +2215,7 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> {
let parser = [{
auto &builder = parser.getBuilder();
mlir::Attribute val;
- llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
+ mlir::NamedAttrList attrs;
if (parser.parseAttribute(val, "fake", attrs))
return mlir::failure();
if (auto v = val.dyn_cast<mlir::StringAttr>())
@@ -2858,8 +2858,8 @@ def fir_DispatchTableOp : fir_Op<"dispatch_table",
return failure();
// Convert the parsed name attr into a string attr.
- result.attributes.back().second =
- parser.getBuilder().getStringAttr(nameAttr.getRootReference());
+ result.attributes.set(mlir::SymbolTable::getSymbolAttrName(),
+ parser.getBuilder().getStringAttr(nameAttr.getRootReference()));
// Parse the optional table body.
mlir::Region *body = result.addRegion();
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 2e67c86486fc..30cd365f139b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -181,7 +181,7 @@ static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
if (parser.parseOperandList(operands))
return mlir::failure();
- llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
+ mlir::NamedAttrList attrs;
mlir::SymbolRefAttr funcAttr;
bool isDirect = operands.empty();
if (isDirect)
@@ -259,7 +259,7 @@ template <typename OPTY>
static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops;
- llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
+ mlir::NamedAttrList attrs;
mlir::Attribute predicateNameAttr;
mlir::Type type;
if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(),
@@ -279,7 +279,8 @@ static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser,
auto predicate = fir::CmpfOp::getPredicateByName(predicateName);
auto builder = parser.getBuilder();
mlir::Type i1Type = builder.getI1Type();
- attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+ attrs.set(OPTY::getPredicateAttrName(),
+ builder.getI64IntegerAttr(static_cast<int64_t>(predicate)));
result.attributes = attrs;
result.addTypes({i1Type});
return success();
@@ -1102,7 +1103,7 @@ static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
mlir::Attribute attr;
mlir::Block *dest;
llvm::SmallVector<mlir::Value, 8> destArg;
- llvm::SmallVector<mlir::NamedAttribute, 1> temp;
+ mlir::NamedAttrList temp;
if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) ||
parser.parseComma())
return mlir::failure();
@@ -1323,7 +1324,7 @@ static ParseResult parseSelectType(OpAsmParser &parser,
mlir::Attribute attr;
mlir::Block *dest;
llvm::SmallVector<mlir::Value, 8> destArg;
- llvm::SmallVector<mlir::NamedAttribute, 1> temp;
+ mlir::NamedAttrList temp;
if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() ||
parser.parseSuccessorAndUseList(dest, destArg))
return mlir::failure();
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index e0b4e5f43737..cb9fc1687b33 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -269,6 +269,9 @@ class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
/// be non-null.
using NamedAttribute = std::pair<Identifier, Attribute>;
+bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs);
+bool operator<(const NamedAttribute &lhs, StringRef rhs);
+
/// Dictionary attribute is an attribute that represents a sorted collection of
/// named attribute values. The elements are sorted by name, and each name must
/// be unique within the collection.
@@ -309,14 +312,25 @@ class DictionaryAttr
size_t size() const;
/// Sorts the NamedAttributes in the array ordered by name as expected by
- /// getWithSorted.
+ /// getWithSorted and returns whether the values were sorted.
/// Requires: uniquely named attributes.
- static void sort(SmallVectorImpl<NamedAttribute> &array);
+ static bool sort(ArrayRef<NamedAttribute> values,
+ SmallVectorImpl<NamedAttribute> &storage);
+
+ /// Sorts the NamedAttributes in the array ordered by name as expected by
+ /// getWithSorted in place on an array and returns whether the values needed
+ /// to be sorted.
+ /// Requires: uniquely named attributes.
+ static bool sortInPlace(SmallVectorImpl<NamedAttribute> &array);
/// Methods for supporting type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Dictionary;
}
+
+private:
+ /// Return empty dictionary.
+ static DictionaryAttr getEmpty(MLIRContext *context);
};
//===----------------------------------------------------------------------===//
@@ -1652,9 +1666,8 @@ class MutableDictionaryAttr {
return attrs == other.attrs;
}
- /// Return the underlying dictionary attribute. This may be null, if this list
- /// has no attributes.
- DictionaryAttr getDictionary() const { return attrs; }
+ /// Return the underlying dictionary attribute.
+ DictionaryAttr getDictionary(MLIRContext *context) const;
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
@@ -1680,10 +1693,20 @@ class MutableDictionaryAttr {
/// value indicates whether the attribute was present or not.
RemoveResult remove(Identifier name);
+ bool empty() const { return attrs == nullptr; }
+
private:
+ friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg);
+
DictionaryAttr attrs;
};
+inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) {
+ if (!arg.attrs)
+ return ::llvm::hash_value((void *)nullptr);
+ return hash_value(arg.attrs);
+}
+
} // end namespace mlir.
namespace llvm {
diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index 37117b28682e..958cba51f6dc 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -38,8 +38,8 @@ class VariadicFlag {
/// Internally, argument and result attributes are stored as dict attributes
/// with special names given by getResultAttrName, getArgumentAttrName.
void addArgAndResultAttrs(Builder &builder, OperationState &result,
- ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
- ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs);
+ ArrayRef<NamedAttrList> argAttrs,
+ ArrayRef<NamedAttrList> resultAttrs);
/// Callback type for `parseFunctionLikeOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
@@ -53,13 +53,13 @@ using FuncTypeBuilder = function_ref<Type(
/// indicates whether functions with variadic arguments are supported. The
/// trailing arguments are populated by this function with names, types and
/// attributes of the arguments and those of the results.
-ParseResult parseFunctionSignature(
- OpAsmParser &parser, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::OperandType> &argNames,
- SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs);
+ParseResult
+parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
+ SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+ SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<NamedAttrList> &argAttrs,
+ bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<NamedAttrList> &resultAttrs);
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 53c513023e29..87e4b6780164 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -517,10 +517,14 @@ void FunctionLike<ConcreteType>::setArgAttrs(unsigned index,
MutableDictionaryAttr attributes) {
assert(index < getNumArguments() && "invalid argument number");
SmallString<8> nameOut;
- if (auto newAttr = attributes.getDictionary())
+ if (attributes.getAttrs().empty()) {
+ this->getOperation()->removeAttr(getArgAttrName(index, nameOut));
+ } else {
+ auto newAttr = attributes.getDictionary(
+ attributes.getAttrs().front().second.getContext());
return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
newAttr);
- static_cast<ConcreteType *>(this)->removeAttr(getArgAttrName(index, nameOut));
+ }
}
/// If the an attribute exists with the specified name, change it to the new
@@ -533,7 +537,7 @@ void FunctionLike<ConcreteType>::setArgAttr(unsigned index, Identifier name,
attrDict.set(name, value);
// If the attribute changed, then set the new arg attribute list.
- if (curAttr != attrDict.getDictionary())
+ if (curAttr != attrDict.getDictionary(value.getContext()))
setArgAttrs(index, attrDict);
}
@@ -564,7 +568,7 @@ void FunctionLike<ConcreteType>::setResultAttrs(
getResultAttrName(index, nameOut);
if (attributes.empty())
- return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
+ return (void)this->getOperation()->removeAttr(nameOut);
Operation *op = this->getOperation();
op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
}
@@ -574,11 +578,13 @@ void FunctionLike<ConcreteType>::setResultAttrs(
unsigned index, MutableDictionaryAttr attributes) {
assert(index < getNumResults() && "invalid result number");
SmallString<8> nameOut;
- if (auto newAttr = attributes.getDictionary())
+ if (attributes.empty()) {
+ this->getOperation()->removeAttr(getResultAttrName(index, nameOut));
+ } else {
+ auto newAttr = attributes.getDictionary(this->getOperation()->getContext());
return this->getOperation()->setAttr(getResultAttrName(index, nameOut),
newAttr);
- static_cast<ConcreteType *>(this)->removeAttr(
- getResultAttrName(index, nameOut));
+ }
}
/// If the an attribute exists with the specified name, change it to the new
@@ -591,7 +597,7 @@ void FunctionLike<ConcreteType>::setResultAttr(unsigned index, Identifier name,
attrDict.set(name, value);
// If the attribute changed, then set the new arg attribute list.
- if (curAttr != attrDict.getDictionary())
+ if (curAttr != attrDict.getDictionary(value.getContext()))
setResultAttrs(index, attrDict);
}
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0588e693dd07..126d20eacbe4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -366,28 +366,28 @@ class OpAsmParser {
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name.
ParseResult parseAttribute(Attribute &result, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) {
+ NamedAttrList &attrs) {
return parseAttribute(result, Type(), attrName, attrs);
}
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) {
+ NamedAttrList &attrs) {
return parseAttribute(result, Type(), attrName, attrs);
}
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
- virtual ParseResult
- parseAttribute(Attribute &result, Type type, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) = 0;
+ virtual ParseResult parseAttribute(Attribute &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) {
+ NamedAttrList &attrs) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
@@ -404,13 +404,12 @@ class OpAsmParser {
}
/// Parse a named dictionary into 'result' if it is present.
- virtual ParseResult
- parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) = 0;
+ virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
virtual ParseResult
- parseOptionalAttrDictWithKeyword(SmallVectorImpl<NamedAttribute> &result) = 0;
+ parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
/// Parse an affine map instance into 'map'.
virtual ParseResult parseAffineMap(AffineMap &map) = 0;
@@ -425,7 +424,7 @@ class OpAsmParser {
/// Parse an @-identifier and store it (without the '@' symbol) in a string
/// attribute named 'attrName'.
ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) {
+ NamedAttrList &attrs) {
if (failed(parseOptionalSymbolName(result, attrName, attrs)))
return emitError(getCurrentLocation())
<< "expected valid '@'-identifier for symbol name";
@@ -434,9 +433,9 @@ class OpAsmParser {
/// Parse an optional @-identifier and store it (without the '@' symbol) in a
/// string attribute named 'attrName'.
- virtual ParseResult
- parseOptionalSymbolName(StringAttr &result, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) = 0;
+ virtual ParseResult parseOptionalSymbolName(StringAttr &result,
+ StringRef attrName,
+ NamedAttrList &attrs) = 0;
//===--------------------------------------------------------------------===//
// Operand Parsing
@@ -552,8 +551,7 @@ class OpAsmParser {
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
virtual ParseResult
parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
- StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs,
+ StringRef attrName, NamedAttrList &attrs,
Delimiter delimiter = Delimiter::Square) = 0;
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index fcde73efd566..946b5c0b02ee 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -294,6 +294,11 @@ class Operation final
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
+ /// Return all of the attributes on this operation as a DictionaryAttr.
+ DictionaryAttr getAttrDictionary() {
+ return attrs.getDictionary(getContext());
+ }
+
/// Return mutable container of all the attributes on this operation.
MutableDictionaryAttr &getMutableAttrDict() { return attrs; }
@@ -326,6 +331,9 @@ class Operation final
MutableDictionaryAttr::RemoveResult removeAttr(Identifier name) {
return attrs.remove(name);
}
+ MutableDictionaryAttr::RemoveResult removeAttr(StringRef name) {
+ return attrs.remove(Identifier::get(name, getContext()));
+ }
/// A utility iterator that filters out non-dialect attributes.
class dialect_attr_iterator
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 8c0a3f12d426..a98c57cfba2e 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -201,6 +201,100 @@ class AbstractOperation {
bool (&hasRawTrait)(TypeID traitID);
};
+//===----------------------------------------------------------------------===//
+// NamedAttrList
+//===----------------------------------------------------------------------===//
+
+/// NamedAttrList is array of NamedAttributes that tracks whether it is sorted
+/// and does some basic work to remain sorted.
+class NamedAttrList {
+public:
+ using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator;
+ using const_reference = const NamedAttribute &;
+ using reference = NamedAttribute &;
+ using size_type = size_t;
+
+ NamedAttrList() : dictionarySorted({}, true) {}
+ NamedAttrList(ArrayRef<NamedAttribute> attributes);
+ NamedAttrList(const_iterator in_start, const_iterator in_end);
+
+ bool operator!=(const NamedAttrList &other) const {
+ return !(*this == other);
+ }
+ bool operator==(const NamedAttrList &other) const {
+ return attrs == other.attrs;
+ }
+
+ /// Add an attribute with the specified name.
+ void append(StringRef name, Attribute attr);
+
+ /// Add an attribute with the specified name.
+ void append(Identifier name, Attribute attr);
+
+ /// Add an array of named attributes.
+ void append(ArrayRef<NamedAttribute> newAttributes);
+
+ /// Add a range of named attributes.
+ void append(const_iterator in_start, const_iterator in_end);
+
+ /// Replaces the attributes with new list of attributes.
+ void assign(const_iterator in_start, const_iterator in_end);
+
+ /// Replaces the attributes with new list of attributes.
+ void assign(ArrayRef<NamedAttribute> range) {
+ append(range.begin(), range.end());
+ }
+
+ bool empty() const { return attrs.empty(); }
+
+ void reserve(size_type N) { attrs.reserve(N); }
+
+ /// Add an attribute with the specified name.
+ void push_back(NamedAttribute newAttribute);
+
+ /// Pop last element from list.
+ void pop_back() { attrs.pop_back(); }
+
+ /// Return a dictionary attribute for the underlying dictionary. This will
+ /// return an empty dictionary attribute if empty rather than null.
+ DictionaryAttr getDictionary(MLIRContext *context) const;
+
+ /// Return all of the attributes on this operation.
+ ArrayRef<NamedAttribute> getAttrs() const;
+
+ /// Return the specified attribute if present, null otherwise.
+ Attribute get(Identifier name) const;
+ Attribute get(StringRef name) const;
+
+ /// Return the specified named attribute if present, None otherwise.
+ Optional<NamedAttribute> getNamed(StringRef name) const;
+ Optional<NamedAttribute> getNamed(Identifier name) const;
+
+ /// If the an attribute exists with the specified name, change it to the new
+ /// value. Otherwise, add a new attribute with the specified name/value.
+ void set(Identifier name, Attribute value);
+ void set(StringRef name, Attribute value);
+
+ const_iterator begin() const { return attrs.begin(); }
+ const_iterator end() const { return attrs.end(); }
+
+ NamedAttrList &operator=(const SmallVectorImpl<NamedAttribute> &rhs);
+ operator ArrayRef<NamedAttribute>() const;
+ operator MutableDictionaryAttr() const;
+
+private:
+ /// Return whether the attributes are sorted.
+ bool isSorted() const { return dictionarySorted.getInt(); }
+
+ // These are marked mutable as they may be modified (e.g., sorted)
+ mutable SmallVector<NamedAttribute, 4> attrs;
+ // Pair with cached DictionaryAttr and status of whether attrs is sorted.
+ // Note: just because sorted does not mean a DictionaryAttr has been created
+ // but the case where there is a DictionaryAttr but attrs isn't sorted should
+ // not occur.
+ mutable llvm::PointerIntPair<Attribute, 1, bool> dictionarySorted;
+};
+
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
@@ -269,7 +363,7 @@ struct OperationState {
SmallVector<Value, 4> operands;
/// Types of the results of this operation.
SmallVector<Type, 4> types;
- SmallVector<NamedAttribute, 4> attributes;
+ NamedAttrList attributes;
/// Successors of this operation and their respective operands.
SmallVector<Block *, 1> successors;
/// Regions that the op will hold.
@@ -303,12 +397,12 @@ struct OperationState {
/// Add an attribute with the specified name.
void addAttribute(Identifier name, Attribute attr) {
- attributes.push_back({name, attr});
+ attributes.append(name, attr);
}
/// Add an array of named attributes.
void addAttributes(ArrayRef<NamedAttribute> newAttributes) {
- attributes.append(newAttributes.begin(), newAttributes.end());
+ attributes.append(newAttributes);
}
/// Add an array of successors.
@@ -329,7 +423,7 @@ struct OperationState {
void addRegion(std::unique_ptr<Region> &®ion);
/// Get the context held by this operation state.
- MLIRContext *getContext() { return location->getContext(); }
+ MLIRContext *getContext() const { return location->getContext(); }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index c13f2e358d6f..67faeb56a51c 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -83,11 +83,11 @@ namespace detail {
LogicalResult inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes);
/// Verifies that the inferred result types match the actual result types for
@@ -107,7 +107,7 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
static LogicalResult
inferReturnTypes(MLIRContext *context, Optional<Location> location,
- ValueRange operands, ArrayRef<NamedAttribute> attributes,
+ ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
return ::mlir::detail::inferReturnTensorTypes(
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index fb6bdc220c83..723cf99d38b3 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -40,7 +40,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
/*args=*/(ins "MLIRContext*":$context,
"Optional<Location>":$location,
"ValueRange":$operands,
- "ArrayRef<NamedAttribute>":$attributes,
+ "DictionaryAttr":$attributes,
"RegionRange":$regions,
"SmallVectorImpl<Type>&":$inferredReturnTypes)
>,
@@ -92,7 +92,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
/*args=*/(ins "MLIRContext*":$context,
"Optional<Location>":$location,
"ValueRange":$operands,
- "ArrayRef<NamedAttribute>":$attributes,
+ "DictionaryAttr":$attributes,
"RegionRange":$regions,
"SmallVectorImpl<ShapedTypeComponents>&":
$inferredReturnShapes)
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 94965c7a623d..52898bfdb9ab 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -178,7 +178,8 @@ void CallGraph::print(raw_ostream &os) const {
auto *parentOp = callableRegion->getParentOp();
os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
<< callableRegion->getRegionNumber();
- if (auto attrs = parentOp->getMutableAttrDict().getDictionary())
+ auto attrs = parentOp->getAttrDictionary();
+ if (!attrs.empty())
os << " : " << attrs;
};
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f87c993ddd44..a07c3fa71743 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2445,7 +2445,7 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
return failure();
AffineMapAttr stepsMapAttr;
- SmallVector<NamedAttribute, 1> stepsAttrs;
+ NamedAttrList stepsAttrs;
SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
if (failed(parser.parseOptionalKeyword("step"))) {
SmallVector<int64_t, 4> steps(ivs.size(), 1);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 92ececb142d9..71cec77a0a4a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -538,8 +538,8 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
/// function-attributes? region
static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> entryArgs;
- SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
- SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+ SmallVector<NamedAttrList, 1> argAttrs;
+ SmallVector<NamedAttrList, 1> resultAttrs;
SmallVector<Type, 8> argTypes;
SmallVector<Type, 4> resultTypes;
bool isVariadic;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5c112710ec55..83b65798ed9e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -92,8 +92,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
predicateValue = static_cast<int64_t>(predicate.getValue());
}
- result.attributes[0].second =
- parser.getBuilder().getI64IntegerAttr(predicateValue);
+ result.attributes.set("predicate",
+ parser.getBuilder().getI64IntegerAttr(predicateValue));
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
@@ -1186,7 +1186,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
"expected as many argument attribute lists as arguments");
SmallString<8> argAttrName;
for (unsigned i = 0; i < numInputs; ++i)
- if (auto argDict = argAttrs[i].getDictionary())
+ if (auto argDict = argAttrs[i].getDictionary(builder.getContext()))
result.addAttribute(getArgAttrName(i, argAttrName), argDict);
}
@@ -1249,8 +1249,8 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
StringAttr nameAttr;
SmallVector<OpAsmParser::OperandType, 8> entryArgs;
- SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
- SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+ SmallVector<NamedAttrList, 1> argAttrs;
+ SmallVector<NamedAttrList, 1> resultAttrs;
SmallVector<Type, 8> argTypes;
SmallVector<Type, 4> resultTypes;
bool isVariadic;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 05f0eb4462e8..a356a391d72b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -353,8 +353,8 @@ struct ConvertSelectionOpToSelect
}
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
- return lhs.getOperation()->getMutableAttrDict().getDictionary() ==
- rhs.getOperation()->getMutableAttrDict().getDictionary();
+ return lhs.getOperation()->getAttrDictionary() ==
+ rhs.getOperation()->getAttrDictionary();
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 5d4e309a2e96..e7bdfe902804 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -105,7 +105,7 @@ static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
StringRef attrName = spirv::attributeName<EnumClass>()) {
Attribute attrVal;
- SmallVector<NamedAttribute, 1> attr;
+ NamedAttrList attr;
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
attrName, attr)) {
@@ -1019,7 +1019,7 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
// Parse the optional branch weights.
if (succeeded(parser.parseOptionalLSquare())) {
IntegerAttr trueWeight, falseWeight;
- SmallVector<NamedAttribute, 2> weights;
+ NamedAttrList weights;
auto i32Type = builder.getIntegerType(32);
if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
@@ -1443,7 +1443,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
FlatSymbolRefAttr var;
- SmallVector<NamedAttribute, 1> attrs;
+ NamedAttrList attrs;
if (parser.parseAttribute(var, Type(), attrName, attrs)) {
return failure();
}
@@ -1497,7 +1497,7 @@ static ParseResult parseExecutionModeOp(OpAsmParser &parser,
SmallVector<int32_t, 4> values;
Type i32Type = parser.getBuilder().getIntegerType(32);
while (!parser.parseOptionalComma()) {
- SmallVector<NamedAttribute, 1> attr;
+ NamedAttrList attr;
Attribute value;
if (parser.parseAttribute(value, i32Type, "value", attr)) {
return failure();
@@ -1529,8 +1529,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
- SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
- SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
+ SmallVector<NamedAttrList, 4> argAttrs;
+ SmallVector<NamedAttrList, 4> resultAttrs;
SmallVector<Type, 4> argTypes;
SmallVector<Type, 4> resultTypes;
auto &builder = parser.getBuilder();
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 10e766f3cc61..a420fa73cfad 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -92,10 +92,11 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
// BroadcastOp
//===----------------------------------------------------------------------===//
-LogicalResult BroadcastOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
@@ -137,7 +138,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
// shape as an ArrayAttr.
// TODO: Implement custom parser and maybe make syntax a bit more concise.
Attribute extentsRaw;
- SmallVector<NamedAttribute, 6> dummy;
+ NamedAttrList dummy;
if (parser.parseAttribute(extentsRaw, "dummy", dummy))
return failure();
auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
@@ -159,10 +160,11 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
-LogicalResult ConstShapeOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+ConstShapeOp::inferReturnTypes(MLIRContext *context,
+ Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
@@ -171,10 +173,11 @@ LogicalResult ConstShapeOp::inferReturnTypes(
// ConstSizeOp
//===----------------------------------------------------------------------===//
-LogicalResult ConstSizeOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(SizeType::get(context));
return success();
}
@@ -183,10 +186,11 @@ LogicalResult ConstSizeOp::inferReturnTypes(
// ShapeOfOp
//===----------------------------------------------------------------------===//
-LogicalResult ShapeOfOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
@@ -203,10 +207,11 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
// SplitAtOp
//===----------------------------------------------------------------------===//
-LogicalResult SplitAtOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
auto shapeType = ShapeType::get(context);
inferredReturnTypes.push_back(shapeType);
inferredReturnTypes.push_back(shapeType);
@@ -238,10 +243,11 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
// ConcatOp
//===----------------------------------------------------------------------===//
-LogicalResult ConcatOp::inferReturnTypes(
- MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
+LogicalResult
+ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
auto shapeType = ShapeType::get(context);
inferredReturnTypes.push_back(shapeType);
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5ff858dcc3a2..f00f2843bd18 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -478,7 +478,7 @@ static void print(OpAsmPrinter &p, vector::ExtractOp op) {
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc attributeLoc, typeLoc;
- SmallVector<NamedAttribute, 4> attrs;
+ NamedAttrList attrs;
OpAsmParser::OperandType vector;
Type type;
Attribute attr;
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index f47f80ca4066..540c3c6258e2 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -84,24 +84,6 @@ bool BoolAttr::getValue() const { return getImpl()->value; }
// DictionaryAttr
//===----------------------------------------------------------------------===//
-/// Perform a three-way comparison between the names of the specified
-/// NamedAttributes.
-static int compareNamedAttributes(const NamedAttribute *lhs,
- const NamedAttribute *rhs) {
- return strcmp(lhs->first.data(), rhs->first.data());
-}
-
-/// Returns if the name of the given attribute precedes that of 'name'.
-static bool compareNamedAttributeWithName(const NamedAttribute &attr,
- StringRef name) {
- // This is correct even when attr.first.data()[name.size()] is not a zero
- // string terminator, because we only care about a less than comparison.
- // This can't use memcmp, because it doesn't guarantee that it will stop
- // reading both buffers if one is shorter than the other, even if there is
- // a
diff erence.
- return strncmp(attr.first.data(), name.data(), name.size()) < 0;
-}
-
/// Helper function that does either an in place sort or sorts from source array
/// into destination. If inPlace then storage is both the source and the
/// destination, else value is the source and storage destination. Returns
@@ -112,32 +94,35 @@ static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
// Specialize for the common case.
switch (value.size()) {
case 0:
+ // Zero already sorted.
+ break;
case 1:
- // Zero or one elements are already sorted.
+ // One already sorted but may need to be copied.
+ if (!inPlace)
+ storage.assign({value[0]});
break;
- case 2:
+ case 2: {
assert(value[0].first != value[1].first &&
"DictionaryAttr element names must be unique");
- if (compareNamedAttributes(&value[0], &value[1]) > 0) {
- if (inPlace)
+ bool isSorted = value[0] < value[1];
+ if (inPlace) {
+ if (!isSorted)
std::swap(storage[0], storage[1]);
- else
- storage.append({value[1], value[0]});
- return true;
+ } else if (isSorted) {
+ storage.assign({value[0], value[1]});
+ } else {
+ storage.assign({value[1], value[0]});
}
- break;
+ return !isSorted;
+ }
default:
+ if (!inPlace)
+ storage.assign(value.begin(), value.end());
// Check to see they are sorted already.
- bool isSorted =
- llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) {
- return compareNamedAttributes(&l, &r) < 0;
- });
+ bool isSorted = llvm::is_sorted(value);
if (!isSorted) {
// If not, do a general sort.
- if (!inPlace)
- storage.append(value.begin(), value.end());
- llvm::array_pod_sort(storage.begin(), storage.end(),
- compareNamedAttributes);
+ llvm::array_pod_sort(storage.begin(), storage.end());
value = storage;
}
@@ -152,15 +137,19 @@ static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
return false;
}
-/// Sorts the NamedAttributes in the array ordered by name as expected by
-/// getWithSorted.
-/// Requires: uniquely named attributes.
-void DictionaryAttr::sort(SmallVectorImpl<NamedAttribute> &array) {
- dictionaryAttrSort</*inPlace=*/true>(array, array);
+bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
+ SmallVectorImpl<NamedAttribute> &storage) {
+ return dictionaryAttrSort</*inPlace=*/false>(value, storage);
+}
+
+bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
+ return dictionaryAttrSort</*inPlace=*/true>(array, array);
}
DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
MLIRContext *context) {
+ if (value.empty())
+ return DictionaryAttr::getEmpty(context);
assert(llvm::all_of(value,
[](const NamedAttribute &attr) { return attr.second; }) &&
"value cannot have null entries");
@@ -172,11 +161,12 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
return Base::get(context, StandardAttributes::Dictionary, value);
}
-
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
MLIRContext *context) {
+ if (value.empty())
+ return DictionaryAttr::getEmpty(context);
// Ensure that the attribute elements are unique and sorted.
assert(llvm::is_sorted(value,
[](NamedAttribute l, NamedAttribute r) {
@@ -208,7 +198,7 @@ Attribute DictionaryAttr::get(Identifier name) const {
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
ArrayRef<NamedAttribute> values = getValue();
- auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
+ const auto *it = llvm::lower_bound(values, name);
return it != values.end() && it->first == name ? *it
: Optional<NamedAttribute>();
}
@@ -1369,6 +1359,15 @@ MutableDictionaryAttr::MutableDictionaryAttr(
setAttrs(attributes);
}
+/// Return the underlying dictionary attribute.
+DictionaryAttr
+MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
+ // Construct empty DictionaryAttr if needed.
+ if (!attrs)
+ return DictionaryAttr::get({}, context);
+ return attrs;
+}
+
ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
return attrs ? attrs.getValue() : llvm::None;
}
@@ -1408,7 +1407,7 @@ void MutableDictionaryAttr::set(Identifier name, Attribute value) {
// Look for an existing value for the given name, and set it in-place.
ArrayRef<NamedAttribute> values = getAttrs();
- auto it = llvm::find_if(
+ const auto *it = llvm::find_if(
values, [name](NamedAttribute attr) { return attr.first == name; });
if (it != values.end()) {
// Bail out early if the value is the same as what we already have.
@@ -1422,7 +1421,7 @@ void MutableDictionaryAttr::set(Identifier name, Attribute value) {
}
// Otherwise, insert the new attribute into its sorted position.
- it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
+ it = llvm::lower_bound(values, name);
SmallVector<NamedAttribute, 8> newAttrs;
newAttrs.reserve(values.size() + 1);
newAttrs.append(values.begin(), it);
@@ -1454,3 +1453,15 @@ auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
}
return RemoveResult::NotFound;
}
+
+bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
+ return strcmp(lhs.first.data(), rhs.first.data()) < 0;
+}
+bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
+ // This is correct even when attr.first.data()[name.size()] is not a zero
+ // string terminator, because we only care about a less than comparison.
+ // This can't use memcmp, because it doesn't guarantee that it will stop
+ // reading both buffers if one is shorter than the other, even if there is
+ // a
diff erence.
+ return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
+}
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index a26d7db7e921..305022dd8451 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -57,7 +57,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
assert(type.getNumInputs() == argAttrs.size());
SmallString<8> argAttrName;
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
- if (auto argDict = argAttrs[i].getDictionary())
+ if (auto argDict = argAttrs[i].getDictionary(builder.getContext()))
result.addAttribute(getArgAttrName(i, argAttrName), argDict);
}
diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 0049f4fc571c..8b90ff13244f 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -17,8 +17,7 @@ static ParseResult
parseArgumentList(OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs,
- bool &isVariadic) {
+ SmallVectorImpl<NamedAttrList> &argAttrs, bool &isVariadic) {
if (parser.parseLParen())
return failure();
@@ -54,7 +53,7 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
argTypes.push_back(argumentType);
// Parse any argument attributes.
- SmallVector<NamedAttribute, 2> attrs;
+ NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs))
return failure();
argAttrs.push_back(attrs);
@@ -90,9 +89,9 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
/// function-result-list-no-parens ::= function-result (`,` function-result)*
/// function-result ::= type attribute-dict?
///
-static ParseResult parseFunctionResultList(
- OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
+static ParseResult
+parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<NamedAttrList> &resultAttrs) {
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
@@ -127,10 +126,9 @@ static ParseResult parseFunctionResultList(
ParseResult mlir::impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
- SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
+ SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
+ bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<NamedAttrList> &resultAttrs) {
if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
isVariadic))
return failure();
@@ -139,10 +137,9 @@ ParseResult mlir::impl::parseFunctionSignature(
return success();
}
-void mlir::impl::addArgAndResultAttrs(
- Builder &builder, OperationState &result,
- ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
- ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) {
+void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result,
+ ArrayRef<NamedAttrList> argAttrs,
+ ArrayRef<NamedAttrList> resultAttrs) {
// Add the attributes to the function arguments.
SmallString<8> attrNameBuf;
for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
@@ -164,8 +161,8 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
bool allowVariadic,
mlir::impl::FuncTypeBuilder funcTypeBuilder) {
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
- SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
- SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
+ SmallVector<NamedAttrList, 4> argAttrs;
+ SmallVector<NamedAttrList, 4> resultAttrs;
SmallVector<Type, 4> argTypes;
SmallVector<Type, 4> resultTypes;
auto &builder = parser.getBuilder();
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index c59c53567488..0728f294be86 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -328,6 +328,7 @@ class MLIRContextImpl {
BoolAttr falseAttr, trueAttr;
UnitAttr unitAttr;
UnknownLoc unknownLocAttr;
+ DictionaryAttr emptyDictionaryAttr;
public:
MLIRContextImpl() : identifiers(identifierAllocator) {}
@@ -388,6 +389,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
/// Unknown Location Attribute.
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
this, StandardAttributes::UnknownLocation);
+ /// The empty dictionary attribute.
+ impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
+ this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
}
MLIRContext::~MLIRContext() {}
@@ -742,6 +746,11 @@ Location UnknownLoc::get(MLIRContext *context) {
return context->getImpl().unknownLocAttr;
}
+/// Return empty dictionary.
+DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
+ return context->getImpl().emptyDictionaryAttr;
+}
+
//===----------------------------------------------------------------------===//
// AffineMap uniquing
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 5b439d67ad67..3489ee74a422 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -70,9 +70,9 @@ Operation *Operation::create(Location location, OperationName name,
/// Create a new Operation from operation state.
Operation *Operation::create(const OperationState &state) {
- return Operation::create(
- state.location, state.name, state.types, state.operands,
- MutableDictionaryAttr(state.attributes), state.successors, state.regions);
+ return Operation::create(state.location, state.name, state.types,
+ state.operands, state.attributes, state.successors,
+ state.regions);
}
/// Create a new Operation with the specific fields.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 91842cf95e56..f326d621ee0c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -18,6 +18,146 @@
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// NamedAttrList
+//===----------------------------------------------------------------------===//
+
+NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
+ assign(attributes.begin(), attributes.end());
+}
+
+NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) {
+ assign(in_start, in_end);
+}
+
+ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
+
+DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
+ if (!isSorted()) {
+ DictionaryAttr::sortInPlace(attrs);
+ dictionarySorted.setPointerAndInt(nullptr, true);
+ }
+ if (!dictionarySorted.getPointer())
+ dictionarySorted.setPointer(DictionaryAttr::getWithSorted(attrs, context));
+ return dictionarySorted.getPointer().cast<DictionaryAttr>();
+}
+
+NamedAttrList::operator MutableDictionaryAttr() const {
+ if (attrs.empty())
+ return MutableDictionaryAttr();
+ return getDictionary(attrs.front().second.getContext());
+}
+
+/// Add an attribute with the specified name.
+void NamedAttrList::append(StringRef name, Attribute attr) {
+ append(Identifier::get(name, attr.getContext()), attr);
+}
+
+/// Add an attribute with the specified name.
+void NamedAttrList::append(Identifier name, Attribute attr) {
+ push_back({name, attr});
+}
+
+/// Add an array of named attributes.
+void NamedAttrList::append(ArrayRef<NamedAttribute> newAttributes) {
+ append(newAttributes.begin(), newAttributes.end());
+}
+
+/// Add a range of named attributes.
+void NamedAttrList::append(const_iterator in_start, const_iterator in_end) {
+ // TODO: expand to handle case where values appended are in order & after
+ // end of current list.
+ dictionarySorted.setPointerAndInt(nullptr, false);
+ attrs.append(in_start, in_end);
+}
+
+/// Replaces the attributes with new list of attributes.
+void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
+ DictionaryAttr::sort(ArrayRef<NamedAttribute>{in_start, in_end}, attrs);
+ dictionarySorted.setPointerAndInt(nullptr, true);
+}
+
+void NamedAttrList::push_back(NamedAttribute newAttribute) {
+ if (isSorted())
+ dictionarySorted.setInt(
+ attrs.empty() ||
+ strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0);
+ dictionarySorted.setPointer(nullptr);
+ attrs.push_back(newAttribute);
+}
+
+/// Helper function to find attribute in possible sorted vector of
+/// NamedAttributes.
+template <typename T>
+static auto *findAttr(SmallVectorImpl<NamedAttribute> &attrs, T name,
+ bool sorted) {
+ if (!sorted) {
+ return llvm::find_if(
+ attrs, [name](NamedAttribute attr) { return attr.first == name; });
+ }
+
+ auto *it = llvm::lower_bound(attrs, name);
+ if (it->first != name)
+ return attrs.end();
+ return it;
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttrList::get(StringRef name) const {
+ auto *it = findAttr(attrs, name, isSorted());
+ return it != attrs.end() ? it->second : nullptr;
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttrList::get(Identifier name) const {
+ auto *it = findAttr(attrs, name, isSorted());
+ return it != attrs.end() ? it->second : nullptr;
+}
+
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
+ auto *it = findAttr(attrs, name, isSorted());
+ return it != attrs.end() ? *it : Optional<NamedAttribute>();
+}
+Optional<NamedAttribute> NamedAttrList::getNamed(Identifier name) const {
+ auto *it = findAttr(attrs, name, isSorted());
+ return it != attrs.end() ? *it : Optional<NamedAttribute>();
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+void NamedAttrList::set(Identifier name, Attribute value) {
+ assert(value && "attributes may never be null");
+
+ // Look for an existing value for the given name, and set it in-place.
+ auto *it = findAttr(attrs, name, isSorted());
+ if (it != attrs.end()) {
+ // Bail out early if the value is the same as what we already have.
+ if (it->second == value)
+ return;
+ dictionarySorted.setPointer(nullptr);
+ it->second = value;
+ return;
+ }
+
+ // Otherwise, insert the new attribute into its sorted position.
+ it = llvm::lower_bound(attrs, name);
+ dictionarySorted.setPointer(nullptr);
+ attrs.insert(it, {name, value});
+}
+void NamedAttrList::set(StringRef name, Attribute value) {
+ assert(value && "setting null attribute not supported");
+ return set(mlir::Identifier::get(name, value.getContext()), value);
+}
+
+NamedAttrList &
+NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
+ assign(rhs.begin(), rhs.end());
+ return *this;
+}
+
+NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
+
//===----------------------------------------------------------------------===//
// OperationState
//===----------------------------------------------------------------------===//
@@ -133,7 +273,7 @@ void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
// Shift all operands down if the operand to remove is not at the end.
if (start != storage.numOperands) {
- auto indexIt = std::next(operands.begin(), start);
+ auto *indexIt = std::next(operands.begin(), start);
std::rotate(indexIt, std::next(indexIt, length), operands.end());
}
for (unsigned i = 0; i != length; ++i)
@@ -416,8 +556,8 @@ llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) {
// Hash operations based upon their:
// - Operation Name
// - Attributes
- llvm::hash_code hash = llvm::hash_combine(
- op->getName(), op->getMutableAttrDict().getDictionary());
+ llvm::hash_code hash =
+ llvm::hash_combine(op->getName(), op->getMutableAttrDict());
// - Result Types
ArrayRef<Type> resultTypes = op->getResultTypes();
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 1d2235b61936..16e548f4430c 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -383,9 +383,9 @@ static WalkResult walkSymbolRefs(
Operation *op,
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
// Check to see if the operation has any attributes.
- DictionaryAttr attrDict = op->getMutableAttrDict().getDictionary();
- if (!attrDict)
+ if (op->getMutableAttrDict().empty())
return WalkResult::advance();
+ DictionaryAttr attrDict = op->getAttrDictionary();
// A worklist of a container attribute and the current index into the held
// attribute list.
@@ -799,7 +799,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
// Generate a new attribute dictionary for the current operation by replacing
// references to the old symbol.
auto generateNewAttrDict = [&] {
- auto oldDict = curOp->getMutableAttrDict().getDictionary();
+ auto oldDict = curOp->getAttrDictionary();
auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
return newDict.cast<DictionaryAttr>();
};
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index ef73ce3021f8..70dbc05592b7 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -24,11 +24,11 @@ namespace mlir {
LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes, regions,
@@ -49,9 +49,9 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes;
auto retTypeFn = cast<InferTypeOpInterface>(op);
- if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
- op->getOperands(), op->getAttrs(),
- op->getRegions(), inferredReturnTypes)))
+ if (failed(retTypeFn.inferReturnTypes(
+ op->getContext(), op->getLoc(), op->getOperands(),
+ op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
return failure();
if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
op->getResultTypes()))
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index ebac3e214931..f5ad1b65f1a1 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -275,7 +275,7 @@ class Parser {
Attribute parseAttribute(Type type = {});
/// Parse an attribute dictionary.
- ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
+ ParseResult parseAttributeDict(NamedAttrList &attributes);
/// Parse an extended attribute.
Attribute parseExtendedAttr(Type type);
@@ -1569,10 +1569,10 @@ Attribute Parser::parseAttribute(Type type) {
// Parse a dictionary attribute.
case Token::l_brace: {
- SmallVector<NamedAttribute, 4> elements;
+ NamedAttrList elements;
if (parseAttributeDict(elements))
return nullptr;
- return builder.getDictionaryAttr(elements);
+ return elements.getDictionary(getContext());
}
// Parse an extended attribute, i.e. alias or dialect attribute.
@@ -1671,8 +1671,7 @@ Attribute Parser::parseAttribute(Type type) {
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
-ParseResult
-Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
+ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
return failure();
@@ -1701,7 +1700,6 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
auto attr = parseAttribute();
if (!attr)
return failure();
-
attributes.push_back({*nameId, attr});
return success();
};
@@ -4228,7 +4226,7 @@ class CustomOpAsmParser : public OpAsmParser {
/// also adds the attribute to the specified attribute list with the specified
/// name.
ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) override {
+ NamedAttrList &attrs) override {
result = parser.parseAttribute(type);
if (!result)
return failure();
@@ -4238,8 +4236,7 @@ class CustomOpAsmParser : public OpAsmParser {
}
/// Parse a named dictionary into 'result' if it is present.
- ParseResult
- parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override {
+ ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
return parser.parseAttributeDict(result);
@@ -4247,8 +4244,7 @@ class CustomOpAsmParser : public OpAsmParser {
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
- ParseResult parseOptionalAttrDictWithKeyword(
- SmallVectorImpl<NamedAttribute> &result) override {
+ ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
if (failed(parseOptionalKeyword("attributes")))
return success();
return parser.parseAttributeDict(result);
@@ -4296,9 +4292,8 @@ class CustomOpAsmParser : public OpAsmParser {
/// Parse an optional @-identifier and store it (without the '@' symbol) in a
/// string attribute named 'attrName'.
- ParseResult
- parseOptionalSymbolName(StringAttr &result, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs) override {
+ ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
+ NamedAttrList &attrs) override {
Token atToken = parser.getToken();
if (atToken.isNot(Token::at_identifier))
return failure();
@@ -4446,7 +4441,7 @@ class CustomOpAsmParser : public OpAsmParser {
/// Parse an AffineMap of SSA ids.
ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
Attribute &mapAttr, StringRef attrName,
- SmallVectorImpl<NamedAttribute> &attrs,
+ NamedAttrList &attrs,
Delimiter delimiter) override {
SmallVector<OperandType, 2> dimOperands;
SmallVector<OperandType, 1> symOperands;
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 842b83c5d0f7..1e9760581fcc 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -33,9 +33,7 @@ class OperationFingerPrint {
// - Operation pointer
addDataToHash(hasher, op);
// - Attributes
- addDataToHash(
- hasher,
- op->getMutableAttrDict().getDictionary().getAsOpaquePointer());
+ addDataToHash(hasher, op->getMutableAttrDict());
// - Blocks in Regions
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 623bb0af4623..ac067e32e528 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -334,7 +334,7 @@ OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
@@ -347,7 +347,7 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = *operands.getTypes().begin();
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 7c91d5f6c682..87b1a697f8c9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -97,9 +97,9 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
for (int j = 0; j < e; ++j) {
std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
SmallVector<Type, 2> inferredReturnTypes;
- if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values,
- op->getAttrs(), op->getRegions(),
- inferredReturnTypes))) {
+ if (succeeded(OpTy::inferReturnTypes(
+ context, llvm::None, values, op->getAttrDictionary(),
+ op->getRegions(), inferredReturnTypes))) {
OperationState state(location, OpTy::getOperationName());
// TODO(jpienaar): Expand to regions.
OpTy::build(b, state, values, op->getAttrs());
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 29df09b551b8..fc66313c323a 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -775,7 +775,8 @@ void OpEmitter::genSeparateArgParamBuilder() {
body << formatv(R"(
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
- {1}.location, {1}.operands, {1}.attributes,
+ {1}.location, {1}.operands,
+ {1}.attributes.getDictionary({1}.getContext()),
/*regions=*/{{}, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
@@ -867,13 +868,48 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
opClass.newMethod("void", "build", formatv(params, builderOpState).str(),
OpMethod::MP_Static);
auto &body = m.body();
+
+ int numResults = op.getNumResults();
+ int numVariadicResults = op.getNumVariableLengthResults();
+ int numNonVariadicResults = numResults - numVariadicResults;
+
+ int numOperands = op.getNumOperands();
+ int numVariadicOperands = op.getNumVariableLengthOperands();
+ int numNonVariadicOperands = numOperands - numVariadicOperands;
+
+ // Operands
+ if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
+ body << " assert(operands.size()"
+ << (numVariadicOperands != 0 ? " >= " : " == ")
+ << numNonVariadicOperands
+ << "u && \"mismatched number of parameters\");\n";
+ body << " " << builderOpState << ".addOperands(operands);\n";
+ body << " " << builderOpState << ".addAttributes(attributes);\n";
+
+ // Create the correct number of regions
+ if (int numRegions = op.getNumRegions()) {
+ body << llvm::formatv(
+ " for (unsigned i = 0; i != {0}; ++i)\n",
+ (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
+ body << " (void)" << builderOpState << ".addRegion();\n";
+ }
+
+ // Result types
body << formatv(R"(
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
- {1}.location, operands, attributes,
- /*regions=*/{{}, inferredReturnTypes)))
- build(odsBuilder, odsState, inferredReturnTypes, operands, attributes);
- else
+ {1}.location, operands,
+ {1}.attributes.getDictionary({1}.getContext()),
+ /*regions=*/{{}, inferredReturnTypes))) {{)",
+ opClass.getClassName(), builderOpState);
+ if (numVariadicResults == 0 || numNonVariadicResults != 0)
+ body << " assert(inferredReturnTypes.size()"
+ << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
+ << "u && \"mismatched number of return types\");\n";
+ body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
+
+ body << formatv(R"(
+ } else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 47c145df29e0..127b6b976cd5 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -372,7 +372,7 @@ const char *const attrParserCode = R"(
const char *const enumAttrParserCode = R"(
{
StringAttr attrVal;
- SmallVector<NamedAttribute, 1> attrStorage;
+ NamedAttrList attrStorage;
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
"{0}", attrStorage))
More information about the flang-commits
mailing list